Skip to content

Commit 4247fd5

Browse files
Add strong async neutral iteration (#80)
* added any_iter * added unittests for any_iter * exporting any_iter at top-level
1 parent 4559cfc commit 4247fd5

File tree

4 files changed

+89
-2
lines changed

4 files changed

+89
-2
lines changed

asyncstdlib/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
zip_longest,
3434
groupby,
3535
)
36-
from .asynctools import borrow, scoped_iter, await_each, apply, sync
36+
from .asynctools import borrow, scoped_iter, await_each, any_iter, apply, sync
3737
from .heapq import merge, nlargest, nsmallest
3838

3939
__version__ = "3.10.2"
@@ -82,6 +82,7 @@
8282
"borrow",
8383
"scoped_iter",
8484
"await_each",
85+
"any_iter",
8586
"apply",
8687
"sync",
8788
# heapq

asyncstdlib/asynctools.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,51 @@ async def async_wrapped(*args: Any, **kwargs: Any) -> T:
380380
return result
381381

382382
return async_wrapped
383+
384+
385+
async def any_iter(
386+
__iter: Union[
387+
Awaitable[AnyIterable[Awaitable[T]]],
388+
Awaitable[AnyIterable[T]],
389+
AnyIterable[Awaitable[T]],
390+
AnyIterable[T],
391+
]
392+
) -> AsyncIterator[T]:
393+
"""
394+
Provide an async iterator for various forms of "asynchronous iterable"
395+
396+
Useful to uniformly handle async iterables, awaitable iterables, iterables of
397+
awaitables, and similar in an ``async for`` loop. Among other things, this
398+
matches all forms of ``async def`` functions providing iterables.
399+
400+
.. code-block:: python3
401+
402+
import random
403+
import asyncstdlib as a
404+
405+
# AsyncIterator[T]
406+
async def async_iter(n):
407+
for i in range(n):
408+
yield i
409+
410+
# Awaitable[Iterator[T]]
411+
async def await_iter(n):
412+
return [*range(n)]
413+
414+
some_iter = random.choice([async_iter, await_iter, range])
415+
async for item in a.any_iter(some_iter(4)):
416+
print(item)
417+
418+
This function must eagerly resolve each "async layer" before checking if
419+
the next layer is as expected. This incurs a performance penalty and
420+
non-iterables may be left unusable by this.
421+
Prefer :py:func:`~.builtins.iter` to test for iterables with :term:`EAFP`
422+
and for performance when only simple iterables need handling.
423+
"""
424+
iterable = __iter if not isinstance(__iter, Awaitable) else await __iter
425+
if isinstance(iterable, AsyncIterable):
426+
async for item in iterable:
427+
yield item if not isinstance(item, Awaitable) else await item
428+
else:
429+
for item in iterable:
430+
yield item if not isinstance(item, Awaitable) else await item

docs/source/api/asynctools.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ Async transforming
2727

2828
.. versionadded:: 3.9.3
2929

30+
.. autofunction:: any_iter(iter: (await) (async) iter (await) T)
31+
:async-for: :T
32+
33+
.. versionadded:: 3.10.3
34+
3035
.. autofunction:: await_each(awaitables: iter await T)
3136
:async-for: :T
3237

unittests/test_asynctools.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ async def check_4(x: int, y: int, z: int) -> int:
230230

231231
t1 = await a.sync(check_3)(x=100)
232232
t2 = await a.sync(check_4)(x=5, y=5, z=10)
233-
t3 = await a.sync(lambda x: x ** 3)(x=5)
233+
t3 = await a.sync(lambda x: x**3)(x=5)
234234

235235
with pytest.raises(TypeError):
236236
a.sync("string")(10)
@@ -252,3 +252,36 @@ async def coro():
252252
return coro()
253253

254254
assert await nocoro_async(5) == 5
255+
256+
257+
async def await_iter(n: int):
258+
return [*range(n)]
259+
260+
261+
async def async_iter(n: int):
262+
for i in range(n):
263+
yield i
264+
265+
266+
async def await_value(i):
267+
return i
268+
269+
270+
async def await_iter_await(n: int):
271+
return [await_value(i) for i in range(n)]
272+
273+
274+
async def await_async_iter_await(n: int):
275+
for i in range(n):
276+
yield await_value(i)
277+
278+
279+
@pytest.mark.parametrize("n", [0, 1, 12])
280+
@pytest.mark.parametrize(
281+
"any_iterable_t",
282+
[range, await_iter, async_iter, await_iter_await, await_async_iter_await],
283+
)
284+
@sync
285+
async def test_any_iter(n, any_iterable_t):
286+
iterable = any_iterable_t(n)
287+
assert [item async for item in a.any_iter(iterable)] == [*range(n)]

0 commit comments

Comments
 (0)