Skip to content

Commit e60704f

Browse files
draft for shared tee buffer
1 parent 56a90d7 commit e60704f

File tree

1 file changed

+57
-49
lines changed

1 file changed

+57
-49
lines changed

asyncstdlib/itertools.py

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Union,
99
Callable,
1010
Optional,
11-
Deque,
1211
Generic,
1312
Iterable,
1413
Iterator,
@@ -17,7 +16,7 @@
1716
overload,
1817
AsyncGenerator,
1918
)
20-
from collections import deque
19+
from typing_extensions import TypeAlias
2120

2221
from ._typing import ACloseable, R, T, AnyIterable, ADD
2322
from ._utility import public_module
@@ -32,6 +31,7 @@
3231
enumerate as aenumerate,
3332
iter as aiter,
3433
)
34+
from itertools import count as _count
3535

3636
S = TypeVar("S")
3737
T_co = TypeVar("T_co", covariant=True)
@@ -346,45 +346,52 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
346346
return None
347347

348348

349-
async def tee_peer(
350-
iterator: AsyncIterator[T],
351-
# the buffer specific to this peer
352-
buffer: Deque[T],
353-
# the buffers of all peers, including our own
354-
peers: List[Deque[T]],
355-
lock: AsyncContextManager[Any],
356-
) -> AsyncGenerator[T, None]:
357-
"""An individual iterator of a :py:func:`~.tee`"""
358-
try:
359-
while True:
360-
if not buffer:
361-
async with lock:
362-
# Another peer produced an item while we were waiting for the lock.
363-
# Proceed with the next loop iteration to yield the item.
364-
if buffer:
365-
continue
366-
try:
367-
item = await iterator.__anext__()
368-
except StopAsyncIteration:
369-
break
370-
else:
371-
# Append to all buffers, including our own. We'll fetch our
372-
# item from the buffer again, instead of yielding it directly.
373-
# This ensures the proper item ordering if any of our peers
374-
# are fetching items concurrently. They may have buffered their
375-
# item already.
376-
for peer_buffer in peers:
377-
peer_buffer.append(item)
378-
yield buffer.popleft()
379-
finally:
380-
# this peer is done – remove its buffer
381-
for idx, peer_buffer in enumerate(peers): # pragma: no branch
382-
if peer_buffer is buffer:
383-
peers.pop(idx)
384-
break
385-
# if we are the last peer, try and close the iterator
386-
if not peers and isinstance(iterator, ACloseable):
387-
await iterator.aclose()
349+
_get_tee_index = _count().__next__
350+
351+
352+
Node: TypeAlias = "list[T | Node[T]]"
353+
354+
355+
class TeePeer(Generic[T]):
356+
def __init__(
357+
self,
358+
iterator: AsyncIterator[T],
359+
buffer: "Node[T]",
360+
lock: AsyncContextManager[Any],
361+
tee_peers: "set[int]",
362+
) -> None:
363+
self.iterator = iterator
364+
self.lock = lock
365+
self.buffer: Node[T] = buffer
366+
self.tee_peers = tee_peers
367+
self.tee_idx = _get_tee_index()
368+
self.tee_peers.add(self.tee_idx)
369+
370+
def __aiter__(self):
371+
return self
372+
373+
async def __anext__(self) -> T:
374+
# the buffer is a singly-linked list as [value, [value, [...]]] | []
375+
next_node = self.buffer
376+
value: T
377+
# for any most advanced TeePeer, the node is just []
378+
# fetch the next value so we can mutate the node to [value, [...]]
379+
if not next_node:
380+
async with self.lock:
381+
# Check if another peer produced an item while we were waiting for the lock
382+
if not next_node:
383+
next_node[:] = await self.iterator.__anext__(), []
384+
# for any other TeePeer, the node is already some [value, [...]]
385+
value, self.buffer = next_node # type: ignore
386+
return value
387+
388+
async def aclose(self) -> None:
389+
self.tee_peers.discard(self.tee_idx)
390+
if not self.tee_peers and isinstance(self.iterator, ACloseable):
391+
await self.iterator.aclose()
392+
393+
def __del__(self) -> None:
394+
self.tee_peers.discard(self.tee_idx)
388395

389396

390397
@public_module(__name__, "tee")
@@ -426,7 +433,7 @@ async def derivative(sensor_data):
426433
and access is automatically synchronised.
427434
"""
428435

429-
__slots__ = ("_iterator", "_buffers", "_children")
436+
__slots__ = ("_iterator", "_buffer", "_children")
430437

431438
def __init__(
432439
self,
@@ -436,15 +443,16 @@ def __init__(
436443
lock: Optional[AsyncContextManager[Any]] = None,
437444
):
438445
self._iterator = aiter(iterable)
439-
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
446+
self._buffer: Node[T] = []
447+
peers: set[int] = set()
440448
self._children = tuple(
441-
tee_peer(
442-
iterator=self._iterator,
443-
buffer=buffer,
444-
peers=self._buffers,
445-
lock=lock if lock is not None else NoLock(),
449+
TeePeer(
450+
self._iterator,
451+
self._buffer,
452+
lock if lock is not None else NoLock(),
453+
peers,
446454
)
447-
for buffer in self._buffers
455+
for _ in range(n)
448456
)
449457

450458
def __len__(self) -> int:

0 commit comments

Comments
 (0)