22
33import queue
44import threading
5+ from types import SimpleNamespace
56from typing import TYPE_CHECKING , Union
67
78from .queue_io import DequeueFactory , EnqueueFactory
89from ..util .sentinel import StopSentinel
9- from ..util .counter import MultiprocessingCounter , ThreadingCounter
1010
1111if TYPE_CHECKING :
1212 from multiprocessing .managers import SyncManager
@@ -24,9 +24,9 @@ def __init__(
2424 manager : SyncManager ,
2525 shutdown_event : Union [mpsync .Event , threading .Event ]):
2626 if task .workers > 1 :
27- raise RuntimeError (f"The first task in a pipeline ({ task .func . __qualname__ } ) cannot have more than 1 worker" )
27+ raise RuntimeError (f"The first task in a pipeline ({ task .func } ) cannot have more than 1 worker" )
2828 if task .join :
29- raise RuntimeError (f"The first task in a pipeline ({ task .func . __qualname__ } ) cannot join previous results" )
29+ raise RuntimeError (f"The first task in a pipeline ({ task .func } ) cannot join previous results" )
3030 self .q_out = manager .Queue (maxsize = task .throttle ) \
3131 if task .multiprocess or (next_task is not None and next_task .multiprocess ) \
3232 else queue .Queue (maxsize = task .throttle )
@@ -69,7 +69,8 @@ def __init__(
6969 self ._n_consumers = 1 if next_task is None else next_task .workers
7070 self ._dequeue = DequeueFactory (q_in , task )
7171 self ._enqueue = EnqueueFactory (self .q_out , task )
72- self ._workers_done = MultiprocessingCounter (0 , manager ) if task .multiprocess else ThreadingCounter (0 )
72+ self ._workers_done = manager .Value ('i' , 0 ) if task .multiprocess else SimpleNamespace (value = 0 )
73+ self ._workers_done_lock = manager .Lock () if task .multiprocess else threading .Lock ()
7374
7475 def _worker (self ):
7576 try :
@@ -80,11 +81,12 @@ def _worker(self):
8081 self ._shutdown_event .set ()
8182 raise
8283 finally :
83- if self ._workers_done .increment () == self ._n_workers :
84- for _ in range (self ._n_consumers ):
85- self .q_out .put (StopSentinel )
84+ with self ._workers_done_lock :
85+ self ._workers_done .value += 1
86+ if self ._workers_done .value == self ._n_workers :
87+ for _ in range (self ._n_consumers ):
88+ self .q_out .put (StopSentinel )
8689
87-
8890 def start (self , pool : WorkerPool , / ):
8991 for _ in range (self ._n_workers ):
9092 pool .submit (self ._worker )
0 commit comments