Skip to content

Commit 81329ab

Browse files
authored
Async cursors' for_each accepts coroutines as well (#339)
1 parent aedcbb3 commit 81329ab

File tree

3 files changed

+180
-16
lines changed

3 files changed

+180
-16
lines changed

astrapy/data/cursor.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from copy import deepcopy
2020
from decimal import Decimal
2121
from enum import Enum
22-
from typing import Any, Callable, Generic, TypeVar, cast
22+
from inspect import iscoroutinefunction
23+
from typing import Any, Awaitable, Callable, Generic, TypeVar, cast
2324

2425
from typing_extensions import override
2526

@@ -1106,7 +1107,7 @@ def for_each(
11061107
results in the cursor entering CLOSED state.
11071108
11081109
Args:
1109-
function: a callback function whose only parameter is the type returned
1110+
function: a callback function whose only parameter is of the type returned
11101111
by the cursor. This callback is invoked once per each document yielded
11111112
by the cursor. If the callback returns a `False`, the `for_each`
11121113
invocation stops early and returns without consuming further documents.
@@ -1729,14 +1730,14 @@ def _composite(document: TRAW) -> TNEW:
17291730

17301731
async def for_each(
17311732
self,
1732-
function: Callable[[T], bool | None],
1733+
function: Callable[[T], bool | None] | Callable[[T], Awaitable[bool | None]],
17331734
*,
17341735
general_method_timeout_ms: int | None = None,
17351736
timeout_ms: int | None = None,
17361737
) -> None:
17371738
"""
17381739
Consume the remaining documents in the cursor, invoking a provided callback
1739-
function on each of them.
1740+
function -- or coroutine -- on each of them.
17401741
17411742
Calling this method on a CLOSED cursor results in an error.
17421743
@@ -1751,8 +1752,9 @@ async def for_each(
17511752
adaptations to the async interface.
17521753
17531754
Args:
1754-
function: a callback function whose only parameter is the type returned
1755-
by the cursor. This callback is invoked once per each document yielded
1755+
function: a callback function, or a coroutine, whose only parameter is of
1756+
the type returned by the cursor.
1757+
This callback is invoked once per each document yielded
17561758
by the cursor. If the callback returns a `False`, the `for_each`
17571759
invocation stops early and returns without consuming further documents.
17581760
general_method_timeout_ms: a timeout, in milliseconds, for the whole
@@ -1772,8 +1774,12 @@ async def for_each(
17721774
overall_timeout_ms=copy_ovr_ms,
17731775
)
17741776
self._imprint_internal_state(_cursor)
1777+
is_coro = iscoroutinefunction(function)
17751778
async for document in _cursor:
1776-
res = function(document)
1779+
if is_coro:
1780+
res = await function(document) # type: ignore[misc]
1781+
else:
1782+
res = function(document)
17771783
if res is False:
17781784
break
17791785
_cursor._imprint_internal_state(self)
@@ -2384,7 +2390,7 @@ def for_each(
23842390
results in the cursor entering CLOSED state.
23852391
23862392
Args:
2387-
function: a callback function whose only parameter is the type returned
2393+
function: a callback function whose only parameter is of the type returned
23882394
by the cursor. This callback is invoked once per each row yielded
23892395
by the cursor. If the callback returns a `False`, the `for_each`
23902396
invocation stops early and returns without consuming further rows.
@@ -2445,8 +2451,8 @@ def for_each(
24452451
overall_timeout_ms=copy_ovr_ms,
24462452
)
24472453
self._imprint_internal_state(_cursor)
2448-
for document in _cursor:
2449-
res = function(document)
2454+
for row in _cursor:
2455+
res = function(row)
24502456
if res is False:
24512457
break
24522458
_cursor._imprint_internal_state(self)
@@ -3006,14 +3012,14 @@ def _composite(document: TRAW) -> TNEW:
30063012

30073013
async def for_each(
30083014
self,
3009-
function: Callable[[T], bool | None],
3015+
function: Callable[[T], bool | None] | Callable[[T], Awaitable[bool | None]],
30103016
*,
30113017
general_method_timeout_ms: int | None = None,
30123018
timeout_ms: int | None = None,
30133019
) -> None:
30143020
"""
30153021
Consume the remaining rows in the cursor, invoking a provided callback
3016-
function on each of them.
3022+
function -- or coroutine -- on each of them.
30173023
30183024
Calling this method on a CLOSED cursor results in an error.
30193025
@@ -3028,8 +3034,9 @@ async def for_each(
30283034
adaptations to the async interface.
30293035
30303036
Args:
3031-
function: a callback function whose only parameter is the type returned
3032-
by the cursor. This callback is invoked once per each row yielded
3037+
function: a callback function, or a coroutine, whose only parameter is of
3038+
the type returned by the cursor.
3039+
This callback is invoked once per each row yielded
30333040
by the cursor. If the callback returns a `False`, the `for_each`
30343041
invocation stops early and returns without consuming further rows.
30353042
general_method_timeout_ms: a timeout, in milliseconds, for the whole
@@ -3049,8 +3056,12 @@ async def for_each(
30493056
overall_timeout_ms=copy_ovr_ms,
30503057
)
30513058
self._imprint_internal_state(_cursor)
3052-
async for document in _cursor:
3053-
res = function(document)
3059+
is_coro = iscoroutinefunction(function)
3060+
async for row in _cursor:
3061+
if is_coro:
3062+
res = await function(row) # type: ignore[misc]
3063+
else:
3064+
res = function(row)
30543065
if res is False:
30553066
break
30563067
_cursor._imprint_internal_state(self)

tests/base/integration/collections/test_collection_cursor_async.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,20 @@ def marker0(row: dict[str, Any], acc: list[dict[str, Any]] = accum0) -> None:
314314
assert accum0 == base_rows
315315
assert fe_cur.state == FindCursorState.CLOSED
316316

317+
# full for_each, coroutine
318+
aaccum0: list[dict[str, Any]] = []
319+
320+
async def amarker0(
321+
row: dict[str, Any],
322+
acc: list[dict[str, Any]] = aaccum0,
323+
) -> None:
324+
acc += [row]
325+
326+
afe_cur = async_filled_collection.find()
327+
await afe_cur.for_each(amarker0)
328+
assert aaccum0 == base_rows
329+
assert afe_cur.state == FindCursorState.CLOSED
330+
317331
# partially-consumed for_each
318332
accum1: list[dict[str, Any]] = []
319333

@@ -327,6 +341,22 @@ def marker1(row: dict[str, Any], acc: list[dict[str, Any]] = accum1) -> None:
327341
assert accum1 == base_rows[11:]
328342
assert pfe_cur.state == FindCursorState.CLOSED
329343

344+
# partially-consumed for_each, coroutine
345+
aaccum1: list[dict[str, Any]] = []
346+
347+
async def amarker1(
348+
row: dict[str, Any],
349+
acc: list[dict[str, Any]] = aaccum1,
350+
) -> None:
351+
acc += [row]
352+
353+
apfe_cur = async_filled_collection.find()
354+
for _ in range(11):
355+
await apfe_cur.__anext__()
356+
await apfe_cur.for_each(amarker1)
357+
assert aaccum1 == base_rows[11:]
358+
assert apfe_cur.state == FindCursorState.CLOSED
359+
330360
# mapped for_each
331361
accum2: list[int] = []
332362

@@ -340,6 +370,19 @@ def marker2(val: int, acc: list[int] = accum2) -> None:
340370
assert accum2 == [mint(row) for row in base_rows[17:]]
341371
assert mfe_cur.state == FindCursorState.CLOSED
342372

373+
# mapped for_each, coroutine
374+
aaccum2: list[int] = []
375+
376+
async def amarker2(val: int, acc: list[int] = aaccum2) -> None:
377+
acc += [val]
378+
379+
amfe_cur = async_filled_collection.find().map(mint)
380+
for _ in range(17):
381+
await amfe_cur.__anext__()
382+
await amfe_cur.for_each(amarker2)
383+
assert aaccum2 == [mint(row) for row in base_rows[17:]]
384+
assert amfe_cur.state == FindCursorState.CLOSED
385+
343386
# breaking (early) for_each
344387
accum3: list[dict[str, Any]] = []
345388

@@ -354,6 +397,23 @@ def marker3(row: dict[str, Any], acc: list[dict[str, Any]] = accum3) -> bool:
354397
bfe_another = await bfe_cur.__anext__()
355398
assert bfe_another == base_rows[5]
356399

400+
# breaking (early) for_each, coroutine
401+
aaccum3: list[dict[str, Any]] = []
402+
403+
async def amarker3(
404+
row: dict[str, Any],
405+
acc: list[dict[str, Any]] = aaccum3,
406+
) -> bool:
407+
acc += [row]
408+
return len(acc) < 5
409+
410+
abfe_cur = async_filled_collection.find()
411+
await abfe_cur.for_each(amarker3)
412+
assert aaccum3 == base_rows[:5]
413+
assert abfe_cur.state == FindCursorState.STARTED
414+
abfe_another = await abfe_cur.__anext__()
415+
assert abfe_another == base_rows[5]
416+
357417
# nonbool-nonbreaking for_each
358418
accum4: list[dict[str, Any]] = []
359419

@@ -366,6 +426,21 @@ def marker4(row: dict[str, Any], acc: list[dict[str, Any]] = accum4) -> int:
366426
assert accum4 == base_rows
367427
assert nbfe_cur.state == FindCursorState.CLOSED
368428

429+
# nonbool-nonbreaking for_each, coroutine
430+
aaccum4: list[dict[str, Any]] = []
431+
432+
async def amarker4(
433+
row: dict[str, Any],
434+
acc: list[dict[str, Any]] = aaccum4,
435+
) -> int:
436+
acc += [row]
437+
return 8 if len(acc) < 5 else 0
438+
439+
anbfe_cur = async_filled_collection.find()
440+
await anbfe_cur.for_each(amarker4) # type: ignore[arg-type]
441+
assert aaccum4 == base_rows
442+
assert anbfe_cur.state == FindCursorState.CLOSED
443+
369444
@pytest.mark.describe("test of collection cursors, serdes options obeyance, async")
370445
async def test_collection_cursors_serdes_options_async(
371446
self,

tests/base/integration/tables/test_table_cursor_async.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,20 @@ def marker0(row: dict[str, Any], acc: list[dict[str, Any]] = accum0) -> None:
314314
assert accum0 == base_rows
315315
assert fe_cur.state == FindCursorState.CLOSED
316316

317+
# full for_each, coroutine
318+
aaccum0: list[dict[str, Any]] = []
319+
320+
async def amarker0(
321+
row: dict[str, Any],
322+
acc: list[dict[str, Any]] = aaccum0,
323+
) -> None:
324+
acc += [row]
325+
326+
afe_cur = filled_composite_atable.find()
327+
await afe_cur.for_each(amarker0)
328+
assert aaccum0 == base_rows
329+
assert afe_cur.state == FindCursorState.CLOSED
330+
317331
# partially-consumed for_each
318332
accum1: list[dict[str, Any]] = []
319333

@@ -327,6 +341,22 @@ def marker1(row: dict[str, Any], acc: list[dict[str, Any]] = accum1) -> None:
327341
assert accum1 == base_rows[11:]
328342
assert pfe_cur.state == FindCursorState.CLOSED
329343

344+
# partially-consumed for_each, coroutine
345+
aaccum1: list[dict[str, Any]] = []
346+
347+
async def amarker1(
348+
row: dict[str, Any],
349+
acc: list[dict[str, Any]] = aaccum1,
350+
) -> None:
351+
acc += [row]
352+
353+
apfe_cur = filled_composite_atable.find()
354+
for _ in range(11):
355+
await apfe_cur.__anext__()
356+
await apfe_cur.for_each(amarker1)
357+
assert aaccum1 == base_rows[11:]
358+
assert apfe_cur.state == FindCursorState.CLOSED
359+
330360
# mapped for_each
331361
accum2: list[int] = []
332362

@@ -340,6 +370,22 @@ def marker2(val: int, acc: list[int] = accum2) -> None:
340370
assert accum2 == [mint(row) for row in base_rows[17:]]
341371
assert mfe_cur.state == FindCursorState.CLOSED
342372

373+
# mapped for_each, coroutine
374+
aaccum2: list[int] = []
375+
376+
async def amarker2(
377+
val: int,
378+
acc: list[int] = aaccum2,
379+
) -> None:
380+
acc += [val]
381+
382+
amfe_cur = filled_composite_atable.find().map(mint)
383+
for _ in range(17):
384+
await amfe_cur.__anext__()
385+
await amfe_cur.for_each(amarker2)
386+
assert aaccum2 == [mint(row) for row in base_rows[17:]]
387+
assert amfe_cur.state == FindCursorState.CLOSED
388+
343389
# breaking (early) for_each
344390
accum3: list[dict[str, Any]] = []
345391

@@ -354,6 +400,23 @@ def marker3(row: dict[str, Any], acc: list[dict[str, Any]] = accum3) -> bool:
354400
bfe_another = await bfe_cur.__anext__()
355401
assert bfe_another == base_rows[5]
356402

403+
# breaking (early) for_each, coroutine
404+
aaccum3: list[dict[str, Any]] = []
405+
406+
async def amarker3(
407+
row: dict[str, Any],
408+
acc: list[dict[str, Any]] = aaccum3,
409+
) -> bool:
410+
acc += [row]
411+
return len(acc) < 5
412+
413+
abfe_cur = filled_composite_atable.find()
414+
await abfe_cur.for_each(amarker3)
415+
assert aaccum3 == base_rows[:5]
416+
assert abfe_cur.state == FindCursorState.STARTED
417+
abfe_another = await abfe_cur.__anext__()
418+
assert abfe_another == base_rows[5]
419+
357420
# nonbool-nonbreaking for_each
358421
accum4: list[dict[str, Any]] = []
359422

@@ -366,6 +429,21 @@ def marker4(row: dict[str, Any], acc: list[dict[str, Any]] = accum4) -> int:
366429
assert accum4 == base_rows
367430
assert nbfe_cur.state == FindCursorState.CLOSED
368431

432+
# nonbool-nonbreaking for_each, coroutine
433+
aaccum4: list[dict[str, Any]] = []
434+
435+
async def amarker4(
436+
row: dict[str, Any],
437+
acc: list[dict[str, Any]] = aaccum4,
438+
) -> int:
439+
acc += [row]
440+
return 8 if len(acc) < 5 else 0
441+
442+
anbfe_cur = filled_composite_atable.find()
443+
await anbfe_cur.for_each(amarker4) # type: ignore[arg-type]
444+
assert aaccum4 == base_rows
445+
assert anbfe_cur.state == FindCursorState.CLOSED
446+
369447
@pytest.mark.describe("test of table cursors, serdes options obeyance, async")
370448
async def test_table_cursors_serdes_options_async(
371449
self,

0 commit comments

Comments
 (0)