Skip to content

Commit 2e441e2

Browse files
committed
implement explicit branch parameter (BREAKING CHANGE)
1 parent 5a19b36 commit 2e441e2

File tree

5 files changed

+47
-21
lines changed

5 files changed

+47
-21
lines changed

src/pyper/_core/async_helper/queue_io.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncIterable, Iterable
34
from typing import TYPE_CHECKING
45

56
from ..util.sentinel import StopSentinel
@@ -39,7 +40,7 @@ async def __call__(self):
3940

4041

4142
def AsyncEnqueueFactory(q_out: asyncio.Queue, task: Task):
42-
return _BranchingAsyncEnqueue(q_out=q_out, task=task) if task.is_gen \
43+
return _BranchingAsyncEnqueue(q_out=q_out, task=task) if task.branch \
4344
else _SingleAsyncEnqueue(q_out=q_out, task=task)
4445

4546

@@ -60,5 +61,13 @@ async def __call__(self, *args, **kwargs):
6061

6162
class _BranchingAsyncEnqueue(_AsyncEnqueue):
6263
async def __call__(self, *args, **kwargs):
63-
async for output in self.task.func(*args, **kwargs):
64-
await self.q_out.put(output)
64+
result = self.task.func(*args, **kwargs)
65+
if isinstance(result, AsyncIterable):
66+
async for output in result:
67+
await self.q_out.put(output)
68+
elif isinstance(data := await result, Iterable):
69+
for output in data:
70+
await self.q_out.put(output)
71+
else:
72+
raise TypeError(f"got object of type {type(data)} from branching task {self.task.func} which could not be iterated over"
73+
" (the task should be a generator, or return an iterable)")

src/pyper/_core/decorators.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919

2020

