Skip to content

Commit 1afc821

Browse files
committed
remove daemon support and refactor async handlers
1 parent 6c34687 commit 1afc821

File tree

6 files changed

+77
-118
lines changed

6 files changed

+77
-118
lines changed

src/pyper/_core/async_helper/output.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from concurrent.futures import ProcessPoolExecutor
4+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
55
import sys
66
from typing import TYPE_CHECKING
77

88
from .stage import AsyncProducer, AsyncProducerConsumer
9+
from ..util.asynchronize import ascynchronize
910
from ..util.sentinel import StopSentinel
10-
from ..util.thread_pool import ThreadPool
1111

1212
if sys.version_info < (3, 11): # pragma: no cover
1313
from ..util.task_group import TaskGroup, ExceptionGroup
@@ -22,30 +22,32 @@ class AsyncPipelineOutput:
2222
def __init__(self, pipeline: AsyncPipeline):
2323
self.pipeline = pipeline
2424

25-
def _get_q_out(self, tg: TaskGroup, tp: ThreadPool, pp: ProcessPoolExecutor, *args, **kwargs) -> asyncio.Queue:
25+
def _get_q_out(self, tg: TaskGroup, tp: ThreadPoolExecutor, pp: ProcessPoolExecutor, *args, **kwargs) -> asyncio.Queue:
2626
"""Feed forward each stage to the next, returning the output queue of the final stage."""
2727
q_out = None
2828
for task, next_task in zip(self.pipeline.tasks, self.pipeline.tasks[1:] + [None]):
29-
n_consumers = 1 if next_task is None else next_task.concurrency
29+
task = ascynchronize(task, tp=tp, pp=pp)
3030
if q_out is None:
31-
stage = AsyncProducer(task=self.pipeline.tasks[0], tg=tg, tp=tp, pp=pp, n_consumers=n_consumers)
32-
stage.start(*args, **kwargs)
31+
stage = AsyncProducer(task=task, next_task=next_task)
32+
stage.start(tg, *args, **kwargs)
3333
else:
34-
stage = AsyncProducerConsumer(q_in=q_out, task=task, tg=tg, tp=tp, pp=pp, n_consumers=n_consumers)
35-
stage.start()
34+
stage = AsyncProducerConsumer(q_in=q_out, task=task, next_task=next_task)
35+
stage.start(tg)
3636
q_out = stage.q_out
3737

3838
return q_out
3939

4040
async def __call__(self, *args, **kwargs):
41-
"""Call the pipeline, taking the inputs to the first task, and returning the output from the last task."""
41+
"""Iterate through the pipeline, taking the inputs to the first task, and yielding each output from the last task.
42+
43+
Unify async, threaded, and multiprocessed work by:
44+
1. using TaskGroup to execute asynchronous tasks
45+
2. using ThreadPoolExecutor to execute threaded synchronous tasks
46+
3. using ProcessPoolExecutor to execute multiprocessed synchronous tasks
47+
"""
4248
try:
43-
# Unify async, threaded, and multiprocessed work by:
44-
# 1. using TaskGroup to execute asynchronous tasks
45-
# 2. using ThreadPool to execute threaded synchronous tasks
46-
# 3. using ProcessPoolExecutor to execute multiprocessed synchronous tasks
4749
async with TaskGroup() as tg:
48-
with ThreadPool() as tp, ProcessPoolExecutor() as pp:
50+
with ThreadPoolExecutor() as tp, ProcessPoolExecutor() as pp:
4951
q_out = self._get_q_out(tg, tp, pp, *args, **kwargs)
5052
while (data := await q_out.get()) is not StopSentinel:
5153
yield data

src/pyper/_core/async_helper/queue_io.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,15 @@
99
from ..task import Task
1010

1111

