Skip to content

Commit 43563c0

Browse files
Fix closeable scoped/borrowed iter (#69)
* explicitly testing closing of scoped_iter and repeated slicing * __aiter__ of a borrowed iterator no longer exposes original iterator * fixed doc typos * asyncstdlib related type annotations * turned on mypy strict checking for asynctools and _core (see #65) Co-authored-by: isra17 <[email protected]>
1 parent af6e23a commit 43563c0

File tree

5 files changed

+81
-35
lines changed

5 files changed

+81
-35
lines changed

asyncstdlib/asynctools.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,27 @@ class _BorrowedAsyncIterator(AsyncGenerator[T, S]):
2727
Borrowed async iterator/generator, preventing to ``aclose`` the ``iterable``
2828
"""
2929

30-
# adding special methods such as `__aiter__` as `__slots__` allows to set them
30+
# adding special methods such as `__anext__` as `__slots__` allows to set them
3131
# on the instance: the interpreter expects *descriptors* not methods, and
3232
# `__slots__` are descriptors just like methods.
33-
__slots__ = "__wrapped__", "__aiter__", "__anext__", "asend", "athrow"
33+
__slots__ = "__wrapped__", "__anext__", "asend", "athrow", "_wrapper"
3434

3535
# Type checker does not understand `__slot__` definitions
36-
__aiter__: Callable[[Any], AsyncGenerator[T, S]]
3736
__anext__: Callable[[Any], Awaitable[T]]
3837
asend: Any
3938
athrow: Any
4039

4140
def __init__(self, iterator: Union[AsyncIterator[T], AsyncGenerator[T, S]]):
4241
self.__wrapped__ = iterator
43-
# iterator.__aiter__ is likely to return iterator (e.g. for async def: yield)
44-
# We wrap it in a separate async iterator/generator to hide its __aiter__.
45-
try:
46-
wrapped_iterator: AsyncGenerator[T, S] = self._wrapped_iterator(iterator)
47-
self.__anext__ = iterator.__anext__ # type: ignore
48-
self.__aiter__ = wrapped_iterator.__aiter__ # type: ignore
49-
except (AttributeError, TypeError):
50-
raise TypeError(
51-
"borrowing requires an async iterator "
52-
+ f"with __aiter__ and __anext__ method, got {type(iterator).__name__}"
53-
) from None
54-
self.__anext__ = wrapped_iterator.__anext__ # type: ignore
55-
# Our wrapper cannot pass on asend/athrow without getting much heavier.
56-
# Since interleaving anext/asend/athrow is not allowed, and the wrapper holds
57-
# no internal state other than the iterator, circumventing it should be fine.
42+
# Create an actual async generator wrapper that we can close. Otherwise,
43+
# if we pass on the original iterator methods we cannot disable them if
44+
# anyone has a reference to them.
45+
self._wrapper: AsyncGenerator[T, S] = self._wrapped_iterator(iterator)
46+
# Forward all async iterator/generator methods but __aiter__ and aclose:
47+
# An async *iterator* (e.g. `async def: yield`) must return
48+
# itself from __aiter__. If we do not shadow this then
49+
# running aiter(self).aclose closes the underlying iterator.
50+
self.__anext__ = self._wrapper.__anext__ # type: ignore
5851
if hasattr(iterator, "asend"):
5952
self.asend = iterator.asend # type: ignore
6053
if hasattr(iterator, "athrow"):
@@ -70,11 +63,14 @@ async def _wrapped_iterator(
7063
async for item in iterator:
7164
yield item
7265

73-
def __repr__(self):
66+
def __aiter__(self) -> AsyncGenerator[T, S]:
67+
return self
68+
69+
def __repr__(self) -> str:
7470
return f"<asyncstdlib.borrow of {self.__wrapped__!r} at 0x{(id(self)):x}>"
7571

76-
async def _aclose_wrapper(self):
77-
wrapper_iterator = self.__aiter__()
72+
async def _aclose_wrapper(self) -> None:
73+
wrapper_iterator = self._wrapper
7874
# allow closing the intermediate wrapper
7975
# this prevents a resource warning if the wrapper is GC'd
8076
# the underlying iterator is NOT affected by this
@@ -85,17 +81,17 @@ async def _aclose_wrapper(self):
8581
if hasattr(self, "athrow"):
8682
self.athrow = wrapper_iterator.athrow
8783

88-
def aclose(self):
84+
def aclose(self) -> Awaitable[None]:
8985
return self._aclose_wrapper()
9086

9187

9288
class _ScopedAsyncIterator(_BorrowedAsyncIterator[T, S]):
9389
__slots__ = ()
9490

95-
def __repr__(self):
91+
def __repr__(self) -> str:
9692
return f"<asyncstdlib.scoped_iter of {self.__wrapped__!r} at 0x{(id(self)):x}>"
9793

98-
async def aclose(self):
94+
async def aclose(self) -> None:
9995
pass
10096

10197

@@ -119,16 +115,16 @@ async def __aenter__(self) -> AsyncIterator[T]:
119115
borrowed_iter = self._borrowed_iter = _ScopedAsyncIterator(self._iterator)
120116
return borrowed_iter
121117

122-
async def __aexit__(self, exc_type, exc_val, exc_tb):
118+
async def __aexit__(self, *args: Any) -> bool:
123119
await self._borrowed_iter._aclose_wrapper() # type: ignore
124120
await self._iterator.aclose() # type: ignore
125121
return False
126122

127-
def __repr__(self):
123+
def __repr__(self) -> str:
128124
return f"<{self.__class__.__name__} of {self._iterator!r} at 0x{(id(self)):x}>"
129125

130126

131-
def borrow(iterator: AsyncIterator[T]) -> _BorrowedAsyncIterator[T, None]:
127+
def borrow(iterator: AsyncIterator[T]) -> AsyncIterator[T]:
132128
"""
133129
Borrow an async iterator, preventing to ``aclose`` it
134130
@@ -146,10 +142,15 @@ def borrow(iterator: AsyncIterator[T]) -> _BorrowedAsyncIterator[T, None]:
146142
.. seealso:: Use :py:func:`~.scoped_iter` to ensure an (async) iterable
147143
is eventually closed and only :term:`borrowed <borrowing>` until then.
148144
"""
145+
if not hasattr(iterator, "__anext__") or not hasattr(iterator, "__aiter__"):
146+
raise TypeError(
147+
"borrowing requires an async iterator "
148+
+ f"with __aiter__ and __anext__ method, got {type(iterator).__name__}"
149+
)
149150
return _BorrowedAsyncIterator(iterator)
150151

151152

152-
def scoped_iter(iterable: AnyIterable[T]):
153+
def scoped_iter(iterable: AnyIterable[T]) -> AsyncContextManager[AsyncIterator[T]]:
153154
"""
154155
Context manager that provides an async iterator for an (async) ``iterable``
155156
@@ -166,9 +167,9 @@ def scoped_iter(iterable: AnyIterable[T]):
166167
async def head_tail(iterable, leading=5, trailing=5):
167168
'''Provide the first ``leading`` and last ``trailing`` items'''
168169
# create async iterator valid for the entire block
169-
async with scoped_iter(iterable) as async_iter:
170+
async with a.scoped_iter(iterable) as async_iter:
170171
# ... safely pass it on without it being closed ...
171-
async for item in a.isclice(async_iter, leading):
172+
async for item in a.islice(async_iter, leading):
172173
yield item
173174
tail = deque(maxlen=trailing)
174175
# ... and use it again in the block
@@ -336,7 +337,7 @@ def sync(function: Callable[..., T]) -> Callable[..., Awaitable[T]]:
336337
...
337338

338339

339-
def sync(function: Callable) -> Callable[..., Awaitable[T]]:
340+
def sync(function: Callable[..., T]) -> Callable[..., Any]:
340341
r"""
341342
Wraps a callable to ensure its result can be ``await``\ ed
342343
@@ -372,10 +373,10 @@ async def main():
372373
return function
373374

374375
@wraps(function)
375-
async def async_wrapped(*args, **kwargs):
376+
async def async_wrapped(*args: Any, **kwargs: Any) -> T:
376377
result = function(*args, **kwargs)
377378
if isinstance(result, Awaitable):
378-
return await result
379+
return await result # type: ignore
379380
return result
380381

381382
return async_wrapped

asyncstdlib/contextlib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,11 @@ def __init__(self, enter_result=None):
179179
self.enter_result = enter_result
180180

181181
@overload
182-
def __aenter__(self: "NullContext[None]") -> None:
182+
async def __aenter__(self: "NullContext[None]") -> None:
183183
...
184184

185185
@overload
186-
def __aenter__(self: "NullContext[T]") -> T:
186+
async def __aenter__(self: "NullContext[T]") -> T:
187187
...
188188

189189
async def __aenter__(self):

pyproject.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,21 @@ no_implicit_optional = true
4747
warn_redundant_casts = true
4848
warn_unused_ignores = true
4949
warn_unreachable = true
50+
51+
[[tool.mypy.overrides]]
52+
module = [
53+
"asyncstdlib.asynctools",
54+
"asyncstdlib._core",
55+
]
56+
disallow_any_generics = true
57+
disallow_subclassing_any = true
58+
disallow_untyped_calls = true
59+
disallow_untyped_defs = true
60+
disallow_incomplete_defs = true
61+
check_untyped_defs = true
62+
disallow_untyped_decorators = true
63+
no_implicit_optional = true
64+
warn_unused_ignores = true
65+
warn_return_any = true
66+
no_implicit_reexport = true
67+
strict_equality = true

unittests/test_asynctools.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,23 @@
88
CLOSED = "closed"
99

1010

11+
@sync
12+
async def test_scoped_iter_iterclose():
13+
"""A `scoped_iter` cannot be closed via its public interface"""
14+
async_iterable, iterable = asyncify(range(10)), iter(range(10))
15+
async with a.scoped_iter(async_iterable) as a1:
16+
assert await a.anext(a1) == next(iterable)
17+
# closing a scoped iterator is a no-op
18+
await a1.aclose()
19+
assert await a.anext(a1) == next(iterable)
20+
# explicitly test #68
21+
await a.iter(a1).aclose()
22+
assert await a.anext(a1) == next(iterable)
23+
assert await a.list(async_iterable) == list(iterable)
24+
assert await a.anext(a1, CLOSED) == CLOSED
25+
assert await a.anext(async_iterable, CLOSED) == CLOSED
26+
27+
1128
@sync
1229
async def test_nested_lifetime():
1330
async_iterable, iterable = asyncify(range(10)), iter(range(10))

unittests/test_itertools.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ async def test_islice_exact(slicing):
170170
)
171171

172172

173+
@sync
174+
async def test_islice_scoped_iter():
175+
"""multiple `isclice` on borrowed iterator are consecutive"""
176+
async_iterable, iterable = asyncify(range(10)), iter(range(10))
177+
async with a.scoped_iter(async_iterable) as a1:
178+
assert await a.list(a.islice(a1, 5)) == list(itertools.islice(iterable, 5))
179+
assert await a.list(a.islice(a1, 5)) == list(itertools.islice(iterable, 5))
180+
assert await a.list(a.islice(a1, 5)) == list(itertools.islice(iterable, 5))
181+
182+
173183
starmap_cases = [
174184
(lambda x, y: x + y, [(1, 2), (3, 4)]),
175185
(lambda *args: sum(args), [range(i) for i in range(1, 10)]),

0 commit comments

Comments
 (0)