11from __future__ import annotations
22
3- import multiprocessing as mp
43import queue
4+ import threading
55from typing import TYPE_CHECKING , Union
66
77from .queue_io import DequeueFactory , EnqueueFactory
88from ..util .sentinel import StopSentinel
9- from ..util .value import ThreadingValue
9+ from ..util .counter import MultiprocessingCounter , ThreadingCounter
1010
1111if TYPE_CHECKING :
12- from multiprocessing .synchronize import Event as ProcessEvent
13- from threading import Event as ThreadEvent
12+ from multiprocessing .managers import SyncManager
13+ import multiprocessing .queues as mpq
14+ import multiprocessing .synchronize as mpsync
1415 from ..util .worker_pool import WorkerPool
1516 from ..task import Task
1617
@@ -20,17 +21,16 @@ def __init__(
2021 self ,
2122 task : Task ,
2223 next_task : Task ,
23- q_err : Union [ mp . Queue , queue . Queue ] ,
24- shutdown_event : Union [ProcessEvent , ThreadEvent ]):
24+ manager : SyncManager ,
25+ shutdown_event : Union [mpsync . Event , threading . Event ]):
2526 if task .workers > 1 :
2627 raise RuntimeError (f"The first task in a pipeline ({ task .func .__qualname__ } ) cannot have more than 1 worker" )
2728 if task .join :
2829 raise RuntimeError (f"The first task in a pipeline ({ task .func .__qualname__ } ) cannot join previous results" )
29- self .q_out = mp .Queue (maxsize = task .throttle ) \
30+ self .q_out = manager .Queue (maxsize = task .throttle ) \
3031 if task .multiprocess or (next_task is not None and next_task .multiprocess ) \
3132 else queue .Queue (maxsize = task .throttle )
3233
33- self ._q_err = q_err
3434 self ._shutdown_event = shutdown_event
3535 self ._n_workers = task .workers
3636 self ._n_consumers = 1 if next_task is None else next_task .workers
@@ -39,9 +39,9 @@ def __init__(
3939 def _worker (self , * args , ** kwargs ):
4040 try :
4141 self ._enqueue (* args , ** kwargs )
42- except Exception as e :
42+ except Exception :
4343 self ._shutdown_event .set ()
44- self . _q_err . put ( e )
44+ raise
4545 finally :
4646 for _ in range (self ._n_consumers ):
4747 self .q_out .put (StopSentinel )
@@ -53,45 +53,37 @@ def start(self, pool: WorkerPool, /, *args, **kwargs):
5353class ProducerConsumer :
5454 def __init__ (
5555 self ,
56- q_in : Union [mp .Queue , queue .Queue ],
56+ q_in : Union [mpq .Queue , queue .Queue ],
5757 task : Task ,
5858 next_task : Task ,
59- q_err : Union [ mp . Queue , queue . Queue ] ,
60- shutdown_event : Union [ProcessEvent , ThreadEvent ]):
59+ manager : SyncManager ,
60+ shutdown_event : Union [mpsync . Event , threading . Event ]):
6161 # The output queue is shared between this task and the next. We optimize here by using queue.Queue wherever possible
62- # and only using multiprocess. Queue when the current task or the next task are multiprocessed
63- self .q_out = mp .Queue (maxsize = task .throttle ) \
62+ # and only using a multiprocess Queue when the current task or the next task are multiprocessed
63+ self .q_out = manager .Queue (maxsize = task .throttle ) \
6464 if task .multiprocess or (next_task is not None and next_task .multiprocess ) \
6565 else queue .Queue (maxsize = task .throttle )
6666
67- self ._q_err = q_err
6867 self ._shutdown_event = shutdown_event
6968 self ._n_workers = task .workers
7069 self ._n_consumers = 1 if next_task is None else next_task .workers
7170 self ._dequeue = DequeueFactory (q_in , task )
7271 self ._enqueue = EnqueueFactory (self .q_out , task )
73- self ._workers_done = mp .Value ('i' , 0 ) if task .multiprocess else ThreadingValue (0 )
74-
75- def _increment_workers_done (self ):
76- with self ._workers_done .get_lock ():
77- self ._workers_done .value += 1
78- return self ._workers_done .value
79-
80- def _finish (self ):
81- if self ._increment_workers_done () == self ._n_workers :
82- for _ in range (self ._n_consumers ):
83- self .q_out .put (StopSentinel )
72+ self ._workers_done = MultiprocessingCounter (0 , manager ) if task .multiprocess else ThreadingCounter (0 )
8473
8574 def _worker (self ):
8675 try :
8776 for output in self ._dequeue ():
8877 if not self ._shutdown_event .is_set ():
8978 self ._enqueue (output )
90- except Exception as e :
79+ except Exception :
9180 self ._shutdown_event .set ()
92- self . _q_err . put ( e )
81+ raise
9382 finally :
94- self ._finish ()
83+ if self ._workers_done .increment () == self ._n_workers :
84+ for _ in range (self ._n_consumers ):
85+ self .q_out .put (StopSentinel )
86+
9587
9688 def start (self , pool : WorkerPool , / ):
9789 for _ in range (self ._n_workers ):
0 commit comments