Skip to content

Commit b966a95

Browse files
Complete itertools.chain interface (#108)
* added chain.aclose method for cleanup (closes #107) * do not reconstruct chain implementation again and again * use same implementation for chain and chain.from_iterable * chain owns explicitly passed iterables
1 parent d20a48b commit b966a95

File tree

2 files changed

+86
-14
lines changed

2 files changed

+86
-14
lines changed

asyncstdlib/itertools.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -148,33 +148,54 @@ class chain(AsyncIterator[T]):
148148
The resulting iterator consecutively iterates over and yields all values from
149149
each of the ``iterables``. This is similar to converting all ``iterables`` to
150150
sequences and concatenating them, but lazily exhausts each iterable.
151+
152+
The ``chain`` assumes ownership of its ``iterables`` and closes them reliably
153+
when the ``chain`` is closed. Pass the ``iterables`` via a :py:class:`tuple` to
154+
``chain.from_iterable`` to avoid closing all iterables but those already processed.
151155
"""
152156

153-
__slots__ = ("_impl",)
157+
__slots__ = ("_iterator", "_owned_iterators")
154158

155-
def __init__(self, *iterables: AnyIterable[T]):
156-
async def impl() -> AsyncIterator[T]:
157-
for iterable in iterables:
159+
@staticmethod
160+
async def _chain_iterator(
161+
any_iterables: AnyIterable[AnyIterable[T]],
162+
) -> AsyncGenerator[T, None]:
163+
async with ScopedIter(any_iterables) as iterables:
164+
async for iterable in iterables:
158165
async with ScopedIter(iterable) as iterator:
159166
async for item in iterator:
160167
yield item
161168

162-
self._impl = impl()
169+
def __init__(
170+
self, *iterables: AnyIterable[T], _iterables: AnyIterable[AnyIterable[T]] = ()
171+
):
172+
self._iterator = self._chain_iterator(iterables or _iterables)
173+
self._owned_iterators = (
174+
iterable
175+
for iterable in iterables
176+
if isinstance(iterable, AsyncIterator) and hasattr(iterable, "aclose")
177+
)
163178

164-
@staticmethod
165-
async def from_iterable(iterable: AnyIterable[AnyIterable[T]]) -> AsyncIterator[T]:
179+
@classmethod
180+
def from_iterable(cls, iterable: AnyIterable[AnyIterable[T]]) -> "chain[T]":
166181
"""
167182
Alternate constructor for :py:func:`~.chain` that lazily exhausts
168-
iterables as well
183+
the ``iterable`` of iterables as well
184+
185+
This is suitable for chaining iterables from a lazy or infinite ``iterable``.
186+
In turn, closing the ``chain`` only closes those iterables
187+
already fetched from ``iterable``.
169188
"""
170-
async with ScopedIter(iterable) as iterables:
171-
async for sub_iterable in iterables:
172-
async with ScopedIter(sub_iterable) as iterator:
173-
async for item in iterator:
174-
yield item
189+
return cls(_iterables=iterable)
175190

176191
def __anext__(self) -> Awaitable[T]:
177-
return self._impl.__anext__()
192+
return self._iterator.__anext__()
193+
194+
async def aclose(self) -> None:
195+
for iterable in self._owned_iterators:
196+
if hasattr(iterable, "aclose"):
197+
await iterable.aclose()
198+
await self._iterator.aclose()
178199

179200

180201
async def compress(

unittests/test_itertools.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,57 @@ async def test_chain(iterables):
8787
)
8888

8989

90+
class ACloseFacade:
91+
"""Wrapper to check if an iterator has been closed"""
92+
93+
def __init__(self, iterable):
94+
self.closed = False
95+
self.__wrapped__ = iterable
96+
self._iterator = a.iter(iterable)
97+
98+
async def __anext__(self):
99+
if self.closed:
100+
raise StopAsyncIteration()
101+
return await self._iterator.__anext__()
102+
103+
def __aiter__(self):
104+
return self
105+
106+
async def aclose(self):
107+
if hasattr(self._iterator, "aclose"):
108+
await self._iterator.aclose()
109+
self.closed = True
110+
111+
112+
@pytest.mark.parametrize("iterables", chains)
113+
@sync
114+
async def test_chain_close_auto(iterables):
115+
"""Test that `chain` closes exhausted iterators"""
116+
closeable_iterables = [ACloseFacade(iterable) for iterable in iterables]
117+
assert await a.list(a.chain(*closeable_iterables)) == list(
118+
itertools.chain(*iterables)
119+
)
120+
assert all(iterable.closed for iterable in closeable_iterables)
121+
122+
123+
# insert a known filled iterable since chain closes all that are exhausted
124+
@pytest.mark.parametrize("iterables", [([1], *chain) for chain in chains])
125+
@pytest.mark.parametrize(
126+
"chain_type, must_close",
127+
[(lambda iterators: a.chain(*iterators), True), (a.chain.from_iterable, False)],
128+
)
129+
@sync
130+
async def test_chain_close_partial(iterables, chain_type, must_close):
131+
"""Test that `chain` closes owned iterators"""
132+
closeable_iterables = [ACloseFacade(iterable) for iterable in iterables]
133+
chain = chain_type(closeable_iterables)
134+
assert await a.anext(chain) == next(itertools.chain(*iterables))
135+
await chain.aclose()
136+
assert all(iterable.closed == must_close for iterable in closeable_iterables[1:])
137+
# closed chain must remain closed regardless of iterators
138+
assert await a.anext(chain, "sentinel") == "sentinel"
139+
140+
90141
compress_cases = [
91142
(range(20), [idx % 2 for idx in range(20)]),
92143
([1] * 5, [True, True, False, True, True]),

0 commit comments

Comments
 (0)