15
15
import sys
16
16
import threading
17
17
import time
18
+ from concurrent .futures import ThreadPoolExecutor
18
19
from functools import partial
19
- from signal import SIGINT
20
20
from signal import SIGTERM
21
21
from subprocess import check_output
22
22
from subprocess import PIPE
@@ -85,6 +85,12 @@ class UnknownStatus(LauncherError):
85
85
class BaseLauncher (LoggingConfigurable ):
86
86
"""An abstraction for starting, stopping and signaling a process."""
87
87
88
+ stop_timeout = Integer (
89
+ 60 ,
90
+ config = True ,
91
+ help = "The number of seconds to wait for a process to exit before raising a TimeoutError in stop" ,
92
+ )
93
+
88
94
# In all of the launchers, the work_dir is where child processes will be
89
95
# run. This will usually be the profile_dir, but may not be. any work_dir
90
96
# passed into the __init__ method will override the config value.
@@ -249,6 +255,10 @@ def signal(self, sig):
249
255
"""
250
256
raise NotImplementedError ('signal must be implemented in a subclass' )
251
257
258
+ def join (self , timeout = None ):
259
+ """Wait for the process to finish"""
260
+ raise NotImplementedError ('join must be implemented in a subclass' )
261
+
252
262
output_limit = Integer (
253
263
100 ,
254
264
config = True ,
@@ -376,6 +386,12 @@ def _default_output_file(self):
376
386
os .makedirs (log_dir , exist_ok = True )
377
387
return os .path .join (log_dir , f'{ self .identifier } .log' )
378
388
389
+ stop_seconds_until_kill = Integer (
390
+ 5 ,
391
+ config = True ,
392
+ help = """The number of seconds to wait for a process to exit after sending SIGTERM before sending SIGKILL""" ,
393
+ )
394
+
379
395
stdout = None
380
396
stderr = None
381
397
process = None
@@ -446,6 +462,18 @@ def start(self):
446
462
if self .log .level <= logging .DEBUG :
447
463
self ._start_streaming ()
448
464
465
+ async def join (self , timeout = None ):
466
+ """Wait for the process to exit"""
467
+ with ThreadPoolExecutor (1 ) as pool :
468
+ try :
469
+ await asyncio .wrap_future (
470
+ pool .submit (partial (self .process .wait , timeout ))
471
+ )
472
+ except psutil .TimeoutExpired :
473
+ raise TimeoutError (
474
+ f"Process { self .pid } did not complete in { timeout } seconds."
475
+ )
476
+
449
477
def _stream_file (self , path ):
450
478
"""Stream one file"""
451
479
with open (path , 'r' ) as f :
@@ -460,7 +488,7 @@ def _stream_file(self, path):
460
488
time .sleep (0.1 )
461
489
462
490
def _start_streaming (self ):
463
- t = threading .Thread (
491
+ self . _stream_thread = t = threading .Thread (
464
492
target = partial (self ._stream_file , self .output_file ),
465
493
name = f"Stream Output { self .identifier } " ,
466
494
daemon = True ,
@@ -483,33 +511,46 @@ def get_output(self, remove=False):
483
511
484
512
if remove and os .path .isfile (self .output_file ):
485
513
self .log .debug (f"Removing { self .output_file } " )
486
- os .remove (self .output_file )
514
+ try :
515
+ os .remove (self .output_file )
516
+ except Exception as e :
517
+ # don't crash on failure to remove a file,
518
+ # e.g. due to another processing having it open
519
+ self .log .error (f"Failed to remove { self .output_file } : { e } " )
487
520
488
521
return self ._output
489
522
490
- def stop (self ):
491
- return self .interrupt_then_kill ()
523
+ async def stop (self ):
524
+ try :
525
+ self .signal (SIGTERM )
526
+ except Exception as e :
527
+ self .log .debug (f"TERM failed: { e !r} " )
528
+
529
+ try :
530
+ await self .join (timeout = self .stop_seconds_until_kill )
531
+ except TimeoutError :
532
+ self .log .warning (
533
+ f"Process { self .pid } did not exit in { self .stop_seconds_until_kill } seconds after TERM"
534
+ )
535
+ else :
536
+ return
537
+
538
+ try :
539
+ self .signal (SIGKILL )
540
+ except Exception as e :
541
+ self .log .debug (f"KILL failed: { e !r} " )
542
+
543
+ await self .join (timeout = self .stop_timeout )
492
544
493
545
def signal (self , sig ):
494
546
if self .state == 'running' :
495
- if WINDOWS and sig != SIGINT :
547
+ if WINDOWS and sig == SIGKILL :
496
548
# use Windows tree-kill for better child cleanup
497
549
cmd = ['taskkill' , '/pid' , str (self .process .pid ), '/t' , '/F' ]
498
550
check_output (cmd )
499
551
else :
500
552
self .process .send_signal (sig )
501
553
502
- def interrupt_then_kill (self , delay = 2.0 ):
503
- """Send TERM, wait a delay and then send KILL."""
504
- try :
505
- self .signal (SIGTERM )
506
- except Exception as e :
507
- self .log .debug (f"interrupt failed: { e !r} " )
508
- pass
509
- self .killer = asyncio .get_event_loop ().call_later (
510
- delay , lambda : self .signal (SIGKILL )
511
- )
512
-
513
554
# callbacks, etc:
514
555
515
556
def handle_stdout (self , fd , events ):
@@ -635,21 +676,18 @@ def find_args(self):
635
676
return ['engine set' ]
636
677
637
678
def signal (self , sig ):
638
- dlist = []
639
- for el in itervalues (self .launchers ):
640
- d = el .signal (sig )
641
- dlist .append (d )
642
- return dlist
679
+ for el in list (self .launchers .values ()):
680
+ el .signal (sig )
643
681
644
- def interrupt_then_kill (self , delay = 1.0 ):
645
- dlist = []
646
- for el in itervalues (self .launchers ):
647
- d = el .interrupt_then_kill ( delay )
648
- dlist . append ( d )
649
- return dlist
682
+ async def stop (self ):
683
+ futures = []
684
+ for el in list (self .launchers . values () ):
685
+ f = el .stop ( )
686
+ if inspect . isawaitable ( f ):
687
+ futures . append ( asyncio . ensure_future ( f ))
650
688
651
- def stop ( self ) :
652
- return self . interrupt_then_kill ( )
689
+ if futures :
690
+ await asyncio . gather ( * futures )
653
691
654
692
def _notice_engine_stopped (self , data ):
655
693
identifier = data ['identifier' ]
@@ -1144,6 +1182,12 @@ def wait_one(self, timeout):
1144
1182
raise TimeoutError ("still running" )
1145
1183
return int (values .get ("exit_code" , - 1 ))
1146
1184
1185
+ async def join (self , timeout = None ):
1186
+ with ThreadPoolExecutor (1 ) as pool :
1187
+ await asyncio .wrap_future (
1188
+ pool .submit (partial (self .wait_one , timeout = timeout ))
1189
+ )
1190
+
1147
1191
def signal (self , sig ):
1148
1192
if self .state == 'running' :
1149
1193
check_output (
0 commit comments