|
| 1 | +from typing import AsyncIterator |
1 | 2 | import itertools |
2 | 3 | import sys |
3 | 4 | import platform |
@@ -341,7 +342,7 @@ async def test_tee(): |
341 | 342 |
|
342 | 343 | @sync |
343 | 344 | async def test_tee_concurrent_locked(): |
344 | | - """Test that properly uses a lock for synchronisation""" |
| 345 | + """Test that tee properly uses a lock for synchronisation""" |
345 | 346 | items = [1, 2, 3, -5, 12, 78, -1, 111] |
346 | 347 |
|
347 | 348 | async def iter_values(): |
@@ -393,6 +394,41 @@ async def test_peer(peer_tee): |
393 | 394 | await test_peer(this) |
394 | 395 |
|
395 | 396 |
|
| 397 | +@pytest.mark.parametrize("size", [2, 3, 5, 9, 12]) |
| 398 | +@sync |
| 399 | +async def test_tee_concurrent_ordering(size: int): |
| 400 | + """Test that tee respects concurrent ordering for all peers""" |
| 401 | + |
| 402 | + class ConcurrentInvertedIterable: |
| 403 | + """Helper that concurrently iterates with earlier items taking longer""" |
| 404 | + |
| 405 | + def __init__(self, count: int) -> None: |
| 406 | + self.count = count |
| 407 | + self._counter = itertools.count() |
| 408 | + |
| 409 | + def __aiter__(self): |
| 410 | + return self |
| 411 | + |
| 412 | + async def __anext__(self): |
| 413 | + value = next(self._counter) |
| 414 | + if value >= self.count: |
| 415 | + raise StopAsyncIteration() |
| 416 | + await Switch(self.count - value) |
| 417 | + return value |
| 418 | + |
| 419 | + async def test_peer(peer_tee: AsyncIterator[int]): |
| 420 | + # consume items from the tee with a delay so that slower items can arrive |
| 421 | + seen_items: list[int] = [] |
| 422 | + async for item in peer_tee: |
| 423 | + seen_items.append(item) |
| 424 | + await Switch() |
| 425 | + assert seen_items == expected_items |
| 426 | + |
| 427 | + expected_items = list(range(size)[::-1]) |
| 428 | + peers = a.tee(ConcurrentInvertedIterable(size), n=size) |
| 429 | + await Schedule(*map(test_peer, peers)) |
| 430 | + |
| 431 | + |
396 | 432 | @sync |
397 | 433 | async def test_pairwise(): |
398 | 434 | assert await a.list(a.pairwise(range(5))) == [(0, 1), (1, 2), (2, 3), (3, 4)] |
|
0 commit comments