|
17 | 17 | ) |
18 | 18 | from collections import deque |
19 | 19 |
|
20 | | -from ._typing import T, R, T1, T2, T3, T4, T5, AnyIterable, ADD |
| 20 | +from ._typing import T, R, T1, T2, T3, T4, T5, AnyIterable, ADD, AsyncContextManager |
21 | 21 | from ._utility import public_module |
22 | 22 | from ._core import ( |
23 | 23 | ScopedIter, |
@@ -294,34 +294,53 @@ async def takewhile( |
294 | 294 | break |
295 | 295 |
|
296 | 296 |
|
| 297 | +class NoLock: |
| 298 | + """Dummy lock that provides the proper interface but no protection""" |
| 299 | + |
| 300 | + async def __aenter__(self) -> None: |
| 301 | + pass |
| 302 | + |
| 303 | + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: |
| 304 | + return False |
| 305 | + |
| 306 | + |
297 | 307 | async def tee_peer( |
298 | 308 | iterator: AsyncIterator[T], |
| 309 | + # the buffer specific to this peer |
299 | 310 | buffer: Deque[T], |
| 311 | + # the buffers of all peers, including our own |
300 | 312 | peers: List[Deque[T]], |
| 313 | + lock: AsyncContextManager[Any], |
301 | 314 | ) -> AsyncGenerator[T, None]: |
302 | 315 | """An individual iterator of a :py:func:`~.tee`""" |
303 | 316 | try: |
304 | 317 | while True: |
305 | 318 | if not buffer: |
306 | | - try: |
307 | | - item = await iterator.__anext__() |
308 | | - except StopAsyncIteration: |
309 | | - break |
310 | | - else: |
311 | | - # Append to all buffers, including our own. We'll fetch our |
312 | | - # item from the buffer again, instead of yielding it directly. |
313 | | - # This ensures the proper item ordering if any of our peers |
314 | | - # are fetching items concurrently. They may have buffered their |
315 | | - # item already. |
316 | | - for peer_buffer in peers: |
317 | | - peer_buffer.append(item) |
| 319 | + async with lock: |
| 320 | + # Another peer produced an item while we were waiting for the lock. |
| 321 | + # Proceed with the next loop iteration to yield the item. |
| 322 | + if buffer: |
| 323 | + continue |
| 324 | + try: |
| 325 | + item = await iterator.__anext__() |
| 326 | + except StopAsyncIteration: |
| 327 | + break |
| 328 | + else: |
| 329 | + # Append to all buffers, including our own. We'll fetch our |
| 330 | + # item from the buffer again, instead of yielding it directly. |
| 331 | + # This ensures the proper item ordering if any of our peers |
| 332 | + # are fetching items concurrently. They may have buffered their |
| 333 | + # item already. |
| 334 | + for peer_buffer in peers: |
| 335 | + peer_buffer.append(item) |
318 | 336 | yield buffer.popleft() |
319 | 337 | finally: |
320 | 338 | # this peer is done – remove its buffer |
321 | 339 | for idx, peer_buffer in enumerate(peers): # pragma: no branch |
322 | 340 | if peer_buffer is buffer: |
323 | 341 | peers.pop(idx) |
324 | 342 | break |
| 343 | + # if we are the last peer, try and close the iterator |
325 | 344 | if not peers and hasattr(iterator, "aclose"): |
326 | 345 | await iterator.aclose() # type: ignore |
327 | 346 |
|
@@ -355,29 +374,49 @@ async def derivative(sensor_data): |
355 | 374 | If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not* |
356 | 375 | provide these items. Also, ``tee`` must internally buffer each item until the |
357 | 376 | last iterator has yielded it; if the most and least advanced iterator differ |
358 | | - by most data, using a :py:class:`list` is faster (but not lazy). |
| 377 | + by most data, using a :py:class:`list` is more efficient (but not lazy). |
359 | 378 |
|
360 | 379 | If the underlying iterable is concurrency safe (``anext`` may be awaited |
361 | 380 | concurrently) the resulting iterators are concurrency safe as well. Otherwise, |
362 | 381 | the iterators are safe if there is only ever one single "most advanced" iterator. |
| 382 | + To enforce sequential use of ``anext``, provide a ``lock`` |
| 383 | + - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application - |
| 384 | + and access is automatically synchronised. |
363 | 385 | """ |
364 | 386 |
|
365 | | - def __init__(self, iterable: AnyIterable[T], n: int = 2): |
| 387 | + def __init__( |
| 388 | + self, |
| 389 | + iterable: AnyIterable[T], |
| 390 | + n: int = 2, |
| 391 | + *, |
| 392 | + lock: Optional[AsyncContextManager[Any]] = None, |
| 393 | + ): |
366 | 394 | self._iterator = aiter(iterable) |
367 | 395 | self._buffers: List[Deque[T]] = [deque() for _ in range(n)] |
368 | 396 | self._children = tuple( |
369 | 397 | tee_peer( |
370 | 398 | iterator=self._iterator, |
371 | 399 | buffer=buffer, |
372 | 400 | peers=self._buffers, |
| 401 | + lock=lock if lock is not None else NoLock(), |
373 | 402 | ) |
374 | 403 | for buffer in self._buffers |
375 | 404 | ) |
376 | 405 |
|
377 | 406 | def __len__(self) -> int: |
378 | 407 | return len(self._children) |
379 | 408 |
|
| 409 | + @overload |
380 | 410 | def __getitem__(self, item: int) -> AsyncIterator[T]: |
| 411 | + ... |
| 412 | + |
| 413 | + @overload |
| 414 | + def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: |
| 415 | + ... |
| 416 | + |
| 417 | + def __getitem__( |
| 418 | + self, item: Union[int, slice] |
| 419 | + ) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]: |
381 | 420 | return self._children[item] |
382 | 421 |
|
383 | 422 | def __iter__(self) -> Iterator[AnyIterable[T]]: |
|
0 commit comments