@@ -77,14 +77,18 @@ def get_valid_filename(s):
7777
7878
7979class TestRunner :
80- def __init__ (self , cmd , env , test_output ):
80+ def __init__ (self , test_name , cmd , env , test_output , timeout = None ):
81+ self .test_name = test_name
8182 self .cmd = cmd
8283 self .env = env
8384 self .test_output = test_output
85+ self .timeout = timeout
86+ self .p = None
8487 self .pdb_mode = False
85- self .loop = asyncio .new_event_loop ()
8688 self .master_fd = None
8789 self .write_task = None
90+ self .read_task = None
91+ self .timeout_task = None
8892
8993 def run (self ):
9094 """
@@ -97,7 +101,7 @@ def run(self):
97101 self .master_fd , slave_fd = pty .openpty ()
98102
99103 # Start child connected to the PTY
100- p = subprocess .Popen (
104+ self . p = subprocess .Popen (
101105 self .cmd ,
102106 env = self .env ,
103107 stdin = slave_fd ,
@@ -106,11 +110,28 @@ def run(self):
106110 )
107111 os .close (slave_fd )
108112
109- self .loop .run_until_complete (self .handle_inout ())
110- return p .wait ()
113+ try :
114+ asyncio .run (self .handle_inout ())
115+ except subprocess .TimeoutExpired :
116+ LOGGER .error (f"Test { self .test_name } timed out" )
117+ try :
118+ return self .p .wait (timeout = 30 )
119+ except subprocess .TimeoutExpired :
120+ # If SIGTERM is intercepted, do a hard kill
121+ self .p .kill ()
122+ return self .p .wait ()
111123
112124 async def handle_inout (self ):
113- await self .read_from_child ()
125+ tasks = []
126+ self .read_task = asyncio .create_task (self .read_from_child ())
127+ tasks .append (self .read_task )
128+ if self .timeout is not None :
129+ self .timeout_task = asyncio .create_task (self .check_timeout ())
130+ tasks .append (self .timeout_task )
131+ try :
132+ await asyncio .gather (* tasks )
133+ except asyncio .CancelledError :
134+ pass
114135
115136 def output_line (self , line ):
116137 if self .pdb_mode :
@@ -143,15 +164,15 @@ async def read_from_child(self):
143164 buffer = b""
144165 while True :
145166 try :
146- data = await self . loop . run_in_executor ( None , os .read , self .master_fd , 1024 )
167+ data = await asyncio . to_thread ( os .read , self .master_fd , 1024 )
147168 except OSError :
148169 break
149170 if not data :
150171 break
151172 buffer += data
152173 buffer = self .process_buffer (buffer )
153174 if self .pdb_mode and self .write_task is None :
154- self .write_task = self . loop .create_task (self .write_to_child ())
175+ self .write_task = asyncio .create_task (self .write_to_child ())
155176 buffer = self .process_buffer (buffer , force_flush = True )
156177 self .test_output .flush ()
157178
@@ -164,6 +185,10 @@ async def read_from_child(self):
164185 self .write_task = None
165186 self .pdb_mode = False
166187
188+ if self .timeout_task is not None :
189+ self .timeout_task .cancel ()
190+ self .timeout_task = None
191+
167192 # Writer: forward our stdin to child tty
168193 async def write_to_child (self ):
169194 while True :
@@ -172,6 +197,16 @@ async def write_to_child(self):
172197 break
173198 os .write (self .master_fd , data )
174199
200+ # Kill the child process if the timeout is reached
201+ async def check_timeout (self ):
202+ await asyncio .sleep (self .timeout )
203+ if self .pdb_mode :
204+ # We don't want to kill the process if it's in pdb mode
205+ return
206+ if self .p .poll () is None :
207+ self .p .terminate ()
208+ raise subprocess .TimeoutExpired (self .cmd , self .timeout )
209+
175210
176211def run_individual_python_test (target_dir , test_name , pyspark_python , keep_test_output ):
177212 """
@@ -248,29 +283,15 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_
248283 start_time = time .time ()
249284
250285 retcode = None
251- proc = None
252286 try :
253- if timeout :
254- proc = subprocess .Popen (cmd , stderr = per_test_output , stdout = per_test_output , env = env )
255- retcode = proc .wait (timeout = timeout )
256- else :
257- retcode = TestRunner (cmd , env , per_test_output ).run ()
287+ retcode = TestRunner (test_name , cmd , env , per_test_output , timeout ).run ()
258288 if not keep_test_output :
259289 # There exists a race condition in Python and it causes flakiness in MacOS
260290 # https://github.com/python/cpython/issues/73885
261291 if platform .system () == "Darwin" :
262292 os .system ("rm -rf " + tmp_dir )
263293 else :
264294 shutil .rmtree (tmp_dir , ignore_errors = True )
265- except subprocess .TimeoutExpired :
266- if timeout and proc :
267- LOGGER .exception (
268- "Got TimeoutExpired while running %s with %s" , test_name , pyspark_python
269- )
270- proc .terminate ()
271- proc .communicate (timeout = 60 )
272- else :
273- raise
274295 except BaseException :
275296 LOGGER .exception ("Got exception while running %s with %s" , test_name , pyspark_python )
276297 # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
0 commit comments