Skip to content

Commit 1ced59d

Browse files
committed
start mp implementation
1 parent 114c4a4 commit 1ced59d

File tree

5 files changed

+129
-90
lines changed

5 files changed

+129
-90
lines changed

src/pyper/_core/async_helper/output.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ class AsyncPipelineOutput:
1515
def __init__(self, pipeline: AsyncPipeline):
1616
self.pipeline = pipeline
1717

18-
def _get_q_out(self, tg: asyncio.TaskGroup, tp: ThreadPool ,*args, **kwargs) -> asyncio.Queue:
18+
def _get_q_out(self, tg: asyncio.TaskGroup, tp: ThreadPool, pp: ProcessPool, *args, **kwargs) -> asyncio.Queue:
1919
"""Feed forward each stage to the next, returning the output queue of the final stage."""
2020
q_out = None
2121
for task, next_task in zip(self.pipeline.tasks, self.pipeline.tasks[1:] + [None]):
2222
n_consumers = 1 if next_task is None else next_task.concurrency
2323
if q_out is None:
24-
stage = AsyncProducer(task=self.pipeline.tasks[0], tg=tg, tp=tp, n_consumers=n_consumers)
24+
stage = AsyncProducer(task=self.pipeline.tasks[0], tg=tg, tp=tp, pp=pp, n_consumers=n_consumers)
2525
stage.start(*args, **kwargs)
2626
else:
27-
stage = AsyncProducerConsumer(q_in=q_out, task=task, tg=tg, tp=tp, n_consumers=n_consumers)
27+
stage = AsyncProducerConsumer(q_in=q_out, task=task, tg=tg, tp=tp, pp=pp, n_consumers=n_consumers)
2828
stage.start()
2929
q_out = stage.q_out
3030

@@ -33,11 +33,12 @@ def _get_q_out(self, tg: asyncio.TaskGroup, tp: ThreadPool ,*args, **kwargs) ->
3333
async def __call__(self, *args, **kwargs):
3434
"""Call the pipeline, taking the inputs to the first task, and returning the output from the last task."""
3535
try:
36-
# We unify async and thread-based concurrency by
37-
# 1. using TaskGroup to spin up asynchronous tasks
38-
# 2. using ThreadPool to spin up synchronous tasks
39-
async with asyncio.TaskGroup() as tg, ThreadPool() as tp:
40-
q_out = self._get_q_out(tg, tp, *args, **kwargs)
36+
# Unify async, threaded, and multiprocessed work by:
37+
# 1. using TaskGroup to execute asynchronous tasks
38+
# 2. using ThreadPool to execute threaded synchronous tasks
39+
# 3. using ProcessPool to execute multiprocessed synchronous tasks
40+
async with asyncio.TaskGroup() as tg, ThreadPool() as tp, ProcessPool as pp:
41+
q_out = self._get_q_out(tg, tp, pp, *args, **kwargs)
4142
while (data := await q_out.get()) is not StopSentinel:
4243
yield data
4344
except ExceptionGroup as eg:

src/pyper/_core/async_helper/stage.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313

1414

1515
class AsyncProducer:
16-
def __init__(self, task: Task, tg: asyncio.TaskGroup, tp: ThreadPool, n_consumers: int):
16+
def __init__(
17+
self,
18+
task: Task,
19+
tg: asyncio.TaskGroup,
20+
tp: ThreadPool,
21+
pp,
22+
n_consumers: int):
1723
self.task = ascynchronize(task, tp)
1824
if task.concurrency > 1:
1925
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1")
@@ -36,7 +42,14 @@ def start(self, *args, **kwargs):
3642

3743

3844
class AsyncProducerConsumer:
39-
def __init__(self, q_in: asyncio.Queue, task: Task, tg: asyncio.TaskGroup, tp: ThreadPool, n_consumers: int):
45+
def __init__(
46+
self,
47+
q_in: asyncio.Queue,
48+
task: Task,
49+
tg: asyncio.TaskGroup,
50+
tp: ThreadPool,
51+
pp,
52+
n_consumers: int):
4053
self.q_in = q_in
4154
self.task = ascynchronize(task, tp)
4255
self.tg = tg

