Skip to content

Commit 5a19b36

Browse files
committed
change concurrency param to workers for multiprocessing
1 parent 4c59b9b commit 5a19b36

File tree

5 files changed

+47
-33
lines changed

5 files changed

+47
-33
lines changed

src/pyper/_core/async_helper/stage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
class AsyncProducer:
2020
def __init__(self, task: Task, next_task: Task):
21-
if task.concurrency > 1:
22-
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1")
21+
if task.workers > 1:
22+
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have more than 1 worker")
2323
if task.join:
2424
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results")
2525
self.task = task
2626
self.q_out = asyncio.Queue(maxsize=task.throttle)
2727

28-
self._n_consumers = 1 if next_task is None else next_task.concurrency
28+
self._n_consumers = 1 if next_task is None else next_task.workers
2929
self._enqueue = AsyncEnqueueFactory(self.q_out, self.task)
3030

3131
async def _worker(self, *args, **kwargs):
@@ -42,8 +42,8 @@ class AsyncProducerConsumer:
4242
def __init__(self, q_in: asyncio.Queue, task: Task, next_task: Task):
4343
self.q_out = asyncio.Queue(maxsize=task.throttle)
4444

45-
self._n_workers = task.concurrency
46-
self._n_consumers = 1 if next_task is None else next_task.concurrency
45+
self._n_workers = task.workers
46+
self._n_consumers = 1 if next_task is None else next_task.workers
4747
self._dequeue = AsyncDequeueFactory(q_in, task)
4848
self._enqueue = AsyncEnqueueFactory(self.q_out, task)
4949
self._workers_done = 0

src/pyper/_core/decorators.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,26 @@
1919

2020

2121
class task:
22-
"""Decorator class to initialize a `Pipeline` consisting of one task.
22+
"""Decorator class to initialize a `Pipeline` consisting of one task, from a function or callable.
2323
24-
The behaviour of each task within a Pipeline is determined by the parameters:
25-
* `join`: allows the function to take all previous results as input, instead of single results
26-
* `concurrency`: runs the functions with multiple (async or threaded) workers
27-
* `throttle`: limits the number of results the function is able to produce when all consumers are busy
28-
* `bind`: additional args and kwargs to bind to the function when defining a pipeline
24+
Args:
25+
func (callable): A positional-only param defining the task function
26+
join (bool): Allows the task to take all previous results as input, instead of single results
27+
workers (int): Defines the number of workers to run the task
28+
throttle (int): Limits the number of results the task is able to produce when all consumers are busy
29+
multiprocess (bool): Allows the task to be multiprocessed (cannot be `True` for async tasks)
30+
bind (tuple[args, kwargs]): Additional args and kwargs to bind to the task when defining a pipeline
31+
32+
Returns:
33+
Pipeline: A `Pipeline` instance consisting of one task.
34+
35+
Example:
36+
```python
37+
def f(x: int):
38+
return x + 1
39+
40+
p = task(f, workers=10, multiprocess=True)
41+
```
2942
"""
3043
@t.overload
3144
def __new__(
@@ -34,7 +47,7 @@ def __new__(
3447
/,
3548
*,
3649
join: bool = False,
37-
concurrency: int = 1,
50+
workers: int = 1,
3851
throttle: int = 0,
3952
multiprocess: bool = False,
4053
bind: _ArgsKwargs = None) -> t.Type[task]: ...
@@ -46,7 +59,7 @@ def __new__(
4659
/,
4760
*,
4861
join: bool = False,
49-
concurrency: int = 1,
62+
workers: int = 1,
5063
throttle: int = 0,
5164
multiprocess: bool = False,
5265
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
@@ -58,7 +71,7 @@ def __new__(
5871
/,
5972
*,
6073
join: bool = False,
61-
concurrency: int = 1,
74+
workers: int = 1,
6275
throttle: int = 0,
6376
multiprocess: bool = False,
6477
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
@@ -70,7 +83,7 @@ def __new__(
7083
/,
7184
*,
7285
join: bool = False,
73-
concurrency: int = 1,
86+
workers: int = 1,
7487
throttle: int = 0,
7588
multiprocess: bool = False,
7689
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
@@ -82,7 +95,7 @@ def __new__(
8295
/,
8396
*,
8497
join: bool = False,
85-
concurrency: int = 1,
98+
workers: int = 1,
8699
throttle: int = 0,
87100
multiprocess: bool = False,
88101
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
@@ -93,14 +106,14 @@ def __new__(
93106
/,
94107
*,
95108
join: bool = False,
96-
concurrency: int = 1,
109+
workers: int = 1,
97110
throttle: int = 0,
98111
multiprocess: bool = False,
99112
bind: _ArgsKwargs = None):
100113
# Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
101114
if func is None:
102-
return functools.partial(cls, join=join, concurrency=concurrency, throttle=throttle, multiprocess=multiprocess, bind=bind)
103-
return Pipeline([Task(func=func, join=join, concurrency=concurrency, throttle=throttle, multiprocess=multiprocess, bind=bind)])
115+
return functools.partial(cls, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)
116+
return Pipeline([Task(func=func, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)])
104117

105118
@staticmethod
106119
def bind(*args, **kwargs) -> _ArgsKwargs:
@@ -113,6 +126,7 @@ def f(x: int, y: int):
113126
114127
p = task(f, bind=task.bind(y=1))
115128
p(x=1)
129+
```
116130
"""
117131
if not args and not kwargs:
118132
return None