2121
class task:
22-
"""Decorator class to initialize a `Pipeline` consisting of one task, from a function or callable.
22+
"""Decorator class to initialize a `Pipeline` consisting of one task.
2323
2424
Args:
25-
func (callable): A positional-only param defining the task function
25+
func (callable): A positional-only param defining the task function (can be omitted when using `@task`)
26+
branch (bool): Allows the task to submit multiple outputs
2627
join (bool): Allows the task to take all previous results as input, instead of single results
2728
workers (int): Defines the number of workers to run the task
2829
throttle (int): Limits the number of results the task is able to produce when all consumers are busy
@@ -32,7 +33,7 @@ class task:
3233
Returns:
3334
Pipeline: A `Pipeline` instance consisting of one task.
3435
35-
Example:
36+
Example:
3637
```python
3738
def f(x: int):
3839
return x + 1
@@ -46,6 +47,7 @@ def __new__(
4647
func: None = None,
4748
/,
4849
*,
50+
branch: bool = False,
4951
join: bool = False,
5052
workers: int = 1,
5153
throttle: int = 0,
@@ -55,21 +57,23 @@ def __new__(
5557
@t.overload
5658
def __new__(
5759
cls,
58-
func: t.Callable[_P, t.Awaitable[_R]],
60+
func: t.Callable[_P, t.Union[t.Awaitable[t.Iterable[_R]], t.AsyncGenerator[_R]]],
5961
/,
6062
*,
63+
branch: True,
6164
join: bool = False,
6265
workers: int = 1,
6366
throttle: int = 0,
6467
multiprocess: bool = False,
6568
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
66-
69+
6770
@t.overload
6871
def __new__(
6972
cls,
70-
func: t.Callable[_P, t.AsyncGenerator[_R]],
73+
func: t.Callable[_P, t.Awaitable[_R]],
7174
/,
7275
*,
76+
branch: bool = False,
7377
join: bool = False,
7478
workers: int = 1,
7579
throttle: int = 0,
@@ -79,9 +83,10 @@ def __new__(
7983
@t.overload
8084
def __new__(
8185
cls,
82-
func: t.Callable[_P, t.Generator[_R]],
86+
func: t.Callable[_P, t.Iterable[_R]],
8387
/,
8488
*,
89+
branch: True,
8590
join: bool = False,
8691
workers: int = 1,
8792
throttle: int = 0,
@@ -94,6 +99,7 @@ def __new__(
9499
func: t.Callable[_P, _R],
95100
/,
96101
*,
102+
branch: bool = False,
97103
join: bool = False,
98104
workers: int = 1,
99105
throttle: int = 0,
@@ -105,15 +111,16 @@ def __new__(
105111
func: t.Optional[t.Callable] = None,
106112
/,
107113
*,
114+
branch: bool = False,
108115
join: bool = False,
109116
workers: int = 1,
110117
throttle: int = 0,
111118
multiprocess: bool = False,
112119
bind: _ArgsKwargs = None):
113120
# Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
114121
if func is None:
115-
return functools.partial(cls, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)
116-
return Pipeline([Task(func=func, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)])
122+
return functools.partial(cls, branch=branch, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)
123+
return Pipeline([Task(func=func, branch=branch, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)])
117124

118125
@staticmethod
119126
def bind(*args, **kwargs) -> _ArgsKwargs:
@@ -131,4 +138,3 @@ def f(x: int, y: int):
131138
if not args and not kwargs:
132139
return None
133140
return args, kwargs
134-

src/pyper/_core/sync_helper/queue_io.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Iterable
34
from typing import TYPE_CHECKING, Union
45

56
from ..util.sentinel import StopSentinel
@@ -40,7 +41,7 @@ def __call__(self):
4041

4142

4243
def EnqueueFactory(q_out: Union[mp.Queue, queue.Queue], task: Task):
43-
return _BranchingEnqueue(q_out=q_out, task=task) if task.is_gen \
44+
return _BranchingEnqueue(q_out=q_out, task=task) if task.branch \
4445
else _SingleEnqueue(q_out=q_out, task=task)
4546

4647

@@ -61,5 +62,11 @@ def __call__(self, *args, **kwargs):
6162

6263
class _BranchingEnqueue(_Enqueue):
6364
def __call__(self, *args, **kwargs):
64-
for output in self.task.func(*args, **kwargs):
65-
self.q_out.put(output)
65+
result = self.task.func(*args, **kwargs)
66+
if isinstance(result, Iterable):
67+
for output in result:
68+
self.q_out.put(output)
69+
else:
70+
raise TypeError(
71+
f"got object of type {type(result)} from branching task {self.task.func} which could not be iterated over."
72+
" (the task should be a generator, or return an iterable)")

src/pyper/_core/task.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,20 @@ class Task:
99
"""The representation of a function within a Pipeline."""
1010

1111
__slots__ = (
12-
"is_gen",
13-
"is_async",
1412
"func",
13+
"branch",
1514
"join",
1615
"workers",
1716
"throttle",
18-
"multiprocess"
17+
"multiprocess",
18+
"is_async",
19+
"is_gen"
1920
)
2021

2122
def __init__(
2223
self,
2324
func: Callable,
25+
branch: bool = False,
2426
join: bool = False,
2527
workers: int = 1,
2628
throttle: int = 0,
@@ -35,7 +37,7 @@ def __init__(
3537
if throttle < 0:
3638
raise ValueError("throttle cannot be less than 0")
3739
if not callable(func):
38-
raise TypeError("A task must be a callable object")
40+
raise TypeError("A task function must be a callable object")
3941

4042
self.is_gen = inspect.isgeneratorfunction(func) \
4143
or inspect.isasyncgenfunction(func) \
@@ -50,6 +52,7 @@ def __init__(
5052
raise ValueError("multiprocess cannot be True for an async task")
5153

5254
self.func = func if bind is None else functools.partial(func, *bind[0], **bind[1])
55+
self.branch = branch
5356
self.join = join
5457
self.workers = workers
5558
self.throttle = throttle

src/pyper/_core/util/asynchronize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def ascynchronize(task: Task, tp: ThreadPoolExecutor, pp: ProcessPoolExecutor) -
1717
if task.is_async:
1818
return task
1919

20-
if task.is_gen:
20+
if task.is_gen and task.branch:
2121
@functools.wraps(task.func)
2222
async def wrapper(*args, **kwargs):
2323
for output in task.func(*args, **kwargs):
@@ -31,6 +31,7 @@ async def wrapper(*args, **kwargs):
3131
return await loop.run_in_executor(executor=executor, func=f)
3232
return Task(
3333
func=wrapper,
34+
branch=task.branch,
3435
join=task.join,
3536
workers=task.workers,
3637
throttle=task.throttle

0 commit comments

Comments
 (0)