| 
8 | 8 | import asyncio  | 
9 | 9 | import datetime  | 
10 | 10 | import logging  | 
 | 11 | +from contextlib import asynccontextmanager  | 
11 | 12 | from typing import Literal  | 
12 | 13 | 
 
  | 
13 | 14 | import pytest  | 
 | 
21 | 22 |     DistributedSemaphore,  | 
22 | 23 |     SemaphoreAcquisitionError,  | 
23 | 24 | )  | 
24 |  | -from servicelib.redis._semaphore_decorator import with_limited_concurrency  | 
 | 25 | +from servicelib.redis._semaphore_decorator import (  | 
 | 26 | +    with_limited_concurrency,  | 
 | 27 | +    with_limited_concurrency_cm,  | 
 | 28 | +)  | 
25 | 29 | 
 
  | 
26 | 30 | pytest_simcore_core_services_selection = [  | 
27 | 31 |     "redis",  | 
@@ -389,3 +393,290 @@ async def limited_function() -> None:  | 
389 | 393 | 
 
  | 
390 | 394 |     # Should never exceed the large capacity  | 
391 | 395 |     assert max_concurrent <= large_capacity  | 
 | 396 | + | 
 | 397 | + | 
 | 398 | +async def test_context_manager_basic_functionality(  | 
 | 399 | +    redis_client_sdk: RedisClientSDK,  | 
 | 400 | +    semaphore_name: str,  | 
 | 401 | +):  | 
 | 402 | +    call_count = 0  | 
 | 403 | + | 
 | 404 | +    @asynccontextmanager  | 
 | 405 | +    @with_limited_concurrency_cm(  | 
 | 406 | +        redis_client_sdk,  | 
 | 407 | +        key=semaphore_name,  | 
 | 408 | +        capacity=1,  | 
 | 409 | +    )  | 
 | 410 | +    async def limited_context_manager():  | 
 | 411 | +        nonlocal call_count  | 
 | 412 | +        call_count += 1  | 
 | 413 | +        yield call_count  | 
 | 414 | + | 
 | 415 | +    # Multiple concurrent context managers  | 
 | 416 | +    async def use_context_manager():  | 
 | 417 | +        async with limited_context_manager() as value:  | 
 | 418 | +            await asyncio.sleep(0.1)  | 
 | 419 | +            return value  | 
 | 420 | + | 
 | 421 | +    tasks = [asyncio.create_task(use_context_manager()) for _ in range(3)]  | 
 | 422 | +    results = await asyncio.gather(*tasks)  | 
 | 423 | + | 
 | 424 | +    # All should complete successfully  | 
 | 425 | +    assert len(results) == 3  | 
 | 426 | +    assert all(isinstance(r, int) for r in results)  | 
 | 427 | + | 
 | 428 | + | 
 | 429 | +async def test_context_manager_capacity_enforcement(  | 
 | 430 | +    redis_client_sdk: RedisClientSDK,  | 
 | 431 | +    semaphore_name: str,  | 
 | 432 | +):  | 
 | 433 | +    concurrent_count = 0  | 
 | 434 | +    max_concurrent = 0  | 
 | 435 | + | 
 | 436 | +    @asynccontextmanager  | 
 | 437 | +    @with_limited_concurrency_cm(  | 
 | 438 | +        redis_client_sdk,  | 
 | 439 | +        key=semaphore_name,  | 
 | 440 | +        capacity=2,  | 
 | 441 | +    )  | 
 | 442 | +    async def limited_context_manager():  | 
 | 443 | +        nonlocal concurrent_count, max_concurrent  | 
 | 444 | +        concurrent_count += 1  | 
 | 445 | +        max_concurrent = max(max_concurrent, concurrent_count)  | 
 | 446 | +        try:  | 
 | 447 | +            yield  | 
 | 448 | +            await asyncio.sleep(0.1)  | 
 | 449 | +        finally:  | 
 | 450 | +            concurrent_count -= 1  | 
 | 451 | + | 
 | 452 | +    async def use_context_manager():  | 
 | 453 | +        async with limited_context_manager():  | 
 | 454 | +            await asyncio.sleep(0.1)  | 
 | 455 | + | 
 | 456 | +    # Start 5 concurrent context managers  | 
 | 457 | +    tasks = [asyncio.create_task(use_context_manager()) for _ in range(5)]  | 
 | 458 | +    await asyncio.gather(*tasks)  | 
 | 459 | + | 
 | 460 | +    # Should never exceed capacity of 2  | 
 | 461 | +    assert max_concurrent <= 2  | 
 | 462 | + | 
 | 463 | + | 
 | 464 | +async def test_context_manager_exception_handling(  | 
 | 465 | +    redis_client_sdk: RedisClientSDK,  | 
 | 466 | +    semaphore_name: str,  | 
 | 467 | +):  | 
 | 468 | +    @asynccontextmanager  | 
 | 469 | +    @with_limited_concurrency_cm(  | 
 | 470 | +        redis_client_sdk,  | 
 | 471 | +        key=semaphore_name,  | 
 | 472 | +        capacity=1,  | 
 | 473 | +    )  | 
 | 474 | +    async def failing_context_manager():  | 
 | 475 | +        yield  | 
 | 476 | +        raise RuntimeError("Test exception")  | 
 | 477 | + | 
 | 478 | +    with pytest.raises(RuntimeError, match="Test exception"):  | 
 | 479 | +        async with failing_context_manager():  | 
 | 480 | +            pass  | 
 | 481 | + | 
 | 482 | +    # Semaphore should be released even after exception  | 
 | 483 | +    @asynccontextmanager  | 
 | 484 | +    @with_limited_concurrency_cm(  | 
 | 485 | +        redis_client_sdk,  | 
 | 486 | +        key=semaphore_name,  | 
 | 487 | +        capacity=1,  | 
 | 488 | +    )  | 
 | 489 | +    async def success_context_manager():  | 
 | 490 | +        yield "success"  | 
 | 491 | + | 
 | 492 | +    async with success_context_manager() as result:  | 
 | 493 | +        assert result == "success"  | 
 | 494 | + | 
 | 495 | + | 
 | 496 | +async def test_context_manager_auto_renewal(  | 
 | 497 | +    redis_client_sdk: RedisClientSDK,  | 
 | 498 | +    semaphore_name: str,  | 
 | 499 | +    semaphore_capacity: int,  | 
 | 500 | +    short_ttl: datetime.timedelta,  | 
 | 501 | +):  | 
 | 502 | +    work_started = asyncio.Event()  | 
 | 503 | +    work_completed = asyncio.Event()  | 
 | 504 | + | 
 | 505 | +    @asynccontextmanager  | 
 | 506 | +    @with_limited_concurrency_cm(  | 
 | 507 | +        redis_client_sdk,  | 
 | 508 | +        key=semaphore_name,  | 
 | 509 | +        capacity=semaphore_capacity,  | 
 | 510 | +        ttl=short_ttl,  | 
 | 511 | +    )  | 
 | 512 | +    async def long_running_context_manager():  | 
 | 513 | +        work_started.set()  | 
 | 514 | +        yield "data"  | 
 | 515 | +        # Wait longer than TTL to ensure renewal works  | 
 | 516 | +        await asyncio.sleep(short_ttl.total_seconds() * 2)  | 
 | 517 | +        work_completed.set()  | 
 | 518 | + | 
 | 519 | +    async def use_long_running_cm():  | 
 | 520 | +        async with long_running_context_manager() as data:  | 
 | 521 | +            assert data == "data"  | 
 | 522 | +            # Keep context manager active for longer than TTL  | 
 | 523 | +            await asyncio.sleep(short_ttl.total_seconds() * 1.5)  | 
 | 524 | + | 
 | 525 | +    task = asyncio.create_task(use_long_running_cm())  | 
 | 526 | +    await work_started.wait()  | 
 | 527 | + | 
 | 528 | +    # Check that semaphore is being held  | 
 | 529 | +    temp_semaphore = DistributedSemaphore(  | 
 | 530 | +        redis_client=redis_client_sdk,  | 
 | 531 | +        key=semaphore_name,  | 
 | 532 | +        capacity=semaphore_capacity,  | 
 | 533 | +        ttl=short_ttl,  | 
 | 534 | +    )  | 
 | 535 | +    assert await temp_semaphore.get_current_count() == 1  | 
 | 536 | +    assert await temp_semaphore.get_available_count() == semaphore_capacity - 1  | 
 | 537 | + | 
 | 538 | +    # Wait for work to complete  | 
 | 539 | +    await task  | 
 | 540 | +    assert work_completed.is_set()  | 
 | 541 | + | 
 | 542 | +    # After completion, semaphore should be released  | 
 | 543 | +    assert await temp_semaphore.get_current_count() == 0  | 
 | 544 | +    assert await temp_semaphore.get_available_count() == semaphore_capacity  | 
 | 545 | + | 
 | 546 | + | 
 | 547 | +async def test_context_manager_with_callable_parameters(  | 
 | 548 | +    redis_client_sdk: RedisClientSDK,  | 
 | 549 | +):  | 
 | 550 | +    executed_keys = []  | 
 | 551 | + | 
 | 552 | +    def get_redis_client(*args, **kwargs):  | 
 | 553 | +        return redis_client_sdk  | 
 | 554 | + | 
 | 555 | +    def get_key(user_id: str, resource: str) -> str:  | 
 | 556 | +        return f"{user_id}-{resource}"  | 
 | 557 | + | 
 | 558 | +    def get_capacity(user_id: str, resource: str) -> int:  | 
 | 559 | +        return 2  | 
 | 560 | + | 
 | 561 | +    @asynccontextmanager  | 
 | 562 | +    @with_limited_concurrency_cm(  | 
 | 563 | +        get_redis_client,  | 
 | 564 | +        key=get_key,  | 
 | 565 | +        capacity=get_capacity,  | 
 | 566 | +    )  | 
 | 567 | +    async def process_user_resource_cm(user_id: str, resource: str):  | 
 | 568 | +        executed_keys.append(f"{user_id}-{resource}")  | 
 | 569 | +        yield f"processed-{user_id}-{resource}"  | 
 | 570 | +        await asyncio.sleep(0.05)  | 
 | 571 | + | 
 | 572 | +    async def use_cm(user_id: str, resource: str):  | 
 | 573 | +        async with process_user_resource_cm(user_id, resource) as result:  | 
 | 574 | +            return result  | 
 | 575 | + | 
 | 576 | +    # Test with different parameters  | 
 | 577 | +    results = await asyncio.gather(  | 
 | 578 | +        use_cm("user1", "wallet1"),  | 
 | 579 | +        use_cm("user1", "wallet2"),  | 
 | 580 | +        use_cm("user2", "wallet1"),  | 
 | 581 | +    )  | 
 | 582 | + | 
 | 583 | +    assert len(executed_keys) == 3  | 
 | 584 | +    assert "user1-wallet1" in executed_keys  | 
 | 585 | +    assert "user1-wallet2" in executed_keys  | 
 | 586 | +    assert "user2-wallet1" in executed_keys  | 
 | 587 | + | 
 | 588 | +    assert len(results) == 3  | 
 | 589 | +    assert "processed-user1-wallet1" in results  | 
 | 590 | +    assert "processed-user1-wallet2" in results  | 
 | 591 | +    assert "processed-user2-wallet1" in results  | 
 | 592 | + | 
 | 593 | + | 
 | 594 | +async def test_context_manager_non_blocking_behavior(  | 
 | 595 | +    redis_client_sdk: RedisClientSDK,  | 
 | 596 | +    semaphore_name: str,  | 
 | 597 | +):  | 
 | 598 | +    started_event = asyncio.Event()  | 
 | 599 | + | 
 | 600 | +    @asynccontextmanager  | 
 | 601 | +    @with_limited_concurrency_cm(  | 
 | 602 | +        redis_client_sdk,  | 
 | 603 | +        key=semaphore_name,  | 
 | 604 | +        capacity=1,  | 
 | 605 | +        blocking=True,  | 
 | 606 | +        blocking_timeout=datetime.timedelta(seconds=0.1),  | 
 | 607 | +    )  | 
 | 608 | +    async def limited_context_manager():  | 
 | 609 | +        started_event.set()  | 
 | 610 | +        yield  | 
 | 611 | +        await asyncio.sleep(2)  | 
 | 612 | + | 
 | 613 | +    # Start first context manager that will hold the semaphore  | 
 | 614 | +    async def long_running_cm():  | 
 | 615 | +        async with limited_context_manager():  | 
 | 616 | +            await asyncio.sleep(2)  | 
 | 617 | + | 
 | 618 | +    task1 = asyncio.create_task(long_running_cm())  | 
 | 619 | +    await started_event.wait()  # Wait until semaphore is actually acquired  | 
 | 620 | + | 
 | 621 | +    # Second context manager should timeout and raise an exception  | 
 | 622 | +    @asynccontextmanager  | 
 | 623 | +    @with_limited_concurrency_cm(  | 
 | 624 | +        redis_client_sdk,  | 
 | 625 | +        key=semaphore_name,  | 
 | 626 | +        capacity=1,  | 
 | 627 | +        blocking=True,  | 
 | 628 | +        blocking_timeout=datetime.timedelta(seconds=0.1),  | 
 | 629 | +    )  | 
 | 630 | +    async def timeout_context_manager():  | 
 | 631 | +        yield  | 
 | 632 | + | 
 | 633 | +    with pytest.raises(SemaphoreAcquisitionError):  | 
 | 634 | +        async with timeout_context_manager():  | 
 | 635 | +            pass  | 
 | 636 | + | 
 | 637 | +    await task1  | 
 | 638 | + | 
 | 639 | + | 
 | 640 | +async def test_context_manager_lose_semaphore_raises(  | 
 | 641 | +    redis_client_sdk: RedisClientSDK,  | 
 | 642 | +    semaphore_name: str,  | 
 | 643 | +    semaphore_capacity: int,  | 
 | 644 | +    short_ttl: datetime.timedelta,  | 
 | 645 | +):  | 
 | 646 | +    work_started = asyncio.Event()  | 
 | 647 | + | 
 | 648 | +    @asynccontextmanager  | 
 | 649 | +    @with_limited_concurrency_cm(  | 
 | 650 | +        redis_client_sdk,  | 
 | 651 | +        key=semaphore_name,  | 
 | 652 | +        capacity=semaphore_capacity,  | 
 | 653 | +        ttl=short_ttl,  | 
 | 654 | +    )  | 
 | 655 | +    async def context_manager_that_should_fail():  | 
 | 656 | +        work_started.set()  | 
 | 657 | +        yield "data"  | 
 | 658 | +        # Wait long enough for renewal to be attempted multiple times  | 
 | 659 | +        await asyncio.sleep(short_ttl.total_seconds() * 4)  | 
 | 660 | + | 
 | 661 | +    async def use_failing_cm():  | 
 | 662 | +        async with context_manager_that_should_fail() as data:  | 
 | 663 | +            assert data == "data"  | 
 | 664 | +            # Keep context active while semaphore will be lost  | 
 | 665 | +            await asyncio.sleep(short_ttl.total_seconds() * 3)  | 
 | 666 | + | 
 | 667 | +    task = asyncio.create_task(use_failing_cm())  | 
 | 668 | +    await work_started.wait()  | 
 | 669 | + | 
 | 670 | +    # Wait for the first renewal interval to pass  | 
 | 671 | +    renewal_interval = short_ttl / 3  | 
 | 672 | +    await asyncio.sleep(renewal_interval.total_seconds() + 1)  | 
 | 673 | + | 
 | 674 | +    # Find and delete all holder keys for this semaphore  | 
 | 675 | +    holder_keys = await redis_client_sdk.redis.keys(  | 
 | 676 | +        f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:*"  | 
 | 677 | +    )  | 
 | 678 | +    assert holder_keys, "Holder keys should exist before deletion"  | 
 | 679 | +    await redis_client_sdk.redis.delete(*holder_keys)  | 
 | 680 | + | 
 | 681 | +    with pytest.raises(SemaphoreLostError):  # Expected from the ExceptionGroup handling  | 
 | 682 | +        await task  | 
0 commit comments