src/pyper/_core/sync_helper/stage.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(
2121
next_task: Task,
2222
q_err: Union[mp.Queue, queue.Queue],
2323
shutdown_event: Union[MpEvent, threading.Event]):
24-
if task.concurrency > 1:
25-
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1")
24+
if task.workers > 1:
25+
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have more than 1 worker")
2626
if task.join:
2727
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results")
2828
self.q_out = mp.Queue(maxsize=task.throttle) \
@@ -31,8 +31,8 @@ def __init__(
3131

3232
self._q_err = q_err
3333
self._shutdown_event = shutdown_event
34-
self._n_workers = task.concurrency
35-
self._n_consumers = 1 if next_task is None else next_task.concurrency
34+
self._n_workers = task.workers
35+
self._n_consumers = 1 if next_task is None else next_task.workers
3636
self._enqueue = EnqueueFactory(self.q_out, task)
3737

3838
def _worker(self, *args, **kwargs):
@@ -65,8 +65,8 @@ def __init__(
6565

6666
self._q_err = q_err
6767
self._shutdown_event = shutdown_event
68-
self._n_workers = task.concurrency
69-
self._n_consumers = 1 if next_task is None else next_task.concurrency
68+
self._n_workers = task.workers
69+
self._n_consumers = 1 if next_task is None else next_task.workers
7070
self._dequeue = DequeueFactory(q_in, task)
7171
self._enqueue = EnqueueFactory(self.q_out, task)
7272

src/pyper/_core/task.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Task:
1313
"is_async",
1414
"func",
1515
"join",
16-
"concurrency",
16+
"workers",
1717
"throttle",
1818
"multiprocess"
1919
)
@@ -22,14 +22,14 @@ def __init__(
2222
self,
2323
func: Callable,
2424
join: bool = False,
25-
concurrency: int = 1,
25+
workers: int = 1,
2626
throttle: int = 0,
2727
multiprocess: bool = False,
2828
bind: Optional[Tuple[Tuple, Dict]] = None):
29-
if not isinstance(concurrency, int):
30-
raise TypeError("concurrency must be an integer")
31-
if concurrency < 1:
32-
raise ValueError("concurrency cannot be less than 1")
29+
if not isinstance(workers, int):
30+
raise TypeError("workers must be an integer")
31+
if workers < 1:
32+
raise ValueError("workers cannot be less than 1")
3333
if not isinstance(throttle, int):
3434
raise TypeError("throttle must be an integer")
3535
if throttle < 0:
@@ -51,6 +51,6 @@ def __init__(
5151

5252
self.func = func if bind is None else functools.partial(func, *bind[0], **bind[1])
5353
self.join = join
54-
self.concurrency = concurrency
54+
self.workers = workers
5555
self.throttle = throttle
5656
self.multiprocess = multiprocess

src/pyper/_core/util/asynchronize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ async def wrapper(*args, **kwargs):
3232
return Task(
3333
func=wrapper,
3434
join=task.join,
35-
concurrency=task.concurrency,
35+
workers=task.workers,
3636
throttle=task.throttle
3737
)

0 commit comments

Comments
 (0)