11import abc
2+ import contextlib
23import multiprocessing as mp
34import os
45import queue
6+ import signal
57import sys
68import threading
79from multiprocessing .queues import JoinableQueue
8- from typing import Any , Callable , Union
10+ from typing import Any , Callable , Set , Union
911
1012from .logging import multiprocessing_breakpoint
1113
1214mp .set_start_method ("fork" )
1315
1416
1517class PoolBase (abc .ABC ):
18+ def __init__ (self ):
19+ with pools_lock :
20+ pools .add (self )
21+
1622 @abc .abstractmethod
1723 def submit (self , args ):
1824 pass
@@ -24,15 +30,20 @@ def process_until_done(self):
2430 def start (self ):
2531 pass
2632
27- def close (self ):
28- pass
33+ def close (self , * , immediate = False ): # noqa: ARG002
34+ with pools_lock :
35+ pools .remove (self )
2936
3037 def __enter__ (self ):
3138 self .start ()
3239 return self
3340
34- def __exit__ (self , * args ):
35- self .close ()
41+ def __exit__ (self , exc_type , _exc_value , _tb ):
42+ self .close (immediate = exc_type is not None )
43+
44+
45+ pools_lock = threading .Lock ()
46+ pools : Set [PoolBase ] = set ()
3647
3748
3849class Queue (JoinableQueue ):
@@ -53,9 +64,15 @@ class _Sentinel:
5364
5465
5566def _worker_process (handler , input_ , output ):
56- # Creates a new process group, making sure no signals are propagated from the main process to the worker processes.
67+ # Creates a new process group, making sure no signals are
68+ # propagated from the main process to the worker processes.
5769 os .setpgrp ()
5870
71+ # Restore default signal handlers, otherwise workers would inherit
72+ # them from main process
73+ signal .signal (signal .SIGTERM , signal .SIG_DFL )
74+ signal .signal (signal .SIGINT , signal .SIG_DFL )
75+
5976 sys .breakpointhook = multiprocessing_breakpoint
6077 while (args := input_ .get ()) is not _SENTINEL :
6178 result = handler (args )
@@ -71,11 +88,14 @@ def __init__(
7188 * ,
7289 result_callback : Callable [["MultiPool" , Any ], Any ],
7390 ):
91+ super ().__init__ ()
7492 if process_num <= 0 :
7593 raise ValueError ("At process_num must be greater than 0" )
7694
95+ self ._running = False
7796 self ._result_callback = result_callback
7897 self ._input = Queue (ctx = mp .get_context ())
98+ self ._input .cancel_join_thread ()
7999 self ._output = mp .SimpleQueue ()
80100 self ._procs = [
81101 mp .Process (
@@ -87,14 +107,32 @@ def __init__(
87107 self ._tid = threading .get_native_id ()
88108
89109 def start (self ):
110+ self ._running = True
90111 for p in self ._procs :
91112 p .start ()
92113
93- def close (self ):
94- self ._clear_input_queue ()
95- self ._request_workers_to_quit ()
96- self ._clear_output_queue ()
114+ def close (self , * , immediate = False ):
115+ if not self ._running :
116+ return
117+ self ._running = False
118+
119+ if immediate :
120+ self ._terminate_workers ()
121+ else :
122+ self ._clear_input_queue ()
123+ self ._request_workers_to_quit ()
124+ self ._clear_output_queue ()
125+
97126 self ._wait_for_workers_to_quit ()
127+ super ().close (immediate = immediate )
128+
129+ def _terminate_workers (self ):
130+ for proc in self ._procs :
131+ proc .terminate ()
132+
133+ self ._input .close ()
134+ if sys .version_info >= (3 , 9 ):
135+ self ._output .close ()
98136
99137 def _clear_input_queue (self ):
100138 try :
@@ -129,14 +167,16 @@ def submit(self, args):
129167 self ._input .put (args )
130168
131169 def process_until_done (self ):
132- while not self ._input .is_empty ():
133- result = self ._output .get ()
134- self ._result_callback (self , result )
135- self ._input .task_done ()
170+ with contextlib .suppress (EOFError ):
171+ while not self ._input .is_empty ():
172+ result = self ._output .get ()
173+ self ._result_callback (self , result )
174+ self ._input .task_done ()
136175
137176
138177class SinglePool (PoolBase ):
139178 def __init__ (self , handler , * , result_callback ):
179+ super ().__init__ ()
140180 self ._handler = handler
141181 self ._result_callback = result_callback
142182
@@ -157,3 +197,19 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP
157197 handler = handler ,
158198 result_callback = result_callback ,
159199 )
200+
201+
202+ orig_signal_handlers = {}
203+
204+
205+ def _on_terminate (signum , frame ):
206+ pools_snapshot = list (pools )
207+ for pool in pools_snapshot :
208+ pool .close (immediate = True )
209+
210+ if callable (orig_signal_handlers [signum ]):
211+ orig_signal_handlers [signum ](signum , frame )
212+
213+
214+ orig_signal_handlers [signal .SIGTERM ] = signal .signal (signal .SIGTERM , _on_terminate )
215+ orig_signal_handlers [signal .SIGINT ] = signal .signal (signal .SIGINT , _on_terminate )
0 commit comments