Skip to content

Commit 935d8cb

Browse files
committed
update tests, along with small refactors
1 parent 2e441e2 commit 935d8cb

File tree

11 files changed

+132
-91
lines changed

11 files changed

+132
-91
lines changed

src/pyper/_core/async_helper/queue_io.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncIterable, Iterable
3+
from collections.abc import Iterable
44
from typing import TYPE_CHECKING
55

66
from ..util.sentinel import StopSentinel
@@ -61,13 +61,12 @@ async def __call__(self, *args, **kwargs):
6161

6262
class _BranchingAsyncEnqueue(_AsyncEnqueue):
6363
async def __call__(self, *args, **kwargs):
64-
result = self.task.func(*args, **kwargs)
65-
if isinstance(result, AsyncIterable):
66-
async for output in result:
64+
if self.task.is_gen:
65+
async for output in self.task.func(*args, **kwargs):
6766
await self.q_out.put(output)
68-
elif isinstance(data := await result, Iterable):
69-
for output in data:
67+
elif isinstance(result := await self.task.func(*args, **kwargs), Iterable):
68+
for output in result:
7069
await self.q_out.put(output)
7170
else:
72-
raise TypeError(f"got object of type {type(data)} from branching task {self.task.func} which could not be iterated over"
71+
raise TypeError(f"got object of type {type(result)} from branching task {self.task.func} which could not be iterated over"
7372
" (the task should be a generator, or return an iterable)")

src/pyper/_core/sync_helper/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __call__(self, *args, **kwargs):
4545
except queue.Empty:
4646
tp.raise_error_if_exists()
4747
pp.raise_error_if_exists()
48-
except (KeyboardInterrupt, SystemExit):
48+
except (KeyboardInterrupt, SystemExit): # pragma: no cover
4949
tp.shutdown_event.set()
5050
pp.shutdown_event.set()
5151
raise

src/pyper/_core/sync_helper/queue_io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def __call__(self, *args, **kwargs):
6262

6363
class _BranchingEnqueue(_Enqueue):
6464
def __call__(self, *args, **kwargs):
65-
result = self.task.func(*args, **kwargs)
66-
if isinstance(result, Iterable):
65+
if isinstance(result := self.task.func(*args, **kwargs), Iterable):
6766
for output in result:
6867
self.q_out.put(output)
6968
else:

src/pyper/_core/sync_helper/stage.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import multiprocessing as mp
44
import queue
5-
import threading
65
from typing import TYPE_CHECKING, Union
76

87
from .queue_io import DequeueFactory, EnqueueFactory
98
from ..util.sentinel import StopSentinel
9+
from ..util.value import ThreadingValue
1010

1111
if TYPE_CHECKING:
12-
from multiprocessing.synchronize import Event as MpEvent
12+
from multiprocessing.synchronize import Event as ProcessEvent
13+
from threading import Event as ThreadEvent
1314
from ..util.worker_pool import WorkerPool
1415
from ..task import Task
1516

@@ -20,7 +21,7 @@ def __init__(
2021
task: Task,
2122
next_task: Task,
2223
q_err: Union[mp.Queue, queue.Queue],
23-
shutdown_event: Union[MpEvent, threading.Event]):
24+
shutdown_event: Union[ProcessEvent, ThreadEvent]):
2425
if task.workers > 1:
2526
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have more than 1 worker")
2627
if task.join:
@@ -56,7 +57,7 @@ def __init__(
5657
task: Task,
5758
next_task: Task,
5859
q_err: Union[mp.Queue, queue.Queue],
59-
shutdown_event: Union[MpEvent, threading.Event]):
60+
shutdown_event: Union[ProcessEvent, ThreadEvent]):
6061
# The output queue is shared between this task and the next. We optimize here by using queue.Queue wherever possible
6162
# and only using multiprocess.Queue when the current task or the next task are multiprocessed
6263
self.q_out = mp.Queue(maxsize=task.throttle) \
@@ -69,23 +70,12 @@ def __init__(
6970
self._n_consumers = 1 if next_task is None else next_task.workers
7071
self._dequeue = DequeueFactory(q_in, task)
7172
self._enqueue = EnqueueFactory(self.q_out, task)
72-
73-
self._multiprocess = task.multiprocess
74-
if self._multiprocess:
75-
self._workers_done = mp.Value('i', 0)
76-
self._lock = self._workers_done.get_lock()
77-
else:
78-
self._workers_done = 0
79-
self._lock = threading.Lock()
73+
self._workers_done = mp.Value('i', 0) if task.multiprocess else ThreadingValue(0)
8074

8175
def _increment_workers_done(self):
82-
with self._lock:
83-
if self._multiprocess:
84-
self._workers_done.value += 1
85-
return self._workers_done.value
86-
else:
87-
self._workers_done += 1
88-
return self._workers_done
76+
with self._workers_done.get_lock():
77+
self._workers_done.value += 1
78+
return self._workers_done.value
8979

9080
def _finish(self):
9181
if self._increment_workers_done() == self._n_workers:

src/pyper/_core/util/value.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
import threading
4+
from typing import Any
5+
6+
7+
class ThreadingValue:
8+
"""Utility class to help manage thread based access to a value.
9+
10+
The `get_lock` method mimics the API of `multiprocessing.Value`
11+
"""
12+
def __init__(self, value: Any):
13+
self.value = value
14+
self._lock = threading.Lock()
15+
16+
def get_lock(self):
17+
return self._lock

src/pyper/_core/util/worker_pool.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class WorkerPool:
1616
worker_type = None
1717

1818
def __init__(self):
19-
self.error_queue = mp.Queue(1) if self.worker_type is mp.Process else queue.Queue(1)
19+
self.error_queue = mp.Queue() if self.worker_type is mp.Process else queue.Queue()
2020
self.shutdown_event = mp.Event() if self.worker_type is mp.Process else threading.Event()
2121

2222
self._workers: List[Union[mp.Process, threading.Thread]] = []
@@ -35,9 +35,6 @@ def has_error(self):
3535
def get_error(self) -> Exception:
3636
return self.error_queue.get()
3737

38-
def put_error(self, e: Exception):
39-
self.error_queue.put(e)
40-
4138
def raise_error_if_exists(self):
4239
if self.has_error:
4340
raise self.get_error() from None

tests/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ COPY .git .git
1313

1414
RUN chmod +x tests/entrypoint.sh
1515

16-
ENTRYPOINT ["tests/entrypoint.sh"]
16+
ENTRYPOINT ["/bin/bash", "tests/entrypoint.sh"]

tests/test_async.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pyper import task
22
import pytest
33

4+
class TestError(Exception): ...
45

56
def f1(data):
67
return data
@@ -9,7 +10,10 @@ def f2(data):
910
yield data
1011

1112
def f3(data):
12-
raise RuntimeError
13+
raise TestError
14+
15+
def f4(data):
16+
return [data]
1317

1418
async def af1(data):
1519
return data
@@ -18,7 +22,7 @@ async def af2(data):
1822
yield data
1923

2024
async def af3(data):
21-
raise RuntimeError
25+
raise TestError
2226

2327
async def af4(data):
2428
async for row in data:
@@ -31,24 +35,29 @@ async def consumer(data):
3135
return total
3236

3337
@pytest.mark.asyncio
34-
async def test_pipeline():
35-
p = task(f1) | task(f2)
36-
assert p(1).__next__() == 1
38+
async def test_aiterable_branched_pipeline():
39+
p = task(af1) | task(f2, branch=True)
40+
assert await p(1).__anext__() == 1
41+
42+
@pytest.mark.asyncio
43+
async def test_iterable_branched_pipeline():
44+
p = task(af1) | task(f4, branch=True)
45+
assert await p(1).__anext__() == 1
3746

3847
@pytest.mark.asyncio
3948
async def test_joined_pipeline():
40-
p = task(af1) | task(af2) | task(af4, join=True)
49+
p = task(af1) | task(af2, branch=True) | task(af4, branch=True, join=True)
4150
assert await p(1).__anext__() == 1
4251

4352
@pytest.mark.asyncio
4453
async def test_consumer():
45-
p = task(af1) | task(af2) > consumer
54+
p = task(af1) | task(af2, branch=True) > consumer
4655
assert await p(1) == 1
4756

4857
@pytest.mark.asyncio
49-
async def test_invalid_first_stage_concurrency():
58+
async def test_invalid_first_stage_workers():
5059
try:
51-
p = task(af1, concurrency=2) | task(af2) > consumer
60+
p = task(af1, workers=2) | task(af2, branch=True) > consumer
5261
await p(1)
5362
except Exception as e:
5463
assert isinstance(e, RuntimeError)
@@ -58,35 +67,48 @@ async def test_invalid_first_stage_concurrency():
5867
@pytest.mark.asyncio
5968
async def test_invalid_first_stage_join():
6069
try:
61-
p = task(af1, join=True) | task(af2) > consumer
70+
p = task(af1, join=True) | task(af2, branch=True) > consumer
6271
await p(1)
6372
except Exception as e:
6473
assert isinstance(e, RuntimeError)
6574
else:
6675
raise AssertionError
67-
76+
6877
@pytest.mark.asyncio
69-
async def test_error_handling():
78+
async def test_invalid_branch_result():
7079
try:
71-
p = task(af1) | task(af2) | task(af3) > consumer
80+
p = task(af1, branch=True) > consumer
7281
await p(1)
7382
except Exception as e:
74-
assert isinstance(e, RuntimeError)
83+
assert isinstance(e, TypeError)
7584
else:
7685
raise AssertionError
77-
78-
@pytest.mark.asyncio
79-
async def test_unified_pipeline():
80-
p = task(af1) | task(f1) | task(af2) | task(f2) > consumer
81-
assert await p(1) == 1
8286

83-
@pytest.mark.asyncio
84-
async def test_error_handling_in_daemon():
87+
async def _try_catch_error(pipeline):
8588
try:
86-
p = task(af1) | task(af2) | task(f3, daemon=True) > consumer
89+
p = task(af1) | pipeline > consumer
8790
await p(1)
8891
except Exception as e:
89-
assert isinstance(e, RuntimeError)
92+
return isinstance(e, TestError)
9093
else:
91-
raise AssertionError
92-
94+
return False
95+
96+
@pytest.mark.asyncio
97+
async def test_async_error_handling():
98+
p = task(af3)
99+
assert await _try_catch_error(p)
100+
101+
@pytest.mark.asyncio
102+
async def test_threaded_error_handling():
103+
p = task(f3, workers=2)
104+
assert await _try_catch_error(p)
105+
106+
@pytest.mark.asyncio
107+
async def test_multiprocessed_error_handling():
108+
p = task(f3, workers=2, multiprocess=True)
109+
assert await _try_catch_error(p)
110+
111+
@pytest.mark.asyncio
112+
async def test_unified_pipeline():
113+
p = task(af1) | task(f1) | task(f2, branch=True, multiprocess=True) > consumer
114+
assert await p(1) == 1

tests/test_sync.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
from pyper import task
33

4+
class TestError(Exception): ...
45

56
def f1(data):
67
return data
@@ -18,37 +19,37 @@ def f4(a1, a2, a3, data, k1, k2):
1819
def f5(data):
1920
# Make queue monitor timeout on main thread
2021
time.sleep(0.2)
21-
raise RuntimeError
22+
raise TestError
2223

2324
def consumer(data):
2425
total = 0
2526
for i in data:
2627
total += i
2728
return total
2829

29-
def test_pipeline():
30-
p = task(f1) | task(f2)
30+
def test_branched_pipeline():
31+
p = task(f1) | task(f2, branch=True)
3132
assert p(1).__next__() == 1
3233

3334
def test_joined_pipeline():
34-
p = task(f1) | task(f2) | task(f3, join=True)
35+
p = task(f1) | task(f2, branch=True) | task(f3, branch=True, join=True)
3536
assert p(1).__next__() == 1
3637

3738
def test_bind():
3839
p = task(f1) | task(f4, bind=task.bind(1, 1, 1, k1=1, k2=2))
3940
assert p(1).__next__() == 1
4041

4142
def test_redundant_bind_ok():
42-
p = task(f1) | task(f2, bind=task.bind())
43+
p = task(f1) | task(f2, branch=True, bind=task.bind())
4344
assert p(1).__next__() == 1
4445

4546
def test_consumer():
46-
p = task(f1) | task(f2) > consumer
47+
p = task(f1) | task(f2, branch=True) > consumer
4748
assert p(1) == 1
4849

49-
def test_invalid_first_stage_concurrency():
50+
def test_invalid_first_stage_workers():
5051
try:
51-
p = task(f1, concurrency=2) | task(f2) > consumer
52+
p = task(f1, workers=2) | task(f2) > consumer
5253
p(1)
5354
except Exception as e:
5455
assert isinstance(e, RuntimeError)
@@ -57,29 +58,45 @@ def test_invalid_first_stage_concurrency():
5758

5859
def test_invalid_first_stage_join():
5960
try:
60-
p = task(f1, join=True) | task(f2) > consumer
61+
p = task(f1, join=True) | task(f2, branch=True) > consumer
6162
p(1)
6263
except Exception as e:
6364
assert isinstance(e, RuntimeError)
6465
else:
6566
raise AssertionError
6667

67-
def test_error_handling():
68+
def test_invalid_branch_result():
6869
try:
69-
p = task(f1) | task(f2) | task(f5) > consumer
70+
p = task(f1, branch=True) > consumer
71+
p(1)
72+
except Exception as e:
73+
assert isinstance(e, TypeError)
74+
else:
75+
raise AssertionError
76+
77+
def test_invalid_branch():
78+
try:
79+
p = task(f1, join=True) | task(f2, branch=True) > consumer
7080
p(1)
7181
except Exception as e:
72-
print(e)
7382
assert isinstance(e, RuntimeError)
7483
else:
7584
raise AssertionError
7685

77-
def test_error_handling_in_daemon():
86+
def test_threaded_error_handling():
7887
try:
79-
p = task(f5, daemon=True) | task(f2) > consumer
88+
p = task(f1) | task(f5, workers=2) > consumer
8089
p(1)
8190
except Exception as e:
82-
print(e)
83-
assert isinstance(e, RuntimeError)
91+
assert isinstance(e, TestError)
92+
else:
93+
raise AssertionError
94+
95+
def test_multiprocessed_error_handling():
96+
try:
97+
p = task(f1) | task(f5, workers=2, multiprocess=True) > consumer
98+
p(1)
99+
except Exception as e:
100+
assert isinstance(e, TestError)
84101
else:
85102
raise AssertionError

0 commit comments

Comments
 (0)