Skip to content

Commit 4c59b9b

Browse files
committed
add fail fast mechanisms for sync api
1 parent cabfe8d commit 4c59b9b

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

src/pyper/_core/sync_helper/output.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def _get_q_out(self, tp: ThreadPool, pp: ProcessPool, *args, **kwargs) -> queue.
2121
for task, next_task in zip(self.pipeline.tasks, self.pipeline.tasks[1:] + [None]):
2222
pool = pp if task.multiprocess else tp
2323
if q_out is None:
24-
stage = Producer(task=task, next_task=next_task, q_err=pool.error_queue)
24+
stage = Producer(task=task, next_task=next_task, q_err=pool.error_queue, shutdown_event=pool.shutdown_event)
2525
stage.start(pool, *args, **kwargs)
2626
else:
27-
stage = ProducerConsumer(q_in=q_out, task=task, next_task=next_task, q_err=pool.error_queue)
27+
stage = ProducerConsumer(q_in=q_out, task=task, next_task=next_task, q_err=pool.error_queue, shutdown_event=pool.shutdown_event)
2828
stage.start(pool)
2929
q_out = stage.q_out
3030

@@ -45,3 +45,7 @@ def __call__(self, *args, **kwargs):
4545
except queue.Empty:
4646
tp.raise_error_if_exists()
4747
pp.raise_error_if_exists()
48+
except (KeyboardInterrupt, SystemExit):
49+
tp.shutdown_event.set()
50+
pp.shutdown_event.set()
51+
raise

src/pyper/_core/sync_helper/stage.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@
99
from ..util.sentinel import StopSentinel
1010

1111
if TYPE_CHECKING:
12+
from multiprocessing.synchronize import Event as MpEvent
1213
from ..util.worker_pool import WorkerPool
1314
from ..task import Task
1415

1516

1617
class Producer:
17-
def __init__(self, task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Queue]):
18+
def __init__(
19+
self,
20+
task: Task,
21+
next_task: Task,
22+
q_err: Union[mp.Queue, queue.Queue],
23+
shutdown_event: Union[MpEvent, threading.Event]):
1824
if task.concurrency > 1:
1925
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1")
2026
if task.join:
@@ -24,16 +30,16 @@ def __init__(self, task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Que
2430
else queue.Queue(maxsize=task.throttle)
2531

2632
self._q_err = q_err
33+
self._shutdown_event = shutdown_event
2734
self._n_workers = task.concurrency
2835
self._n_consumers = 1 if next_task is None else next_task.concurrency
2936
self._enqueue = EnqueueFactory(self.q_out, task)
3037

31-
self._multiprocess = task.multiprocess
32-
3338
def _worker(self, *args, **kwargs):
3439
try:
3540
self._enqueue(*args, **kwargs)
3641
except Exception as e:
42+
self._shutdown_event.set()
3743
self._q_err.put(e)
3844
finally:
3945
for _ in range(self._n_consumers):
@@ -44,12 +50,21 @@ def start(self, pool: WorkerPool, /, *args, **kwargs):
4450

4551

4652
class ProducerConsumer:
47-
def __init__(self, q_in: Union[mp.Queue, queue.Queue], task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Queue]):
53+
def __init__(
54+
self,
55+
q_in: Union[mp.Queue, queue.Queue],
56+
task: Task,
57+
next_task: Task,
58+
q_err: Union[mp.Queue, queue.Queue],
59+
shutdown_event: Union[MpEvent, threading.Event]):
60+
# The output queue is shared between this task and the next. We optimize here by using queue.Queue wherever possible
61+
# and only using multiprocess.Queue when the current task or the next task are multiprocessed
4862
self.q_out = mp.Queue(maxsize=task.throttle) \
4963
if task.multiprocess or (next_task is not None and next_task.multiprocess) \
5064
else queue.Queue(maxsize=task.throttle)
5165

5266
self._q_err = q_err
67+
self._shutdown_event = shutdown_event
5368
self._n_workers = task.concurrency
5469
self._n_consumers = 1 if next_task is None else next_task.concurrency
5570
self._dequeue = DequeueFactory(q_in, task)
@@ -80,8 +95,10 @@ def _finish(self):
8095
def _worker(self):
8196
try:
8297
for output in self._dequeue():
83-
self._enqueue(output)
98+
if not self._shutdown_event.is_set():
99+
self._enqueue(output)
84100
except Exception as e:
101+
self._shutdown_event.set()
85102
self._q_err.put(e)
86103
finally:
87104
self._finish()

src/pyper/_core/util/worker_pool.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@ class WorkerPool:
1313
2. Provides a mechanism to capture and propagate errors to the main thread/process
1414
3. Ensures safe tear-down of all workers
1515
"""
16-
worker_type: type = None
17-
18-
error_queue: Union[mp.Queue, queue.Queue]
19-
_workers: List[Union[mp.Process, threading.Thread]]
16+
worker_type = None
2017

2118
def __init__(self):
2219
self.error_queue = mp.Queue(1) if self.worker_type is mp.Process else queue.Queue(1)
23-
self._workers = []
20+
self.shutdown_event = mp.Event() if self.worker_type is mp.Process else threading.Event()
21+
22+
self._workers: List[Union[mp.Process, threading.Thread]] = []
2423

2524
def __enter__(self):
2625
return self

0 commit comments

Comments
 (0)