Skip to content

Commit a7db925

Browse files
committed
added tests for context manager
1 parent 72140e2 commit a7db925

File tree

1 file changed

+292
-1
lines changed

1 file changed

+292
-1
lines changed

packages/service-library/tests/redis/test_semaphore_decorator.py

Lines changed: 292 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import datetime
1010
import logging
11+
from contextlib import asynccontextmanager
1112
from typing import Literal
1213

1314
import pytest
@@ -21,7 +22,10 @@
2122
DistributedSemaphore,
2223
SemaphoreAcquisitionError,
2324
)
24-
from servicelib.redis._semaphore_decorator import with_limited_concurrency
25+
from servicelib.redis._semaphore_decorator import (
26+
with_limited_concurrency,
27+
with_limited_concurrency_cm,
28+
)
2529

2630
pytest_simcore_core_services_selection = [
2731
"redis",
@@ -389,3 +393,290 @@ async def limited_function() -> None:
389393

390394
# Should never exceed the large capacity
391395
assert max_concurrent <= large_capacity
396+
397+
398+
async def test_context_manager_basic_functionality(
399+
redis_client_sdk: RedisClientSDK,
400+
semaphore_name: str,
401+
):
402+
call_count = 0
403+
404+
@asynccontextmanager
405+
@with_limited_concurrency_cm(
406+
redis_client_sdk,
407+
key=semaphore_name,
408+
capacity=1,
409+
)
410+
async def limited_context_manager():
411+
nonlocal call_count
412+
call_count += 1
413+
yield call_count
414+
415+
# Multiple concurrent context managers
416+
async def use_context_manager():
417+
async with limited_context_manager() as value:
418+
await asyncio.sleep(0.1)
419+
return value
420+
421+
tasks = [asyncio.create_task(use_context_manager()) for _ in range(3)]
422+
results = await asyncio.gather(*tasks)
423+
424+
# All should complete successfully
425+
assert len(results) == 3
426+
assert all(isinstance(r, int) for r in results)
427+
428+
429+
async def test_context_manager_capacity_enforcement(
430+
redis_client_sdk: RedisClientSDK,
431+
semaphore_name: str,
432+
):
433+
concurrent_count = 0
434+
max_concurrent = 0
435+
436+
@asynccontextmanager
437+
@with_limited_concurrency_cm(
438+
redis_client_sdk,
439+
key=semaphore_name,
440+
capacity=2,
441+
)
442+
async def limited_context_manager():
443+
nonlocal concurrent_count, max_concurrent
444+
concurrent_count += 1
445+
max_concurrent = max(max_concurrent, concurrent_count)
446+
try:
447+
yield
448+
await asyncio.sleep(0.1)
449+
finally:
450+
concurrent_count -= 1
451+
452+
async def use_context_manager():
453+
async with limited_context_manager():
454+
await asyncio.sleep(0.1)
455+
456+
# Start 5 concurrent context managers
457+
tasks = [asyncio.create_task(use_context_manager()) for _ in range(5)]
458+
await asyncio.gather(*tasks)
459+
460+
# Should never exceed capacity of 2
461+
assert max_concurrent <= 2
462+
463+
464+
async def test_context_manager_exception_handling(
465+
redis_client_sdk: RedisClientSDK,
466+
semaphore_name: str,
467+
):
468+
@asynccontextmanager
469+
@with_limited_concurrency_cm(
470+
redis_client_sdk,
471+
key=semaphore_name,
472+
capacity=1,
473+
)
474+
async def failing_context_manager():
475+
yield
476+
raise RuntimeError("Test exception")
477+
478+
with pytest.raises(RuntimeError, match="Test exception"):
479+
async with failing_context_manager():
480+
pass
481+
482+
# Semaphore should be released even after exception
483+
@asynccontextmanager
484+
@with_limited_concurrency_cm(
485+
redis_client_sdk,
486+
key=semaphore_name,
487+
capacity=1,
488+
)
489+
async def success_context_manager():
490+
yield "success"
491+
492+
async with success_context_manager() as result:
493+
assert result == "success"
494+
495+
496+
async def test_context_manager_auto_renewal(
497+
redis_client_sdk: RedisClientSDK,
498+
semaphore_name: str,
499+
semaphore_capacity: int,
500+
short_ttl: datetime.timedelta,
501+
):
502+
work_started = asyncio.Event()
503+
work_completed = asyncio.Event()
504+
505+
@asynccontextmanager
506+
@with_limited_concurrency_cm(
507+
redis_client_sdk,
508+
key=semaphore_name,
509+
capacity=semaphore_capacity,
510+
ttl=short_ttl,
511+
)
512+
async def long_running_context_manager():
513+
work_started.set()
514+
yield "data"
515+
# Wait longer than TTL to ensure renewal works
516+
await asyncio.sleep(short_ttl.total_seconds() * 2)
517+
work_completed.set()
518+
519+
async def use_long_running_cm():
520+
async with long_running_context_manager() as data:
521+
assert data == "data"
522+
# Keep context manager active for longer than TTL
523+
await asyncio.sleep(short_ttl.total_seconds() * 1.5)
524+
525+
task = asyncio.create_task(use_long_running_cm())
526+
await work_started.wait()
527+
528+
# Check that semaphore is being held
529+
temp_semaphore = DistributedSemaphore(
530+
redis_client=redis_client_sdk,
531+
key=semaphore_name,
532+
capacity=semaphore_capacity,
533+
ttl=short_ttl,
534+
)
535+
assert await temp_semaphore.get_current_count() == 1
536+
assert await temp_semaphore.get_available_count() == semaphore_capacity - 1
537+
538+
# Wait for work to complete
539+
await task
540+
assert work_completed.is_set()
541+
542+
# After completion, semaphore should be released
543+
assert await temp_semaphore.get_current_count() == 0
544+
assert await temp_semaphore.get_available_count() == semaphore_capacity
545+
546+
547+
async def test_context_manager_with_callable_parameters(
548+
redis_client_sdk: RedisClientSDK,
549+
):
550+
executed_keys = []
551+
552+
def get_redis_client(*args, **kwargs):
553+
return redis_client_sdk
554+
555+
def get_key(user_id: str, resource: str) -> str:
556+
return f"{user_id}-{resource}"
557+
558+
def get_capacity(user_id: str, resource: str) -> int:
559+
return 2
560+
561+
@asynccontextmanager
562+
@with_limited_concurrency_cm(
563+
get_redis_client,
564+
key=get_key,
565+
capacity=get_capacity,
566+
)
567+
async def process_user_resource_cm(user_id: str, resource: str):
568+
executed_keys.append(f"{user_id}-{resource}")
569+
yield f"processed-{user_id}-{resource}"
570+
await asyncio.sleep(0.05)
571+
572+
async def use_cm(user_id: str, resource: str):
573+
async with process_user_resource_cm(user_id, resource) as result:
574+
return result
575+
576+
# Test with different parameters
577+
results = await asyncio.gather(
578+
use_cm("user1", "wallet1"),
579+
use_cm("user1", "wallet2"),
580+
use_cm("user2", "wallet1"),
581+
)
582+
583+
assert len(executed_keys) == 3
584+
assert "user1-wallet1" in executed_keys
585+
assert "user1-wallet2" in executed_keys
586+
assert "user2-wallet1" in executed_keys
587+
588+
assert len(results) == 3
589+
assert "processed-user1-wallet1" in results
590+
assert "processed-user1-wallet2" in results
591+
assert "processed-user2-wallet1" in results
592+
593+
594+
async def test_context_manager_non_blocking_behavior(
595+
redis_client_sdk: RedisClientSDK,
596+
semaphore_name: str,
597+
):
598+
started_event = asyncio.Event()
599+
600+
@asynccontextmanager
601+
@with_limited_concurrency_cm(
602+
redis_client_sdk,
603+
key=semaphore_name,
604+
capacity=1,
605+
blocking=True,
606+
blocking_timeout=datetime.timedelta(seconds=0.1),
607+
)
608+
async def limited_context_manager():
609+
started_event.set()
610+
yield
611+
await asyncio.sleep(2)
612+
613+
# Start first context manager that will hold the semaphore
614+
async def long_running_cm():
615+
async with limited_context_manager():
616+
await asyncio.sleep(2)
617+
618+
task1 = asyncio.create_task(long_running_cm())
619+
await started_event.wait() # Wait until semaphore is actually acquired
620+
621+
# Second context manager should timeout and raise an exception
622+
@asynccontextmanager
623+
@with_limited_concurrency_cm(
624+
redis_client_sdk,
625+
key=semaphore_name,
626+
capacity=1,
627+
blocking=True,
628+
blocking_timeout=datetime.timedelta(seconds=0.1),
629+
)
630+
async def timeout_context_manager():
631+
yield
632+
633+
with pytest.raises(SemaphoreAcquisitionError):
634+
async with timeout_context_manager():
635+
pass
636+
637+
await task1
638+
639+
640+
async def test_context_manager_lose_semaphore_raises(
641+
redis_client_sdk: RedisClientSDK,
642+
semaphore_name: str,
643+
semaphore_capacity: int,
644+
short_ttl: datetime.timedelta,
645+
):
646+
work_started = asyncio.Event()
647+
648+
@asynccontextmanager
649+
@with_limited_concurrency_cm(
650+
redis_client_sdk,
651+
key=semaphore_name,
652+
capacity=semaphore_capacity,
653+
ttl=short_ttl,
654+
)
655+
async def context_manager_that_should_fail():
656+
work_started.set()
657+
yield "data"
658+
# Wait long enough for renewal to be attempted multiple times
659+
await asyncio.sleep(short_ttl.total_seconds() * 4)
660+
661+
async def use_failing_cm():
662+
async with context_manager_that_should_fail() as data:
663+
assert data == "data"
664+
# Keep context active while semaphore will be lost
665+
await asyncio.sleep(short_ttl.total_seconds() * 3)
666+
667+
task = asyncio.create_task(use_failing_cm())
668+
await work_started.wait()
669+
670+
# Wait for the first renewal interval to pass
671+
renewal_interval = short_ttl / 3
672+
await asyncio.sleep(renewal_interval.total_seconds() + 1)
673+
674+
# Find and delete all holder keys for this semaphore
675+
holder_keys = await redis_client_sdk.redis.keys(
676+
f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:*"
677+
)
678+
assert holder_keys, "Holder keys should exist before deletion"
679+
await redis_client_sdk.redis.delete(*holder_keys)
680+
681+
with pytest.raises(SemaphoreLostError): # Expected from the ExceptionGroup handling
682+
await task

0 commit comments

Comments
 (0)