@@ -27,34 +27,27 @@ class _BorrowedAsyncIterator(AsyncGenerator[T, S]):
2727 Borrowed async iterator/generator, preventing to ``aclose`` the ``iterable``
2828 """
2929
30- # adding special methods such as `__aiter__ ` as `__slots__` allows to set them
30+ # adding special methods such as `__anext__ ` as `__slots__` allows to set them
3131 # on the instance: the interpreter expects *descriptors* not methods, and
3232 # `__slots__` are descriptors just like methods.
33- __slots__ = "__wrapped__" , "__aiter__ " , "__anext__ " , "asend " , "athrow "
33+ __slots__ = "__wrapped__" , "__anext__ " , "asend " , "athrow " , "_wrapper "
3434
3535 # Type checker does not understand `__slot__` definitions
36- __aiter__ : Callable [[Any ], AsyncGenerator [T , S ]]
3736 __anext__ : Callable [[Any ], Awaitable [T ]]
3837 asend : Any
3938 athrow : Any
4039
4140 def __init__ (self , iterator : Union [AsyncIterator [T ], AsyncGenerator [T , S ]]):
4241 self .__wrapped__ = iterator
43- # iterator.__aiter__ is likely to return iterator (e.g. for async def: yield)
44- # We wrap it in a separate async iterator/generator to hide its __aiter__.
45- try :
46- wrapped_iterator : AsyncGenerator [T , S ] = self ._wrapped_iterator (iterator )
47- self .__anext__ = iterator .__anext__ # type: ignore
48- self .__aiter__ = wrapped_iterator .__aiter__ # type: ignore
49- except (AttributeError , TypeError ):
50- raise TypeError (
51- "borrowing requires an async iterator "
52- + f"with __aiter__ and __anext__ method, got { type (iterator ).__name__ } "
53- ) from None
54- self .__anext__ = wrapped_iterator .__anext__ # type: ignore
55- # Our wrapper cannot pass on asend/athrow without getting much heavier.
56- # Since interleaving anext/asend/athrow is not allowed, and the wrapper holds
57- # no internal state other than the iterator, circumventing it should be fine.
42+ # Create an actual async generator wrapper that we can close. Otherwise,
43+ # if we pass on the original iterator methods we cannot disable them if
44+ # anyone has a reference to them.
45+ self ._wrapper : AsyncGenerator [T , S ] = self ._wrapped_iterator (iterator )
46+ # Forward all async iterator/generator methods but __aiter__ and aclose:
47+ # An async *iterator* (e.g. `async def: yield`) must return
48+ # itself from __aiter__. If we do not shadow this then
49+ # running aiter(self).aclose closes the underlying iterator.
50+ self .__anext__ = self ._wrapper .__anext__ # type: ignore
5851 if hasattr (iterator , "asend" ):
5952 self .asend = iterator .asend # type: ignore
6053 if hasattr (iterator , "athrow" ):
@@ -70,11 +63,14 @@ async def _wrapped_iterator(
7063 async for item in iterator :
7164 yield item
7265
73- def __repr__ (self ):
66+ def __aiter__ (self ) -> AsyncGenerator [T , S ]:
67+ return self
68+
69+ def __repr__ (self ) -> str :
7470 return f"<asyncstdlib.borrow of { self .__wrapped__ !r} at 0x{ (id (self )):x} >"
7571
76- async def _aclose_wrapper (self ):
77- wrapper_iterator = self .__aiter__ ()
72+ async def _aclose_wrapper (self ) -> None :
73+ wrapper_iterator = self ._wrapper
7874 # allow closing the intermediate wrapper
7975 # this prevents a resource warning if the wrapper is GC'd
8076 # the underlying iterator is NOT affected by this
@@ -85,17 +81,17 @@ async def _aclose_wrapper(self):
8581 if hasattr (self , "athrow" ):
8682 self .athrow = wrapper_iterator .athrow
8783
88- def aclose (self ):
84+ def aclose (self ) -> Awaitable [ None ] :
8985 return self ._aclose_wrapper ()
9086
9187
9288class _ScopedAsyncIterator (_BorrowedAsyncIterator [T , S ]):
9389 __slots__ = ()
9490
95- def __repr__ (self ):
91+ def __repr__ (self ) -> str :
9692 return f"<asyncstdlib.scoped_iter of { self .__wrapped__ !r} at 0x{ (id (self )):x} >"
9793
98- async def aclose (self ):
94+ async def aclose (self ) -> None :
9995 pass
10096
10197
@@ -119,16 +115,16 @@ async def __aenter__(self) -> AsyncIterator[T]:
119115 borrowed_iter = self ._borrowed_iter = _ScopedAsyncIterator (self ._iterator )
120116 return borrowed_iter
121117
122- async def __aexit__ (self , exc_type , exc_val , exc_tb ) :
118+ async def __aexit__ (self , * args : Any ) -> bool :
123119 await self ._borrowed_iter ._aclose_wrapper () # type: ignore
124120 await self ._iterator .aclose () # type: ignore
125121 return False
126122
127- def __repr__ (self ):
123+ def __repr__ (self ) -> str :
128124 return f"<{ self .__class__ .__name__ } of { self ._iterator !r} at 0x{ (id (self )):x} >"
129125
130126
131- def borrow (iterator : AsyncIterator [T ]) -> _BorrowedAsyncIterator [ T , None ]:
127+ def borrow (iterator : AsyncIterator [T ]) -> AsyncIterator [ T ]:
132128 """
133129 Borrow an async iterator, preventing to ``aclose`` it
134130
@@ -146,10 +142,15 @@ def borrow(iterator: AsyncIterator[T]) -> _BorrowedAsyncIterator[T, None]:
146142 .. seealso:: Use :py:func:`~.scoped_iter` to ensure an (async) iterable
147143 is eventually closed and only :term:`borrowed <borrowing>` until then.
148144 """
145+ if not hasattr (iterator , "__anext__" ) or not hasattr (iterator , "__aiter__" ):
146+ raise TypeError (
147+ "borrowing requires an async iterator "
148+ + f"with __aiter__ and __anext__ method, got { type (iterator ).__name__ } "
149+ )
149150 return _BorrowedAsyncIterator (iterator )
150151
151152
152- def scoped_iter (iterable : AnyIterable [T ]):
153+ def scoped_iter (iterable : AnyIterable [T ]) -> AsyncContextManager [ AsyncIterator [ T ]] :
153154 """
154155 Context manager that provides an async iterator for an (async) ``iterable``
155156
@@ -166,9 +167,9 @@ def scoped_iter(iterable: AnyIterable[T]):
166167 async def head_tail(iterable, leading=5, trailing=5):
167168 '''Provide the first ``leading`` and last ``trailing`` items'''
168169 # create async iterator valid for the entire block
169- async with scoped_iter(iterable) as async_iter:
170+ async with a. scoped_iter(iterable) as async_iter:
170171 # ... safely pass it on without it being closed ...
171- async for item in a.isclice (async_iter, leading):
172+ async for item in a.islice (async_iter, leading):
172173 yield item
173174 tail = deque(maxlen=trailing)
174175 # ... and use it again in the block
@@ -336,7 +337,7 @@ def sync(function: Callable[..., T]) -> Callable[..., Awaitable[T]]:
336337 ...
337338
338339
339- def sync (function : Callable ) -> Callable [..., Awaitable [ T ] ]:
340+ def sync (function : Callable [..., T ] ) -> Callable [..., Any ]:
340341 r"""
341342 Wraps a callable to ensure its result can be ``await``\ ed
342343
@@ -372,10 +373,10 @@ async def main():
372373 return function
373374
374375 @wraps (function )
375- async def async_wrapped (* args , ** kwargs ) :
376+ async def async_wrapped (* args : Any , ** kwargs : Any ) -> T :
376377 result = function (* args , ** kwargs )
377378 if isinstance (result , Awaitable ):
378- return await result
379+ return await result # type: ignore
379380 return result
380381
381382 return async_wrapped
0 commit comments