Skip to content

Commit 5b689c0

Browse files
committed
fix type hints for branched tasks and unhinted functions
1 parent 73098b7 commit 5b689c0

File tree

1 file changed

+82
-24
lines changed

1 file changed

+82
-24
lines changed

src/pyper/_core/decorators.py

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
_P = ParamSpec('P')
1717
_R = t.TypeVar('R')
18-
_ArgsKwargs: t.TypeAlias = t.Optional[t.Tuple[t.Tuple[t.Any], t.Dict[str, t.Any]]]
18+
_Default = t.TypeVar('T', bound=t.NoReturn) # Matches to no type hints
19+
ArgsKwargs: t.TypeAlias = t.Optional[t.Tuple[t.Tuple[t.Any], t.Dict[str, t.Any]]]
1920

2021

2122
class task:
@@ -28,19 +29,57 @@ class task:
2829
workers (int): Defines the number of workers to run the task
2930
throttle (int): Limits the number of results the task is able to produce when all consumers are busy
3031
multiprocess (bool): Allows the task to be multiprocessed (cannot be `True` for async tasks)
31-
bind (tuple[args, kwargs]): Additional args and kwargs to bind to the task when defining a pipeline
32+
bind (tuple[tuple, dict]): Additional args and kwargs to bind to the task when defining a pipeline
3233
3334
Returns:
34-
Pipeline: A `Pipeline` instance consisting of one task.
35+
A `Pipeline` instance consisting of one task.
3536
36-
Example:
37-
```python
38-
def f(x: int):
39-
return x + 1
37+
Examples:
38+
```python
39+
def spam(x: int):
40+
return x + 1
41+
42+
p = task(spam)
43+
44+
def ham(x: int):
45+
return [x, x + 1, x + 2]
4046
41-
p = task(f, workers=10, multiprocess=True)
42-
```
47+
p = task(ham, branch=True, workers=10)
48+
49+
async def eggs(x: int):
50+
yield x
51+
yield x + 1
52+
yield x + 2
53+
54+
p = task(eggs, branch=True, throttle=1)
55+
```
4356
"""
57+
@t.overload
58+
def __new__(
59+
cls,
60+
func: t.Callable[_P, _Default],
61+
/,
62+
*,
63+
branch: bool = False,
64+
join: bool = False,
65+
workers: int = 1,
66+
throttle: int = 0,
67+
multiprocess: bool = False,
68+
bind: ArgsKwargs = None) -> Pipeline[_P, _Default]: ...
69+
70+
@t.overload
71+
def __new__(
72+
cls,
73+
func: None = None,
74+
/,
75+
*,
76+
branch: t.Literal[True],
77+
join: bool = False,
78+
workers: int = 1,
79+
throttle: int = 0,
80+
multiprocess: bool = False,
81+
bind: ArgsKwargs = None) -> t.Type[_branched_partial_task]: ...
82+
4483
@t.overload
4584
def __new__(
4685
cls,
@@ -52,20 +91,20 @@ def __new__(
5291
workers: int = 1,
5392
throttle: int = 0,
5493
multiprocess: bool = False,
55-
bind: _ArgsKwargs = None) -> t.Type[task]: ...
94+
bind: ArgsKwargs = None) -> t.Type[task]: ...
5695

5796
@t.overload
5897
def __new__(
5998
cls,
6099
func: t.Callable[_P, t.Union[t.Awaitable[t.Iterable[_R]], t.AsyncGenerator[_R]]],
61100
/,
62101
*,
63-
branch: True,
102+
branch: t.Literal[True],
64103
join: bool = False,
65104
workers: int = 1,
66105
throttle: int = 0,
67106
multiprocess: bool = False,
68-
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
107+
bind: ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
69108

70109
@t.overload
71110
def __new__(
@@ -78,20 +117,20 @@ def __new__(
78117
workers: int = 1,
79118
throttle: int = 0,
80119
multiprocess: bool = False,
81-
bind: _ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
120+
bind: ArgsKwargs = None) -> AsyncPipeline[_P, _R]: ...
82121

83122
@t.overload
84123
def __new__(
85124
cls,
86125
func: t.Callable[_P, t.Iterable[_R]],
87126
/,
88127
*,
89-
branch: True,
128+
branch: t.Literal[True],
90129
join: bool = False,
91130
workers: int = 1,
92131
throttle: int = 0,
93132
multiprocess: bool = False,
94-
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
133+
bind: ArgsKwargs = None) -> Pipeline[_P, _R]: ...
95134

96135
@t.overload
97136
def __new__(
@@ -104,7 +143,7 @@ def __new__(
104143
workers: int = 1,
105144
throttle: int = 0,
106145
multiprocess: bool = False,
107-
bind: _ArgsKwargs = None) -> Pipeline[_P, _R]: ...
146+
bind: ArgsKwargs = None) -> Pipeline[_P, _R]: ...
108147

109148
def __new__(
110149
cls,
@@ -116,25 +155,44 @@ def __new__(
116155
workers: int = 1,
117156
throttle: int = 0,
118157
multiprocess: bool = False,
119-
bind: _ArgsKwargs = None):
158+
bind: ArgsKwargs = None):
120159
# Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
121160
if func is None:
122161
return functools.partial(cls, branch=branch, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)
123162
return Pipeline([Task(func=func, branch=branch, join=join, workers=workers, throttle=throttle, multiprocess=multiprocess, bind=bind)])
124163

125164
@staticmethod
126-
def bind(*args, **kwargs) -> _ArgsKwargs:
165+
def bind(*args, **kwargs) -> ArgsKwargs:
127166
"""Bind additional `args` and `kwargs` to a task.
128167
129168
Example:
130-
```python
131-
def f(x: int, y: int):
132-
return x + y
169+
```python
170+
def f(x: int, y: int):
171+
return x + y
133172
134-
p = task(f, bind=task.bind(y=1))
135-
p(x=1)
136-
```
173+
p = task(f, bind=task.bind(y=1))
174+
p(x=1)
175+
```
137176
"""
138177
if not args and not kwargs:
139178
return None
140179
return args, kwargs
180+
181+
182+
class _branched_partial_task:
183+
@t.overload
184+
def __new__(cls, func: t.Callable[_P, _Default]) -> Pipeline[_P, _Default]: ...
185+
186+
@t.overload
187+
def __new__(
188+
cls,
189+
func: t.Callable[_P, t.Union[t.Awaitable[t.Iterable[_R]], t.AsyncGenerator[_R]]]) -> AsyncPipeline[_P, _R]: ...
190+
191+
@t.overload
192+
def __new__(cls, func: t.Callable[_P, t.Iterable[_R]]) -> Pipeline[_P, _R]: ...
193+
194+
@t.overload
195+
def __new__(cls, func: t.Callable[_P, _R]) -> Pipeline[_P, t.Any]: ...
196+
197+
def __new__(cls):
198+
raise NotImplementedError

0 commit comments

Comments
 (0)