Skip to content

Commit b730600

Browse files
concurrent tee iterators are safe and ordered
1 parent ce5d591 commit b730600

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

asyncstdlib/itertools.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,25 @@ async def __anext__(self) -> T:
380380
async with self._lock:
381381
# Check if another peer produced an item while we were waiting for the lock
382382
if not next_node:
383-
next_node[:] = await self._iterator.__anext__(), []
383+
await self._extend_buffer(next_node)
384+
if not next_node:
385+
raise StopAsyncIteration()
384386
# for any other TeePeer, the node is already some [value, [...]]
385387
value, self._buffer = next_node # type: ignore
386388
return value
387389

390+
async def _extend_buffer(self, next_node: "_TeeNode[T]") -> None:
391+
"""Extend the buffer by fetching a new item from the iterable"""
392+
try:
393+
next_value = await self._iterator.__anext__()
394+
except StopAsyncIteration:
395+
return
396+
# another peer may have filled the buffer while we waited
397+
# seek the last node that needs to be filled
398+
while next_node:
399+
_, next_node = next_node # type: ignore
400+
next_node[:] = next_value, []
401+
388402
async def aclose(self) -> None:
389403
self._tee_peers.discard(self._tee_idx)
390404
if not self._tee_peers and isinstance(self._iterator, ACloseable):

unittests/test_itertools.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import AsyncIterator
12
import itertools
23
import sys
34
import platform
@@ -341,7 +342,7 @@ async def test_tee():
341342

342343
@sync
343344
async def test_tee_concurrent_locked():
344-
"""Test that properly uses a lock for synchronisation"""
345+
"""Test that tee properly uses a lock for synchronisation"""
345346
items = [1, 2, 3, -5, 12, 78, -1, 111]
346347

347348
async def iter_values():
@@ -393,6 +394,41 @@ async def test_peer(peer_tee):
393394
await test_peer(this)
394395

395396

397+
@pytest.mark.parametrize("size", [2, 3, 5, 9, 12])
398+
@sync
399+
async def test_tee_concurrent_ordering(size: int):
400+
"""Test that tee respects concurrent ordering for all peers"""
401+
402+
class ConcurrentInvertedIterable:
403+
"""Helper that concurrently iterates with earlier items taking longer"""
404+
405+
def __init__(self, count: int) -> None:
406+
self.count = count
407+
self._counter = itertools.count()
408+
409+
def __aiter__(self):
410+
return self
411+
412+
async def __anext__(self):
413+
value = next(self._counter)
414+
if value >= self.count:
415+
raise StopAsyncIteration()
416+
await Switch(self.count - value)
417+
return value
418+
419+
async def test_peer(peer_tee: AsyncIterator[int]):
420+
# consume items from the tee with a delay so that slower items can arrive
421+
seen_items: list[int] = []
422+
async for item in peer_tee:
423+
seen_items.append(item)
424+
await Switch()
425+
assert seen_items == expected_items
426+
427+
expected_items = list(range(size)[::-1])
428+
peers = a.tee(ConcurrentInvertedIterable(size), n=size)
429+
await Schedule(*map(test_peer, peers))
430+
431+
396432
@sync
397433
async def test_pairwise():
398434
assert await a.list(a.pairwise(range(5))) == [(0, 1), (1, 2), (2, 3), (3, 4)]

0 commit comments

Comments
 (0)