Skip to content

Commit fda2df8

Browse files
Tee locking (#87)
* optional lock-protection for tee * test loop supports locks * testing tee with and without locks * skip concurrency test for old Python versions * documented lock parameter * added annotations for slicing a tee * update actions versions
1 parent eeb76f2 commit fda2df8

File tree

7 files changed

+142
-26
lines changed

7 files changed

+142
-26
lines changed

.github/workflows/python-publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ jobs:
1818
runs-on: ubuntu-latest
1919

2020
steps:
21-
- uses: actions/checkout@v2
21+
- uses: actions/checkout@v3
2222
- name: Set up Python
23-
uses: actions/setup-python@v2
23+
uses: actions/setup-python@v3
2424
with:
2525
python-version: '3.x'
2626
- name: Install dependencies

.github/workflows/unittests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ jobs:
1010
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.6', 'pypy-3.7']
1111

1212
steps:
13-
- uses: actions/checkout@v2
13+
- uses: actions/checkout@v3
1414
- name: Set up Python ${{ matrix.python-version }}
15-
uses: actions/setup-python@v2
15+
uses: actions/setup-python@v3
1616
with:
1717
python-version: ${{ matrix.python-version }}
1818
- name: Install dependencies

.github/workflows/verification.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ jobs:
66
build:
77
runs-on: ubuntu-latest
88
steps:
9-
- uses: actions/checkout@v2
9+
- uses: actions/checkout@v3
1010
- name: Set up Python
11-
uses: actions/setup-python@v2
11+
uses: actions/setup-python@v3
1212
with:
1313
python-version: '3.9'
1414
- name: Install dependencies

asyncstdlib/itertools.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from collections import deque
1919

20-
from ._typing import T, R, T1, T2, T3, T4, T5, AnyIterable, ADD
20+
from ._typing import T, R, T1, T2, T3, T4, T5, AnyIterable, ADD, AsyncContextManager
2121
from ._utility import public_module
2222
from ._core import (
2323
ScopedIter,
@@ -294,34 +294,53 @@ async def takewhile(
294294
break
295295

296296

297+
class NoLock:
298+
"""Dummy lock that provides the proper interface but no protection"""
299+
300+
async def __aenter__(self) -> None:
301+
pass
302+
303+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
304+
return False
305+
306+
297307
async def tee_peer(
298308
iterator: AsyncIterator[T],
309+
# the buffer specific to this peer
299310
buffer: Deque[T],
311+
# the buffers of all peers, including our own
300312
peers: List[Deque[T]],
313+
lock: AsyncContextManager[Any],
301314
) -> AsyncGenerator[T, None]:
302315
"""An individual iterator of a :py:func:`~.tee`"""
303316
try:
304317
while True:
305318
if not buffer:
306-
try:
307-
item = await iterator.__anext__()
308-
except StopAsyncIteration:
309-
break
310-
else:
311-
# Append to all buffers, including our own. We'll fetch our
312-
# item from the buffer again, instead of yielding it directly.
313-
# This ensures the proper item ordering if any of our peers
314-
# are fetching items concurrently. They may have buffered their
315-
# item already.
316-
for peer_buffer in peers:
317-
peer_buffer.append(item)
319+
async with lock:
320+
# Another peer produced an item while we were waiting for the lock.
321+
# Proceed with the next loop iteration to yield the item.
322+
if buffer:
323+
continue
324+
try:
325+
item = await iterator.__anext__()
326+
except StopAsyncIteration:
327+
break
328+
else:
329+
# Append to all buffers, including our own. We'll fetch our
330+
# item from the buffer again, instead of yielding it directly.
331+
# This ensures the proper item ordering if any of our peers
332+
# are fetching items concurrently. They may have buffered their
333+
# item already.
334+
for peer_buffer in peers:
335+
peer_buffer.append(item)
318336
yield buffer.popleft()
319337
finally:
320338
# this peer is done – remove its buffer
321339
for idx, peer_buffer in enumerate(peers): # pragma: no branch
322340
if peer_buffer is buffer:
323341
peers.pop(idx)
324342
break
343+
# if we are the last peer, try and close the iterator
325344
if not peers and hasattr(iterator, "aclose"):
326345
await iterator.aclose() # type: ignore
327346

@@ -355,29 +374,49 @@ async def derivative(sensor_data):
355374
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
356375
provide these items. Also, ``tee`` must internally buffer each item until the
357376
last iterator has yielded it; if the most and least advanced iterator differ
358-
by most data, using a :py:class:`list` is faster (but not lazy).
377+
by most data, using a :py:class:`list` is more efficient (but not lazy).
359378
360379
If the underlying iterable is concurrency safe (``anext`` may be awaited
361380
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
362381
the iterators are safe if there is only ever one single "most advanced" iterator.
382+
To enforce sequential use of ``anext``, provide a ``lock``
383+
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
384+
and access is automatically synchronised.
363385
"""
364386

365-
def __init__(self, iterable: AnyIterable[T], n: int = 2):
387+
def __init__(
388+
self,
389+
iterable: AnyIterable[T],
390+
n: int = 2,
391+
*,
392+
lock: Optional[AsyncContextManager[Any]] = None,
393+
):
366394
self._iterator = aiter(iterable)
367395
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
368396
self._children = tuple(
369397
tee_peer(
370398
iterator=self._iterator,
371399
buffer=buffer,
372400
peers=self._buffers,
401+
lock=lock if lock is not None else NoLock(),
373402
)
374403
for buffer in self._buffers
375404
)
376405

377406
def __len__(self) -> int:
378407
return len(self._children)
379408

409+
@overload
380410
def __getitem__(self, item: int) -> AsyncIterator[T]:
411+
...
412+
413+
@overload
414+
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]:
415+
...
416+
417+
def __getitem__(
418+
self, item: Union[int, slice]
419+
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
381420
return self._children[item]
382421

383422
def __iter__(self) -> Iterator[AnyIterable[T]]:

docs/source/api/itertools.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,13 @@ Iterator transforming
6565
Iterator splitting
6666
==================
6767

68-
.. autofunction:: tee(iterable: (async) iter T, n: int = 2)
68+
.. autofunction:: tee(iterable: (async) iter T, n: int = 2, [*, lock: async with Any])
6969
:for: :(async iter T, ...)
7070

71+
.. versionadded:: 3.10.5
72+
73+
The ``lock`` keyword parameter.
74+
7175
.. autofunction:: pairwise(iterable: (async) iter T)
7276
:async-for: :(T, T)
7377

unittests/test_itertools.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import itertools
2+
import sys
23

34
import pytest
45

56
import asyncstdlib as a
67

7-
from .utility import sync, asyncify, awaitify
8+
from .utility import sync, asyncify, awaitify, multi_sync, Schedule, Switch, Lock
89

910

1011
@sync
@@ -210,6 +211,56 @@ async def test_tee():
210211
assert await a.list(iterator) == iterable
211212

212213

214+
@multi_sync
215+
async def test_tee_concurrent_locked():
216+
"""Test that properly uses a lock for synchronisation"""
217+
items = [1, 2, 3, -5, 12, 78, -1, 111]
218+
219+
async def iter_values():
220+
for item in items:
221+
# switch to other tasks a few times to guarantees another runs
222+
for _ in range(5):
223+
await Switch()
224+
yield item
225+
226+
async def test_peer(peer_tee):
227+
assert await a.list(peer_tee) == items
228+
229+
head_peer, *peers = a.tee(iter_values(), n=len(items) // 2, lock=Lock())
230+
await Schedule(*map(test_peer, peers))
231+
await Switch()
232+
results = [item async for item in head_peer]
233+
assert results == items
234+
235+
236+
# see https://github.com/python/cpython/issues/74956
237+
@pytest.mark.skipif(
238+
sys.version_info < (3, 8),
239+
reason="async generators only protect against concurrent access since 3.8",
240+
)
241+
@multi_sync
242+
async def test_tee_concurrent_unlocked():
243+
"""Test that does not prevent concurrency without a lock"""
244+
items = list(range(12))
245+
246+
async def iter_values():
247+
for item in items:
248+
# switch to other tasks a few times to guarantees another runs
249+
for _ in range(5):
250+
await Switch()
251+
yield item
252+
253+
async def test_peer(peer_tee):
254+
assert await a.list(peer_tee) == items
255+
256+
this, peer = a.tee(iter_values(), n=2)
257+
await Schedule(test_peer(peer))
258+
await Switch()
259+
# underlying generator raises RuntimeError when `__anext__` is interleaved
260+
with pytest.raises(RuntimeError):
261+
await test_peer(this)
262+
263+
213264
@sync
214265
async def test_pairwise():
215266
assert await a.list(a.pairwise(range(5))) == [(0, 1), (1, 2), (2, 3), (3, 4)]

unittests/utility.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,34 @@ def __await__(self):
8585
yield self
8686

8787

88+
class Lock:
89+
def __init__(self):
90+
self._owned = False
91+
self._waiting = []
92+
93+
async def __aenter__(self):
94+
if self._owned:
95+
# wait until it is our turn to take the lock
96+
token = object()
97+
self._waiting.append(token)
98+
while self._owned or self._waiting[0] is not token:
99+
await Switch()
100+
# take the lock and remove our wait claim
101+
self._owned = True
102+
self._waiting.pop(0)
103+
self._owned = True
104+
105+
async def __aexit__(self, exc_type, exc_val, exc_tb):
106+
self._owned = False
107+
108+
88109
def multi_sync(test_case: Callable[..., Coroutine]):
89110
"""
90-
Mark an ``async def`` test case to be run synchronously with chicldren
111+
Mark an ``async def`` test case to be run synchronously with children
91112
92113
This emulates a primitive "event loop" which only responds
93-
to the :py:class:`PingPong`, :py:class:`Schedule` and :py:class:`Switch`.
114+
to the :py:class:`PingPong`, :py:class:`Schedule`, :py:class:`Switch`
115+
and :py:class:`Lock`.
94116
"""
95117

96118
@wraps(test_case)
@@ -103,7 +125,7 @@ def run_sync(*args, **kwargs):
103125
event = coro.send(event)
104126
except StopIteration as e:
105127
result = e.args[0] if e.args else None
106-
assert result is None
128+
assert result is None, f"got '{result!r}' expected 'None'"
107129
else:
108130
if isinstance(event, PingPong):
109131
run_queue.appendleft((coro, event))

0 commit comments

Comments
 (0)