12-
class AsyncDequeue:
12+
def AsyncDequeueFactory(q_in: asyncio.Queue, task: Task):
13+
return _JoiningAsyncDequeue(q_in=q_in) if task.join \
14+
else _SingleAsyncDequeue(q_in=q_in)
15+
16+
17+
class _AsyncDequeue:
1318
"""Pulls data from an input queue."""
14-
def __new__(self, q_in: asyncio.Queue, task: Task):
15-
if task.join:
16-
instance = object.__new__(_JoiningAsyncDequeue)
17-
else:
18-
instance = object.__new__(_SingleAsyncDequeue)
19-
instance.__init__(q_in=q_in, task=task)
20-
return instance
21-
22-
def __init__(self, q_in: asyncio.Queue, task: Task):
23-
self.q_in = q_in
24-
self.task = task
19+
def __init__(self, q_in: asyncio.Queue):
20+
self.q_in = q_in
2521

2622
async def _input_stream(self):
2723
while (data := await self.q_in.get()) is not StopSentinel:
@@ -31,41 +27,38 @@ def __call__(self):
3127
raise NotImplementedError
3228

3329

34-
class _SingleAsyncDequeue(AsyncDequeue):
30+
class _SingleAsyncDequeue(_AsyncDequeue):
3531
async def __call__(self):
3632
async for data in self._input_stream():
3733
yield data
3834

3935

40-
class _JoiningAsyncDequeue(AsyncDequeue):
36+
class _JoiningAsyncDequeue(_AsyncDequeue):
4137
async def __call__(self):
4238
yield self._input_stream()
4339

4440

45-
class AsyncEnqueue:
46-
"""Puts output from a task onto an output queue."""
47-
def __new__(cls, q_out: asyncio.Queue, task: Task):
48-
if task.is_gen:
49-
instance = object.__new__(_BranchingAsyncEnqueue)
50-
else:
51-
instance = object.__new__(_SingleAsyncEnqueue)
52-
instance.__init__(q_out=q_out, task=task)
53-
return instance
41+
def AsyncEnqueueFactory(q_out: asyncio.Queue, task: Task):
42+
return _BranchingAsyncEnqueue(q_out=q_out, task=task) if task.is_gen \
43+
else _SingleAsyncEnqueue(q_out=q_out, task=task)
5444

45+
46+
class _AsyncEnqueue:
47+
"""Puts output from a task onto an output queue."""
5548
def __init__(self, q_out: asyncio.Queue, task: Task):
56-
self.q_out = q_out
57-
self.task = task
49+
self.q_out = q_out
50+
self.task = task
5851

5952
async def __call__(self, *args, **kwargs):
6053
raise NotImplementedError
6154

6255

63-
class _SingleAsyncEnqueue(AsyncEnqueue):
56+
class _SingleAsyncEnqueue(_AsyncEnqueue):
6457
async def __call__(self, *args, **kwargs):
6558
await self.q_out.put(await self.task.func(*args, **kwargs))
6659

6760

68-
class _BranchingAsyncEnqueue(AsyncEnqueue):
61+
class _BranchingAsyncEnqueue(_AsyncEnqueue):
6962
async def __call__(self, *args, **kwargs):
7063
async for output in self.task.func(*args, **kwargs):
7164
await self.q_out.put(output)
Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from concurrent.futures import ProcessPoolExecutor
54
import sys
65
from typing import TYPE_CHECKING
76

8-
from .queue_io import AsyncDequeue, AsyncEnqueue
9-
from ..util.asynchronize import ascynchronize
7+
from .queue_io import AsyncDequeueFactory, AsyncEnqueueFactory
108
from ..util.sentinel import StopSentinel
119

1210
if sys.version_info < (3, 11): # pragma: no cover
@@ -15,67 +13,50 @@
1513
from asyncio import TaskGroup
1614

1715
if TYPE_CHECKING:
18-
from ..util.thread_pool import ThreadPool
1916
from ..task import Task
2017

2118

