88 Union ,
99 Callable ,
1010 Optional ,
11- Deque ,
1211 Generic ,
1312 Iterable ,
1413 Iterator ,
1716 overload ,
1817 AsyncGenerator ,
1918)
20- from collections import deque
19+ from typing_extensions import TypeAlias
2120
2221from ._typing import ACloseable , R , T , AnyIterable , ADD
2322from ._utility import public_module
3231 enumerate as aenumerate ,
3332 iter as aiter ,
3433)
34+ from itertools import count as _count
3535
3636S = TypeVar ("S" )
3737T_co = TypeVar ("T_co" , covariant = True )
@@ -346,45 +346,52 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
346346 return None
347347
348348
349- async def tee_peer (
350- iterator : AsyncIterator [T ],
351- # the buffer specific to this peer
352- buffer : Deque [T ],
353- # the buffers of all peers, including our own
354- peers : List [Deque [T ]],
355- lock : AsyncContextManager [Any ],
356- ) -> AsyncGenerator [T , None ]:
357- """An individual iterator of a :py:func:`~.tee`"""
358- try :
359- while True :
360- if not buffer :
361- async with lock :
362- # Another peer produced an item while we were waiting for the lock.
363- # Proceed with the next loop iteration to yield the item.
364- if buffer :
365- continue
366- try :
367- item = await iterator .__anext__ ()
368- except StopAsyncIteration :
369- break
370- else :
371- # Append to all buffers, including our own. We'll fetch our
372- # item from the buffer again, instead of yielding it directly.
373- # This ensures the proper item ordering if any of our peers
374- # are fetching items concurrently. They may have buffered their
375- # item already.
376- for peer_buffer in peers :
377- peer_buffer .append (item )
378- yield buffer .popleft ()
379- finally :
380- # this peer is done – remove its buffer
381- for idx , peer_buffer in enumerate (peers ): # pragma: no branch
382- if peer_buffer is buffer :
383- peers .pop (idx )
384- break
385- # if we are the last peer, try and close the iterator
386- if not peers and isinstance (iterator , ACloseable ):
387- await iterator .aclose ()
349+ _get_tee_index = _count ().__next__
350+
351+
352+ Node : TypeAlias = "list[T | Node[T]]"
353+
354+
355+ class TeePeer (Generic [T ]):
356+ def __init__ (
357+ self ,
358+ iterator : AsyncIterator [T ],
359+ buffer : "Node[T]" ,
360+ lock : AsyncContextManager [Any ],
361+ tee_peers : "set[int]" ,
362+ ) -> None :
363+ self .iterator = iterator
364+ self .lock = lock
365+ self .buffer : Node [T ] = buffer
366+ self .tee_peers = tee_peers
367+ self .tee_idx = _get_tee_index ()
368+ self .tee_peers .add (self .tee_idx )
369+
370+ def __aiter__ (self ):
371+ return self
372+
373+ async def __anext__ (self ) -> T :
374+ # the buffer is a singly-linked list as [value, [value, [...]]] | []
375+ next_node = self .buffer
376+ value : T
377+ # for any most advanced TeePeer, the node is just []
378+ # fetch the next value so we can mutate the node to [value, [...]]
379+ if not next_node :
380+ async with self .lock :
381+ # Check if another peer produced an item while we were waiting for the lock
382+ if not next_node :
383+ next_node [:] = await self .iterator .__anext__ (), []
384+ # for any other TeePeer, the node is already some [value, [...]]
385+ value , self .buffer = next_node # type: ignore
386+ return value
387+
388+ async def aclose (self ) -> None :
389+ self .tee_peers .discard (self .tee_idx )
390+ if not self .tee_peers and isinstance (self .iterator , ACloseable ):
391+ await self .iterator .aclose ()
392+
393+ def __del__ (self ) -> None :
394+ self .tee_peers .discard (self .tee_idx )
388395
389396
390397@public_module (__name__ , "tee" )
@@ -426,7 +433,7 @@ async def derivative(sensor_data):
426433 and access is automatically synchronised.
427434 """
428435
429- __slots__ = ("_iterator" , "_buffers " , "_children" )
436+ __slots__ = ("_iterator" , "_buffer " , "_children" )
430437
431438 def __init__ (
432439 self ,
@@ -436,15 +443,16 @@ def __init__(
436443 lock : Optional [AsyncContextManager [Any ]] = None ,
437444 ):
438445 self ._iterator = aiter (iterable )
439- self ._buffers : List [Deque [T ]] = [deque () for _ in range (n )]
446+ self ._buffer : Node [T ] = []
447+ peers : set [int ] = set ()
440448 self ._children = tuple (
441- tee_peer (
442- iterator = self ._iterator ,
443- buffer = buffer ,
444- peers = self . _buffers ,
445- lock = lock if lock is not None else NoLock () ,
449+ TeePeer (
450+ self ._iterator ,
451+ self . _buffer ,
452+ lock if lock is not None else NoLock () ,
453+ peers ,
446454 )
447- for buffer in self . _buffers
455+ for _ in range ( n )
448456 )
449457
450458 def __len__ (self ) -> int :
0 commit comments