99from types import TracebackType
1010from typing import (
1111 Any , AsyncIterator , Awaitable , Callable , Deque , Generator , NoReturn ,
12- Optional , Type , TypeVar , Union ,
12+ Optional , Type , TypeVar , Union , Generic , ParamSpec ,
1313)
1414from weakref import finalize
1515
1919
2020T = TypeVar ("T" )
2121R = TypeVar ("R" )
22+ P = ParamSpec ("P" )
2223
23- GenType = Generator [T , R , None ]
24+ GenType = Generator [T , None , None ]
2425FuncType = Callable [[], GenType ]
2526
2627
@@ -144,7 +145,7 @@ class IteratorWrapperStatistic(Statistic):
144145 enqueued : int
145146
146147
147- class IteratorWrapper (AsyncIterator , EventLoopMixin ):
148+ class IteratorWrapper (Generic [ P , T ], AsyncIterator , EventLoopMixin ):
148149 __slots__ = (
149150 "__channel" ,
150151 "__close_event" ,
@@ -155,9 +156,11 @@ class IteratorWrapper(AsyncIterator, EventLoopMixin):
155156 ) + EventLoopMixin .__slots__
156157
157158 def __init__ (
158- self , gen_func : FuncType ,
159+ self ,
160+ gen_func : Callable [P , Generator [T , None , None ]],
159161 loop : Optional [asyncio .AbstractEventLoop ] = None ,
160- max_size : int = 0 , executor : Optional [Executor ] = None ,
162+ max_size : int = 0 ,
163+ executor : Optional [Executor ] = None ,
161164 statistic_name : Optional [str ] = None ,
162165 ):
163166
@@ -227,11 +230,9 @@ async def wait_closed(self) -> None:
227230 await asyncio .gather (self .__gen_task , return_exceptions = True )
228231
229232 def _run (self ) -> Any :
230- return self .loop .run_in_executor (
231- self .executor , self ._in_thread ,
232- )
233+ return self .loop .run_in_executor (self .executor , self ._in_thread )
233234
234- def __aiter__ (self ) -> AsyncIterator [Any ]:
235+ def __aiter__ (self ) -> AsyncIterator [T ]:
235236 if not self .loop .is_running ():
236237 raise RuntimeError ("Event loop is not running" )
237238
@@ -242,7 +243,7 @@ def __aiter__(self) -> AsyncIterator[Any]:
242243 self .__gen_task = gen_task
243244 return IteratorProxy (self , self .close )
244245
245- async def __anext__ (self ) -> Awaitable [ T ] :
246+ async def __anext__ (self ) -> T :
246247 try :
247248 item , is_exc = await self .__channel .get ()
248249 except ChannelClosed :
@@ -269,13 +270,13 @@ async def __aexit__(
269270 await self .close ()
270271
271272
272- class IteratorProxy (AsyncIterator ):
273+ class IteratorProxy (Generic [ T ], AsyncIterator ):
273274 def __init__ (
274- self , iterator : AsyncIterator ,
275+ self , iterator : AsyncIterator [ T ] ,
275276 finalizer : Callable [[], Any ],
276277 ):
277278 self .__iterator = iterator
278279 finalize (self , finalizer )
279280
280- def __anext__ (self ) -> Awaitable [Any ]:
281+ def __anext__ (self ) -> Awaitable [T ]:
281282 return self .__iterator .__anext__ ()
0 commit comments