11from __future__ import annotations
22
33import inspect
4- from typing import Callable , List , TYPE_CHECKING
4+ import typing as t
55
66from .async_helper .output import AsyncPipelineOutput
77from .sync_helper .output import PipelineOutput
88
9- if TYPE_CHECKING :
9+ if t . TYPE_CHECKING :
1010 from .task import Task
1111
1212
13- class Pipeline :
13+ _P = t .ParamSpec ('P' )
14+ _R = t .TypeVar ('R' )
15+ _P_Other = t .ParamSpec ("P_Other" )
16+ _R_Other = t .TypeVar ("R_Other" )
17+
18+
19+ class Pipeline (t .Generic [_P , _R ]):
1420 """A sequence of at least 1 Tasks.
1521
1622 Two pipelines can be piped into another via:
@@ -21,59 +27,83 @@ class Pipeline:
2127 ```
2228 """
2329
24- def __new__ (cls , tasks : List [Task ]):
30+ def __new__ (cls , tasks : t . List [Task ]):
2531 if any (task .is_async for task in tasks ):
2632 instance = object .__new__ (AsyncPipeline )
2733 else :
2834 instance = object .__new__ (cls )
2935 instance .__init__ (tasks = tasks )
3036 return instance
3137
32- def __init__ (self , tasks : List [Task ]):
38+ def __init__ (self , tasks : t . List [Task ]):
3339 self .tasks = tasks
3440
35- def __call__ (self , * args , ** kwargs ) :
41+ def __call__ (self , * args : _P . args , ** kwargs : _P . kwargs ) -> t . Generator [ _R ] :
3642 """Return the pipeline output."""
3743 output = PipelineOutput (self )
3844 return output (* args , ** kwargs )
45+
46+ @t .overload
47+ def pipe (self : AsyncPipeline [_P , _R ], other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
48+
49+ @t .overload
50+ def pipe (self : AsyncPipeline [_P , _R ], other : Pipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
3951
40- def pipe (self , other ) -> Pipeline :
52+ @t .overload
53+ def pipe (self , other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
54+
55+ @t .overload
56+ def pipe (self , other : Pipeline [_P_Other , _R_Other ]) -> Pipeline [_P , _R_Other ]: ...
57+
58+ def pipe (self , other : Pipeline ):
4159 """Connect two pipelines, returning a new Pipeline."""
4260 if not isinstance (other , Pipeline ):
4361 raise TypeError (f"{ other } of type { type (other )} cannot be piped into a Pipeline" )
4462 return Pipeline (self .tasks + other .tasks )
4563
46- def __or__ (self , other : Pipeline ) -> Pipeline :
47- """Allow the syntax `pipeline1 | pipeline2`."""
64+ @t .overload
65+ def __or__ (self : AsyncPipeline [_P , _R ], other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
66+
67+ @t .overload
68+ def __or__ (self : AsyncPipeline [_P , _R ], other : Pipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
69+
70+ @t .overload
71+ def __or__ (self , other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
72+
73+ @t .overload
74+ def __or__ (self , other : Pipeline [_P_Other , _R_Other ]) -> Pipeline [_P , _R_Other ]: ...
75+
76+ def __or__ (self , other : Pipeline ):
77+ """Connect two pipelines, returning a new Pipeline."""
4878 return self .pipe (other )
4979
50- def consume (self , other : Callable ) -> Callable :
80+ def consume (self , other : t . Callable [..., _R_Other ] ) -> t . Callable [ _P , _R_Other ] :
5181 """Connect the pipeline to a consumer function (a callable that takes the pipeline output as input)."""
5282 if callable (other ):
53- def consumer (* args , ** kwargs ) :
83+ def consumer (* args : _P . args , ** kwargs : _P . kwargs ) -> _R_Other :
5484 return other (self (* args , ** kwargs ))
5585 return consumer
5686 raise TypeError (f"{ other } must be a callable that takes a generator" )
5787
58- def __gt__ (self , other : Callable ) -> Callable :
59- """Allow the syntax ` pipeline > consumer` ."""
88+ def __gt__ (self , other : t . Callable [..., _R_Other ] ) -> t . Callable [ _P , _R_Other ] :
89+ """Connect the pipeline to a consumer function (a callable that takes the pipeline output as input) ."""
6090 return self .consume (other )
6191
6292 def __repr__ (self ):
6393 return f"{ self .__class__ .__name__ } { [task .func for task in self .tasks ]} "
6494
6595
66- class AsyncPipeline (Pipeline ):
67- def __call__ (self , * args , ** kwargs ) :
96+ class AsyncPipeline (Pipeline [ _P , _R ] ):
97+ def __call__ (self , * args : _P . args , ** kwargs : _P . kwargs ) -> t . AsyncGenerator [ _R ] :
6898 """Return the pipeline output."""
6999 output = AsyncPipelineOutput (self )
70100 return output (* args , ** kwargs )
71-
72- def consume (self , other : Callable ) -> Callable :
101+
102+ def consume (self , other : t . Callable [..., _R_Other ] ) -> t . Callable [ _P , _R_Other ] :
73103 """Connect the pipeline to a consumer function (a callable that takes the pipeline output as input)."""
74104 if callable (other ) and \
75105 (inspect .iscoroutinefunction (other ) or inspect .iscoroutinefunction (other .__call__ )):
76- async def consumer (* args , ** kwargs ) :
106+ async def consumer (* args : _P . args , ** kwargs : _P . kwargs ) -> _R_Other :
77107 return await other (self (* args , ** kwargs ))
78108 return consumer
79109 raise TypeError (f"{ other } must be an async callable that takes an async generator" )
0 commit comments