99from ..util .sentinel import StopSentinel
1010
1111if TYPE_CHECKING :
12+ from multiprocessing .synchronize import Event as MpEvent
1213 from ..util .worker_pool import WorkerPool
1314 from ..task import Task
1415
1516
1617class Producer :
17- def __init__ (self , task : Task , next_task : Task , q_err : Union [mp .Queue , queue .Queue ]):
18+ def __init__ (
19+ self ,
20+ task : Task ,
21+ next_task : Task ,
22+ q_err : Union [mp .Queue , queue .Queue ],
23+ shutdown_event : Union [MpEvent , threading .Event ]):
1824 if task .concurrency > 1 :
1925 raise RuntimeError (f"The first task in a pipeline ({ task .func .__qualname__ } ) cannot have concurrency greater than 1" )
2026 if task .join :
@@ -24,16 +30,16 @@ def __init__(self, task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Que
2430 else queue .Queue (maxsize = task .throttle )
2531
2632 self ._q_err = q_err
33+ self ._shutdown_event = shutdown_event
2734 self ._n_workers = task .concurrency
2835 self ._n_consumers = 1 if next_task is None else next_task .concurrency
2936 self ._enqueue = EnqueueFactory (self .q_out , task )
3037
31- self ._multiprocess = task .multiprocess
32-
3338 def _worker (self , * args , ** kwargs ):
3439 try :
3540 self ._enqueue (* args , ** kwargs )
3641 except Exception as e :
42+ self ._shutdown_event .set ()
3743 self ._q_err .put (e )
3844 finally :
3945 for _ in range (self ._n_consumers ):
@@ -44,12 +50,21 @@ def start(self, pool: WorkerPool, /, *args, **kwargs):
4450
4551
4652class ProducerConsumer :
47- def __init__ (self , q_in : Union [mp .Queue , queue .Queue ], task : Task , next_task : Task , q_err : Union [mp .Queue , queue .Queue ]):
53+ def __init__ (
54+ self ,
55+ q_in : Union [mp .Queue , queue .Queue ],
56+ task : Task ,
57+ next_task : Task ,
58+ q_err : Union [mp .Queue , queue .Queue ],
59+ shutdown_event : Union [MpEvent , threading .Event ]):
60+ # The output queue is shared between this task and the next. We optimize here by using queue.Queue wherever possible
61+ # and only using multiprocess.Queue when the current task or the next task are multiprocessed
4862 self .q_out = mp .Queue (maxsize = task .throttle ) \
4963 if task .multiprocess or (next_task is not None and next_task .multiprocess ) \
5064 else queue .Queue (maxsize = task .throttle )
5165
5266 self ._q_err = q_err
67+ self ._shutdown_event = shutdown_event
5368 self ._n_workers = task .concurrency
5469 self ._n_consumers = 1 if next_task is None else next_task .concurrency
5570 self ._dequeue = DequeueFactory (q_in , task )
@@ -80,8 +95,10 @@ def _finish(self):
8095 def _worker (self ):
8196 try :
8297 for output in self ._dequeue ():
83- self ._enqueue (output )
98+ if not self ._shutdown_event .is_set ():
99+ self ._enqueue (output )
84100 except Exception as e :
101+ self ._shutdown_event .set ()
85102 self ._q_err .put (e )
86103 finally :
87104 self ._finish ()
0 commit comments