2219
class AsyncProducer:
23-
def __init__(
24-
self,
25-
task: Task,
26-
tg: TaskGroup,
27-
tp: ThreadPool,
28-
pp: ProcessPoolExecutor,
29-
n_consumers: int):
30-
self.task = ascynchronize(task, tp, pp)
20+
def __init__(self, task: Task, next_task: Task):
3121
if task.concurrency > 1:
3222
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1")
3323
if task.join:
3424
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results")
35-
self.tg = tg
36-
self.n_consumers = n_consumers
25+
self.task = task
3726
self.q_out = asyncio.Queue(maxsize=task.throttle)
3827

39-
self._enqueue = AsyncEnqueue(self.q_out, self.task)
28+
self._n_consumers = 1 if next_task is None else next_task.concurrency
29+
self._enqueue = AsyncEnqueueFactory(self.q_out, self.task)
4030

4131
async def _worker(self, *args, **kwargs):
4232
await self._enqueue(*args, **kwargs)
4333

44-
for _ in range(self.n_consumers):
34+
for _ in range(self._n_consumers):
4535
await self.q_out.put(StopSentinel)
4636

47-
def start(self, *args, **kwargs):
48-
self.tg.create_task(self._worker(*args, **kwargs))
37+
def start(self, tg: TaskGroup, /, *args, **kwargs):
38+
tg.create_task(self._worker(*args, **kwargs))
4939

5040

5141
class AsyncProducerConsumer:
52-
def __init__(
53-
self,
54-
q_in: asyncio.Queue,
55-
task: Task,
56-
tg: TaskGroup,
57-
tp: ThreadPool,
58-
pp: ProcessPoolExecutor,
59-
n_consumers: int):
60-
self.q_in = q_in
61-
self.task = ascynchronize(task, tp, pp)
62-
self.tg = tg
63-
self.n_consumers = n_consumers
42+
def __init__(self, q_in: asyncio.Queue, task: Task, next_task: Task):
6443
self.q_out = asyncio.Queue(maxsize=task.throttle)
6544

45+
self._n_workers = task.concurrency
46+
self._n_consumers = 1 if next_task is None else next_task.concurrency
47+
self._dequeue = AsyncDequeueFactory(q_in, task)
48+
self._enqueue = AsyncEnqueueFactory(self.q_out, task)
6649
self._workers_done = 0
67-
self._dequeue = AsyncDequeue(self.q_in, self.task)
68-
self._enqueue = AsyncEnqueue(self.q_out, self.task)
6950

7051
async def _worker(self):
7152
async for output in self._dequeue():
7253
await self._enqueue(output)
7354

7455
self._workers_done += 1
75-
if self._workers_done == self.task.concurrency:
76-
for _ in range(self.n_consumers):
56+
if self._workers_done == self._n_workers:
57+
for _ in range(self._n_consumers):
7758
await self.q_out.put(StopSentinel)
7859

79-
def start(self):
80-
for _ in range(self.task.concurrency):
81-
self.tg.create_task(self._worker())
60+
def start(self, tg: TaskGroup, /):
61+
for _ in range(self._n_workers):
62+
tg.create_task(self._worker())

src/pyper/_core/decorators.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919

2020