src/pyper/_core/decorators.py

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,84 +25,84 @@ class task:
2525
"""
2626
@t.overload
2727
def __new__(
28-
cls,
29-
func: None = None,
30-
/,
31-
*,
32-
join: bool = False,
33-
concurrency: int = 1,
34-
throttle: int = 0,
35-
daemon: bool = False,
36-
bind: _ArgsKwargs = None
37-
) -> t.Type[task]: ...
28+
cls,
29+
func: None = None,
30+
/,
31+
*,
32+
join: bool = False,
33+
concurrency: int = 1,
34+
throttle: int = 0,
35+
daemon: bool = False,
36+
multiprocess: bool = False,
37+
bind: _ArgsKwargs = None) -> t.Type[task]: ...
3838

3939
@t.overload
4040
def __new__(
41-
cls,
42-
func: t.Callable[_P, t.Awaitable[_R]],
43-
/,
44-
*,
45-
join: bool = False,
46-
concurrency: int = 1,
47-
throttle: int = 0,
48-
daemon: bool = False,
49-
bind: _ArgsKwargs = None
50-
) -> AsyncPipeline[_P, _R]: ...
41+
cls,
42+
func: t.Callable[_P, t.Awaitable[_R]],
43+
/,
44+
*,
45+
join: bool = False,
46+
concurrency: int = 1,
47+
throttle: int = 0,
48+
daemon: bool = False,
49+
multiprocess: bool = False,
50+
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
5151

5252
@t.overload
5353
def __new__(
54-
cls,
55-
func: t.Callable[_P, t.AsyncGenerator[_R]],
56-
/,
57-
*,
58-
join: bool = False,
59-
concurrency: int = 1,
60-
throttle: int = 0,
61-
daemon: bool = False,
62-
bind: _ArgsKwargs = None
63-
) -> AsyncPipeline[_P, _R]: ...
64-
54+
cls,
55+
func: t.Callable[_P, t.AsyncGenerator[_R]],
56+
/,
57+
*,
58+
join: bool = False,
59+
concurrency: int = 1,
60+
throttle: int = 0,
61+
daemon: bool = False,
62+
multiprocess: bool = False,
63+
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
64+
6565
@t.overload
6666
def __new__(
67-
cls,
68-
func: t.Callable[_P, t.Generator[_R]],
69-
/,
70-
*,
71-
join: bool = False,
72-
concurrency: int = 1,
73-
throttle: int = 0,
74-
daemon: bool = False,
75-
bind: _ArgsKwargs = None
76-
) -> Pipeline[_P, _R]: ...
67+
cls,
68+
func: t.Callable[_P, t.Generator[_R]],
69+
/,
70+
*,
71+
join: bool = False,
72+
concurrency: int = 1,
73+
throttle: int = 0,
74+
daemon: bool = False,
75+
multiprocess: bool = False,
76+
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
7777

7878
@t.overload
7979
def __new__(
80-
cls,
81-
func: t.Callable[_P, _R],
82-
/,
83-
*,
84-
join: bool = False,
85-
concurrency: int = 1,
86-
throttle: int = 0,
87-
daemon: bool = False,
88-
bind: _ArgsKwargs = None
89-
) -> Pipeline[_P, _R]: ...
80+
cls,
81+
func: t.Callable[_P, _R],
82+
/,
83+
*,
84+
join: bool = False,
85+
concurrency: int = 1,
86+
throttle: int = 0,
87+
daemon: bool = False,
88+
multiprocess: bool = False,
89+
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
9090

9191
def __new__(
92-
cls,
93-
func: t.Optional[t.Callable] = None,
94-
/,
95-
*,
96-
join: bool = False,
97-
concurrency: int = 1,
98-
throttle: int = 0,
99-
daemon: bool = False,
100-
bind: _ArgsKwargs = None
101-
):
92+
cls,
93+
func: t.Optional[t.Callable] = None,
94+
/,
95+
*,
96+
join: bool = False,
97+
concurrency: int = 1,
98+
throttle: int = 0,
99+
daemon: bool = False,
100+
multiprocess: bool = False,
101+
bind: _ArgsKwargs = None):
102102
# Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
103103
if func is None:
104-
return functools.partial(cls, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, bind=bind)
105-
return Pipeline([Task(func=func, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, bind=bind)])
104+
return functools.partial(cls, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, multiprocess=multiprocess, bind=bind)
105+
return Pipeline([Task(func=func, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, multiprocess=multiprocess, bind=bind)])
106106

107107
@staticmethod
108108
def bind(*args, **kwargs) -> _ArgsKwargs:

src/pyper/_core/task.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,26 @@
88
class Task:
99
"""The representation of a function within a Pipeline."""
1010

11-
__slots__ = "is_gen", "is_async", "func", "join", "concurrency", "throttle", "daemon"
11+
__slots__ = (
12+
"is_gen",
13+
"is_async",
14+
"func",
15+
"join",
16+
"concurrency",
17+
"throttle",
18+
"daemon",
19+
"multiprocess"
20+
)
1221

1322
def __init__(
14-
self,
15-
func: Callable,
16-
join: bool = False,
17-
concurrency: int = 1,
18-
throttle: int = 0,
19-
daemon: bool = False,
20-
bind: Optional[Tuple[Tuple, Dict]] = None
21-
):
23+
self,
24+
func: Callable,
25+
join: bool = False,
26+
concurrency: int = 1,
27+
throttle: int = 0,
28+
daemon: bool = False,
29+
multiprocess: bool = False,
30+
bind: Optional[Tuple[Tuple, Dict]] = None):
2231
if not isinstance(concurrency, int):
2332
raise TypeError("concurrency must be an integer")
2433
if concurrency < 1:
@@ -29,6 +38,8 @@ def __init__(
2938
raise ValueError("throttle cannot be less than 0")
3039
if not callable(func):
3140
raise TypeError("A task must be a callable object")
41+
if daemon and multiprocess:
42+
raise ValueError("daemon and multiprocess cannot both be True")
3243

3344
self.is_gen = inspect.isgeneratorfunction(func) \
3445
or inspect.isasyncgenfunction(func) \
@@ -41,9 +52,12 @@ def __init__(
4152

4253
if self.is_async and daemon:
4354
raise ValueError("daemon cannot be True for an async task")
55+
if self.is_async and multiprocess:
56+
raise ValueError("multiprocess cannot be True for an async task")
4457

4558
self.func = func if bind is None else functools.partial(func, *bind[0], **bind[1])
4659
self.join = join
4760
self.concurrency = concurrency
4861
self.throttle = throttle
4962
self.daemon = daemon
63+
self.multiprocess = multiprocess

src/pyper/_core/util/asynchronize.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,32 @@
88
from ..task import Task
99

1010

11-
def ascynchronize(task: Task, tp: ThreadPool) -> Task:
12-
"""Unifies async and sync functions within an `AsyncPipeline`.
13-
14-
Synchronous functions within a `ThreadPool` are transformed into asynchronous functions via `asyncio.wrap_future`.
15-
Synchronous generators are transformed into asynchronous generators.
11+
def ascynchronize(task: Task, tp: ThreadPool, pp) -> Task:
12+
"""Unifies async and sync tasks as awaitable futures.
13+
1. If the task is async already, return it.
14+
2. Synchronous generators are transformed into asynchronous generators.
15+
3. Multiprocessed synchronous functions within a `ProcessPool` are wrapped in `asyncio.wrap_future`.
16+
4. Threaded synchronous functions within a `ThreadPool` are wrapped in `asyncio.wrap_future`.
1617
"""
1718
if task.is_async:
18-
return task
19-
19+
return task
2020
if task.is_gen:
2121
@functools.wraps(task.func)
2222
async def wrapper(*args, **kwargs):
2323
for output in task.func(*args, **kwargs):
2424
yield output
25+
elif task.multiprocess:
26+
@functools.wraps(task.func)
27+
async def wrapper(*args, **kwargs):
28+
future = Future()
29+
def target(*args, **kwargs):
30+
try:
31+
result = task.func(*args, **kwargs)
32+
future.set_result(result)
33+
except Exception as e:
34+
future.set_exception(e)
35+
pp.submit(target, args=args, kwargs=kwargs)
36+
return await asyncio.wrap_future(future)
2537
else:
2638
@functools.wraps(task.func)
2739
async def wrapper(*args, **kwargs):
@@ -34,7 +46,6 @@ def target(*args, **kwargs):
3446
future.set_exception(e)
3547
tp.submit(target, args=args, kwargs=kwargs, daemon=task.daemon)
3648
return await asyncio.wrap_future(future)
37-
3849
return Task(
3950
func=wrapper,
4051
join=task.join,

0 commit comments

Comments
 (0)