Skip to content

Commit 8585e5c

Browse files
committed
remove duplicate test + some tweaks
1 parent 8e95ae1 commit 8585e5c

File tree

3 files changed

+13
-16
lines changed

3 files changed

+13
-16
lines changed

src/pyper/_core/sync_helper/output.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
4-
import queue
3+
from typing import TYPE_CHECKING, Union
54

65
from .stage import Producer, ProducerConsumer
76
from ..util.sentinel import StopSentinel
87
from ..util.worker_pool import ProcessPool, ThreadPool
98

109
if TYPE_CHECKING:
10+
import multiprocessing as mp
11+
import queue
1112
from ..pipeline import Pipeline
1213

1314

1415
class PipelineOutput:
1516
def __init__(self, pipeline: Pipeline):
1617
self.pipeline = pipeline
1718

18-
def _get_q_out(self, tp: ThreadPool, pp: ProcessPool, *args, **kwargs) -> queue.Queue:
19+
def _get_q_out(self, tp: ThreadPool, pp: ProcessPool, *args, **kwargs) -> Union[mp.Queue, queue.Queue]:
1920
"""Feed forward each stage to the next, returning the output queue of the final stage."""
2021
q_out = None
2122
for task, next_task in zip(self.pipeline.tasks, self.pipeline.tasks[1:] + [None]):
@@ -34,5 +35,10 @@ def __call__(self, *args, **kwargs):
3435
"""Iterate through the pipeline, taking the inputs to the first task, and yielding each output from the last task."""
3536
with ThreadPool() as tp, ProcessPool() as pp:
3637
q_out = self._get_q_out(tp, pp, *args, **kwargs)
37-
while (data := q_out.get()) is not StopSentinel:
38-
yield data
38+
try:
39+
while (data := q_out.get()) is not StopSentinel:
40+
yield data
41+
except (KeyboardInterrupt, SystemExit): # pragma: no cover
42+
tp.shutdown_event.set()
43+
pp.shutdown_event.set()
44+
raise

src/pyper/_core/sync_helper/stage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from ..util.sentinel import StopSentinel
1010

1111
if TYPE_CHECKING:
12+
import multiprocessing as mp
1213
from multiprocessing.managers import SyncManager
13-
import multiprocessing.queues as mpq
1414
import multiprocessing.synchronize as mpsync
1515
from ..util.worker_pool import WorkerPool
1616
from ..task import Task
@@ -53,7 +53,7 @@ def start(self, pool: WorkerPool, /, *args, **kwargs):
5353
class ProducerConsumer:
5454
def __init__(
5555
self,
56-
q_in: Union[mpq.Queue, queue.Queue],
56+
q_in: Union[mp.Queue, queue.Queue],
5757
task: Task,
5858
next_task: Task,
5959
manager: SyncManager,

tests/test_sync.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,6 @@ def test_invalid_branch_result():
7373
assert isinstance(e, TypeError)
7474
else:
7575
raise AssertionError
76-
77-
def test_invalid_branch():
78-
try:
79-
p = task(f1, join=True) | task(f2, branch=True) > consumer
80-
p(1)
81-
except Exception as e:
82-
assert isinstance(e, RuntimeError)
83-
else:
84-
raise AssertionError
8576

8677
def test_threaded_error_handling():
8778
try:

0 commit comments

Comments
 (0)