2121
class task:
22-
"""Decorator class to transform a function into a `Task` object, and then initialize a `Pipeline` with this task.
23-
A Pipeline initialized in this way consists of one Task, and can be piped into other Pipelines.
22+
"""Decorator class to initialize a `Pipeline` consisting of one task.
2423
2524
The behaviour of each task within a Pipeline is determined by the parameters:
2625
* `join`: allows the function to take all previous results as input, instead of single results
2726
* `concurrency`: runs the functions with multiple (async or threaded) workers
2827
* `throttle`: limits the number of results the function is able to produce when all consumers are busy
29-
* `daemon`: determines whether threaded workers are daemon threads (cannot be True for async tasks)
3028
* `bind`: additional args and kwargs to bind to the function when defining a pipeline
3129
"""
3230
@t.overload
@@ -38,7 +36,6 @@ def __new__(
3836
join: bool = False,
3937
concurrency: int = 1,
4038
throttle: int = 0,
41-
daemon: bool = False,
4239
multiprocess: bool = False,
4340
bind: _ArgsKwargs = None) -> t.Type[task]: ...
4441

@@ -51,7 +48,6 @@ def __new__(
5148
join: bool = False,
5249
concurrency: int = 1,
5350
throttle: int = 0,
54-
daemon: bool = False,
5551
multiprocess: bool = False,
5652
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
5753

@@ -64,7 +60,6 @@ def __new__(
6460
join: bool = False,
6561
concurrency: int = 1,
6662
throttle: int = 0,
67-
daemon: bool = False,
6863
multiprocess: bool = False,
6964
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
7065

@@ -77,7 +72,6 @@ def __new__(
7772
join: bool = False,
7873
concurrency: int = 1,
7974
throttle: int = 0,
80-
daemon: bool = False,
8175
multiprocess: bool = False,
8276
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
8377

@@ -90,7 +84,6 @@ def __new__(
9084
join: bool = False,
9185
concurrency: int = 1,
9286
throttle: int = 0,
93-
daemon: bool = False,
9487
multiprocess: bool = False,
9588
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
9689

@@ -102,17 +95,25 @@ def __new__(
10295
join: bool = False,
10396
concurrency: int = 1,
10497
throttle: int = 0,
105-
daemon: bool = False,
10698
multiprocess: bool = False,
10799
bind: _ArgsKwargs = None):
108100
# Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
109101
if func is None:
110-
return functools.partial(cls, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, multiprocess=multiprocess, bind=bind)
111-
return Pipeline([Task(func=func, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, multiprocess=multiprocess, bind=bind)])
102+
return functools.partial(cls, join=join, concurrency=concurrency, throttle=throttle, multiprocess=multiprocess, bind=bind)
103+
return Pipeline([Task(func=func, join=join, concurrency=concurrency, throttle=throttle, multiprocess=multiprocess, bind=bind)])
112104

113105
@staticmethod
114106
def bind(*args, **kwargs) -> _ArgsKwargs:
115-
"""Utility method, to be used with `functools.partial`."""
107+
"""Bind additional `args` and `kwargs` to a task.
108+
109+
Example:
110+
```python
111+
def f(x: int, y: int):
112+
return x + y
113+
114+
p = task(f, bind=task.bind(y=1))
115+
p(x=1)
116+
"""
116117
if not args and not kwargs:
117118
return None
118119
return args, kwargs

src/pyper/_core/task.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ class Task:
1515
"join",
1616
"concurrency",
1717
"throttle",
18-
"daemon",
1918
"multiprocess"
2019
)
2120

@@ -25,7 +24,6 @@ def __init__(
2524
join: bool = False,
2625
concurrency: int = 1,
2726
throttle: int = 0,
28-
daemon: bool = False,
2927
multiprocess: bool = False,
3028
bind: Optional[Tuple[Tuple, Dict]] = None):
3129
if not isinstance(concurrency, int):
@@ -38,8 +36,6 @@ def __init__(
3836
raise ValueError("throttle cannot be less than 0")
3937
if not callable(func):
4038
raise TypeError("A task must be a callable object")
41-
if daemon and multiprocess:
42-
raise ValueError("daemon and multiprocess cannot both be True")
4339

4440
self.is_gen = inspect.isgeneratorfunction(func) \
4541
or inspect.isasyncgenfunction(func) \
@@ -50,14 +46,11 @@ def __init__(
5046
or inspect.iscoroutinefunction(func.__call__) \
5147
or inspect.isasyncgenfunction(func.__call__)
5248

53-
if self.is_async and daemon:
54-
raise ValueError("daemon cannot be True for an async task")
5549
if self.is_async and multiprocess:
5650
raise ValueError("multiprocess cannot be True for an async task")
5751

5852
self.func = func if bind is None else functools.partial(func, *bind[0], **bind[1])
5953
self.join = join
6054
self.concurrency = concurrency
6155
self.throttle = throttle
62-
self.daemon = daemon
6356
self.multiprocess = multiprocess

0 commit comments

Comments
 (0)