1515
1616_P = ParamSpec ('P' )
1717_R = t .TypeVar ('R' )
18- _ArgsKwargs : t .TypeAlias = t .Optional [t .Tuple [t .Tuple [t .Any ], t .Dict [str , t .Any ]]]
18+ _Default = t .TypeVar ('T' , bound = t .NoReturn ) # Matches to no type hints
19+ ArgsKwargs : t .TypeAlias = t .Optional [t .Tuple [t .Tuple [t .Any ], t .Dict [str , t .Any ]]]
1920
2021
2122class task :
@@ -28,19 +29,57 @@ class task:
2829 workers (int): Defines the number of workers to run the task
2930 throttle (int): Limits the number of results the task is able to produce when all consumers are busy
3031 multiprocess (bool): Allows the task to be multiprocessed (cannot be `True` for async tasks)
31- bind (tuple[args, kwargs ]): Additional args and kwargs to bind to the task when defining a pipeline
32+ bind (tuple[tuple, dict ]): Additional args and kwargs to bind to the task when defining a pipeline
3233
3334 Returns:
34- Pipeline: A `Pipeline` instance consisting of one task.
35+ A `Pipeline` instance consisting of one task.
3536
36- Example:
37- ```python
38- def f(x: int):
39- return x + 1
37+ Examples:
38+ ```python
39+ def spam(x: int):
40+ return x + 1
41+
42+ p = task(spam)
43+
44+ def ham(x: int):
45+ return [x, x + 1, x + 2]
4046
41- p = task(f, workers=10, multiprocess=True)
42- ```
47+ p = task(ham, branch=True, workers=10)
48+
49+ async def eggs(x: int):
50+ yield x
51+ yield x + 1
52+ yield x + 2
53+
54+ p = task(eggs, branch=True, throttle=1)
55+ ```
4356 """
57+ @t .overload
58+ def __new__ (
59+ cls ,
60+ func : t .Callable [_P , _Default ],
61+ / ,
62+ * ,
63+ branch : bool = False ,
64+ join : bool = False ,
65+ workers : int = 1 ,
66+ throttle : int = 0 ,
67+ multiprocess : bool = False ,
68+ bind : ArgsKwargs = None ) -> Pipeline [_P , _Default ]: ...
69+
70+ @t .overload
71+ def __new__ (
72+ cls ,
73+ func : None = None ,
74+ / ,
75+ * ,
76+ branch : t .Literal [True ],
77+ join : bool = False ,
78+ workers : int = 1 ,
79+ throttle : int = 0 ,
80+ multiprocess : bool = False ,
81+ bind : ArgsKwargs = None ) -> t .Type [_branched_partial_task ]: ...
82+
4483 @t .overload
4584 def __new__ (
4685 cls ,
@@ -52,20 +91,20 @@ def __new__(
5291 workers : int = 1 ,
5392 throttle : int = 0 ,
5493 multiprocess : bool = False ,
55- bind : _ArgsKwargs = None ) -> t .Type [task ]: ...
94+ bind : ArgsKwargs = None ) -> t .Type [task ]: ...
5695
5796 @t .overload
5897 def __new__ (
5998 cls ,
6099 func : t .Callable [_P , t .Union [t .Awaitable [t .Iterable [_R ]], t .AsyncGenerator [_R ]]],
61100 / ,
62101 * ,
63- branch : True ,
102+ branch : t . Literal [ True ] ,
64103 join : bool = False ,
65104 workers : int = 1 ,
66105 throttle : int = 0 ,
67106 multiprocess : bool = False ,
68- bind : _ArgsKwargs = None ) -> AsyncPipeline [_P , _R ]: ...
107+ bind : ArgsKwargs = None ) -> AsyncPipeline [_P , _R ]: ...
69108
70109 @t .overload
71110 def __new__ (
@@ -78,20 +117,20 @@ def __new__(
78117 workers : int = 1 ,
79118 throttle : int = 0 ,
80119 multiprocess : bool = False ,
81- bind : _ArgsKwargs = None ) -> AsyncPipeline [_P , _R ]: ...
120+ bind : ArgsKwargs = None ) -> AsyncPipeline [_P , _R ]: ...
82121
83122 @t .overload
84123 def __new__ (
85124 cls ,
86125 func : t .Callable [_P , t .Iterable [_R ]],
87126 / ,
88127 * ,
89- branch : True ,
128+ branch : t . Literal [ True ] ,
90129 join : bool = False ,
91130 workers : int = 1 ,
92131 throttle : int = 0 ,
93132 multiprocess : bool = False ,
94- bind : _ArgsKwargs = None ) -> Pipeline [_P , _R ]: ...
133+ bind : ArgsKwargs = None ) -> Pipeline [_P , _R ]: ...
95134
96135 @t .overload
97136 def __new__ (
@@ -104,7 +143,7 @@ def __new__(
104143 workers : int = 1 ,
105144 throttle : int = 0 ,
106145 multiprocess : bool = False ,
107- bind : _ArgsKwargs = None ) -> Pipeline [_P , _R ]: ...
146+ bind : ArgsKwargs = None ) -> Pipeline [_P , _R ]: ...
108147
109148 def __new__ (
110149 cls ,
@@ -116,25 +155,44 @@ def __new__(
116155 workers : int = 1 ,
117156 throttle : int = 0 ,
118157 multiprocess : bool = False ,
119- bind : _ArgsKwargs = None ):
158+ bind : ArgsKwargs = None ):
120159 # Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
121160 if func is None :
122161 return functools .partial (cls , branch = branch , join = join , workers = workers , throttle = throttle , multiprocess = multiprocess , bind = bind )
123162 return Pipeline ([Task (func = func , branch = branch , join = join , workers = workers , throttle = throttle , multiprocess = multiprocess , bind = bind )])
124163
125164 @staticmethod
126- def bind (* args , ** kwargs ) -> _ArgsKwargs :
165+ def bind (* args , ** kwargs ) -> ArgsKwargs :
127166 """Bind additional `args` and `kwargs` to a task.
128167
129168 Example:
130- ```python
131- def f(x: int, y: int):
132- return x + y
169+ ```python
170+ def f(x: int, y: int):
171+ return x + y
133172
134- p = task(f, bind=task.bind(y=1))
135- p(x=1)
136- ```
173+ p = task(f, bind=task.bind(y=1))
174+ p(x=1)
175+ ```
137176 """
138177 if not args and not kwargs :
139178 return None
140179 return args , kwargs
180+
181+
182+ class _branched_partial_task :
183+ @t .overload
184+ def __new__ (cls , func : t .Callable [_P , _Default ]) -> Pipeline [_P , _Default ]: ...
185+
186+ @t .overload
187+ def __new__ (
188+ cls ,
189+ func : t .Callable [_P , t .Union [t .Awaitable [t .Iterable [_R ]], t .AsyncGenerator [_R ]]]) -> AsyncPipeline [_P , _R ]: ...
190+
191+ @t .overload
192+ def __new__ (cls , func : t .Callable [_P , t .Iterable [_R ]]) -> Pipeline [_P , _R ]: ...
193+
194+ @t .overload
195+ def __new__ (cls , func : t .Callable [_P , _R ]) -> Pipeline [_P , t .Any ]: ...
196+
197+ def __new__ (cls ):
198+ raise NotImplementedError
0 commit comments