Skip to content

Commit 8829a50

Browse files
committed
fixes
1 parent 02cea1b commit 8829a50

File tree

3 files changed

+52
-36
lines changed

3 files changed

+52
-36
lines changed

packages/service-library/src/servicelib/redis/_semaphore.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,22 +101,30 @@ class DistributedSemaphore(BaseModel):
101101
] = DEFAULT_FACTORY
102102

103103
# Private state attributes (not part of the model)
104-
_acquire_script: ClassVar[AsyncScript]
105-
_count_script: ClassVar[AsyncScript]
106-
_release_script: ClassVar[AsyncScript]
107-
_renew_script: ClassVar[AsyncScript]
104+
acquire_script: ClassVar[AsyncScript | None] = None
105+
count_script: ClassVar[AsyncScript | None] = None
106+
release_script: ClassVar[AsyncScript | None] = None
107+
renew_script: ClassVar[AsyncScript | None] = None
108108

109109
@classmethod
110-
def _register_scripts(cls, redis_client) -> None:
111-
if cls._acquire_script is None:
112-
cls._acquire_script = redis_client.register_script(ACQUIRE_SEMAPHORE_SCRIPT)
113-
cls._count_script = redis_client.register_script(COUNT_SEMAPHORE_SCRIPT)
114-
cls._release_script = redis_client.register_script(RELEASE_SEMAPHORE_SCRIPT)
115-
cls._renew_script = redis_client.register_script(RENEW_SEMAPHORE_SCRIPT)
110+
def _register_scripts(cls, redis_client: RedisClientSDK) -> None:
111+
if cls.acquire_script is None:
112+
cls.acquire_script = redis_client.redis.register_script(
113+
ACQUIRE_SEMAPHORE_SCRIPT
114+
)
115+
cls.count_script = redis_client.redis.register_script(
116+
COUNT_SEMAPHORE_SCRIPT
117+
)
118+
cls.release_script = redis_client.redis.register_script(
119+
RELEASE_SEMAPHORE_SCRIPT
120+
)
121+
cls.renew_script = redis_client.redis.register_script(
122+
RENEW_SEMAPHORE_SCRIPT
123+
)
116124

117125
def __init__(self, **data) -> None:
118126
super().__init__(**data)
119-
type(self)._register_scripts(self.redis_client) # noqa: SLF001
127+
self.__class__._register_scripts(self.redis_client) # noqa: SLF001
120128

121129
@computed_field # type: ignore[prop-decorator]
122130
@property
@@ -196,7 +204,7 @@ async def release(self) -> None:
196204
ttl_seconds = int(self.ttl.total_seconds())
197205

198206
# Execute the release Lua script atomically
199-
result = await type(self)._release_script( # noqa: SLF001
207+
result = await type(self).release_script( # noqa: SLF001
200208
keys=(
201209
self.semaphore_key,
202210
self.holder_key,
@@ -231,7 +239,7 @@ async def _try_acquire(self) -> bool:
231239
ttl_seconds = int(self.ttl.total_seconds())
232240

233241
# Execute the Lua script atomically
234-
result = await type(self)._acquire_script( # noqa: SLF001
242+
result = await type(self).acquire_script( # noqa: SLF001
235243
keys=(self.semaphore_key, self.holder_key),
236244
args=(self.instance_id, str(self.capacity), str(ttl_seconds)),
237245
client=self.redis_client.redis,
@@ -274,7 +282,7 @@ async def reacquire(self) -> None:
274282
ttl_seconds = int(self.ttl.total_seconds())
275283

276284
# Execute the renewal Lua script atomically
277-
result = await type(self)._renew_script( # noqa: SLF001
285+
result = await type(self).renew_script( # noqa: SLF001
278286
keys=(self.semaphore_key, self.holder_key),
279287
args=(
280288
self.instance_id,
@@ -322,7 +330,7 @@ async def get_current_count(self) -> int:
322330
ttl_seconds = int(self.ttl.total_seconds())
323331

324332
# Execute the count Lua script atomically
325-
result = await type(self)._count_script( # noqa: SLF001
333+
result = await type(self).count_script( # noqa: SLF001
326334
keys=(self.semaphore_key,),
327335
args=(str(ttl_seconds),),
328336
client=self.redis_client.redis,

packages/service-library/src/servicelib/redis/_semaphore_decorator.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import socket
66
from collections.abc import AsyncIterator, Callable, Coroutine
77
from contextlib import asynccontextmanager
8-
from typing import Any, ParamSpec, TypeVar
8+
from typing import Any, AsyncContextManager, ParamSpec, TypeVar
99

1010
from common_library.async_tools import cancel_wait_task
1111

@@ -210,7 +210,10 @@ def with_limited_concurrency_cm(
210210
ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL,
211211
blocking: bool = True,
212212
blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT,
213-
) -> Callable[[Callable[P, AsyncIterator[R]]], Callable[P, AsyncIterator[R]]]:
213+
) -> Callable[
214+
[Callable[P, AsyncContextManager[R]]],
215+
Callable[P, AsyncContextManager[R]],
216+
]:
214217
"""
215218
Decorator to limit concurrent execution of async context managers using a distributed semaphore.
216219
@@ -244,9 +247,10 @@ async def get_cluster_client():
244247
"""
245248

246249
def _decorator(
247-
cm_func: Callable[P, AsyncIterator[R]],
248-
) -> Callable[P, AsyncIterator[R]]:
250+
cm_func: Callable[P, AsyncContextManager[R]],
251+
) -> Callable[P, AsyncContextManager[R]]:
249252
@functools.wraps(cm_func)
253+
@asynccontextmanager
250254
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]:
251255
semaphore, semaphore_key = _create_semaphore(
252256
redis_client,
@@ -259,11 +263,13 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]:
259263
kwargs=kwargs,
260264
)
261265

262-
async with _managed_semaphore_execution(
263-
semaphore, semaphore_key, ttl, f"context_manager_{cm_func.__name__}"
266+
async with (
267+
_managed_semaphore_execution(
268+
semaphore, semaphore_key, ttl, f"context_manager_{cm_func.__name__}"
269+
),
270+
cm_func(*args, **kwargs) as value,
264271
):
265-
async for value in cm_func(*args, **kwargs):
266-
yield value
272+
yield value
267273

268274
return _wrapper
269275

packages/service-library/tests/redis/test_semaphore_decorator.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -401,19 +401,19 @@ async def test_context_manager_basic_functionality(
401401
):
402402
call_count = 0
403403

404-
@asynccontextmanager
405404
@with_limited_concurrency_cm(
406405
redis_client_sdk,
407406
key=semaphore_name,
408407
capacity=1,
409408
)
409+
@asynccontextmanager
410410
async def limited_context_manager():
411411
nonlocal call_count
412412
call_count += 1
413413
yield call_count
414414

415415
# Multiple concurrent context managers
416-
async def use_context_manager():
416+
async def use_context_manager() -> int:
417417
async with limited_context_manager() as value:
418418
await asyncio.sleep(0.1)
419419
return value
@@ -433,12 +433,12 @@ async def test_context_manager_capacity_enforcement(
433433
concurrent_count = 0
434434
max_concurrent = 0
435435

436-
@asynccontextmanager
437436
@with_limited_concurrency_cm(
438437
redis_client_sdk,
439438
key=semaphore_name,
440439
capacity=2,
441440
)
441+
@asynccontextmanager
442442
async def limited_context_manager():
443443
nonlocal concurrent_count, max_concurrent
444444
concurrent_count += 1
@@ -449,12 +449,12 @@ async def limited_context_manager():
449449
finally:
450450
concurrent_count -= 1
451451

452-
async def use_context_manager():
452+
async def use_context_manager() -> None:
453453
async with limited_context_manager():
454454
await asyncio.sleep(0.1)
455455

456-
# Start 5 concurrent context managers
457-
tasks = [asyncio.create_task(use_context_manager()) for _ in range(5)]
456+
# Start concurrent context managers
457+
tasks = [asyncio.create_task(use_context_manager()) for _ in range(20)]
458458
await asyncio.gather(*tasks)
459459

460460
# Should never exceed capacity of 2
@@ -465,12 +465,12 @@ async def test_context_manager_exception_handling(
465465
redis_client_sdk: RedisClientSDK,
466466
semaphore_name: str,
467467
):
468-
@asynccontextmanager
469468
@with_limited_concurrency_cm(
470469
redis_client_sdk,
471470
key=semaphore_name,
472471
capacity=1,
473472
)
473+
@asynccontextmanager
474474
async def failing_context_manager():
475475
yield
476476
raise RuntimeError("Test exception")
@@ -480,12 +480,13 @@ async def failing_context_manager():
480480
pass
481481

482482
# Semaphore should be released even after exception
483-
@asynccontextmanager
483+
484484
@with_limited_concurrency_cm(
485485
redis_client_sdk,
486486
key=semaphore_name,
487487
capacity=1,
488488
)
489+
@asynccontextmanager
489490
async def success_context_manager():
490491
yield "success"
491492

@@ -502,13 +503,13 @@ async def test_context_manager_auto_renewal(
502503
work_started = asyncio.Event()
503504
work_completed = asyncio.Event()
504505

505-
@asynccontextmanager
506506
@with_limited_concurrency_cm(
507507
redis_client_sdk,
508508
key=semaphore_name,
509509
capacity=semaphore_capacity,
510510
ttl=short_ttl,
511511
)
512+
@asynccontextmanager
512513
async def long_running_context_manager():
513514
work_started.set()
514515
yield "data"
@@ -558,12 +559,12 @@ def get_key(user_id: str, resource: str) -> str:
558559
def get_capacity(user_id: str, resource: str) -> int:
559560
return 2
560561

561-
@asynccontextmanager
562562
@with_limited_concurrency_cm(
563563
get_redis_client,
564564
key=get_key,
565565
capacity=get_capacity,
566566
)
567+
@asynccontextmanager
567568
async def process_user_resource_cm(user_id: str, resource: str):
568569
executed_keys.append(f"{user_id}-{resource}")
569570
yield f"processed-{user_id}-{resource}"
@@ -597,14 +598,14 @@ async def test_context_manager_non_blocking_behavior(
597598
):
598599
started_event = asyncio.Event()
599600

600-
@asynccontextmanager
601601
@with_limited_concurrency_cm(
602602
redis_client_sdk,
603603
key=semaphore_name,
604604
capacity=1,
605605
blocking=True,
606606
blocking_timeout=datetime.timedelta(seconds=0.1),
607607
)
608+
@asynccontextmanager
608609
async def limited_context_manager():
609610
started_event.set()
610611
yield
@@ -619,14 +620,15 @@ async def long_running_cm():
619620
await started_event.wait() # Wait until semaphore is actually acquired
620621

621622
# Second context manager should timeout and raise an exception
622-
@asynccontextmanager
623+
623624
@with_limited_concurrency_cm(
624625
redis_client_sdk,
625626
key=semaphore_name,
626627
capacity=1,
627628
blocking=True,
628629
blocking_timeout=datetime.timedelta(seconds=0.1),
629630
)
631+
@asynccontextmanager
630632
async def timeout_context_manager():
631633
yield
632634

@@ -645,13 +647,13 @@ async def test_context_manager_lose_semaphore_raises(
645647
):
646648
work_started = asyncio.Event()
647649

648-
@asynccontextmanager
649650
@with_limited_concurrency_cm(
650651
redis_client_sdk,
651652
key=semaphore_name,
652653
capacity=semaphore_capacity,
653654
ttl=short_ttl,
654655
)
656+
@asynccontextmanager
655657
async def context_manager_that_should_fail():
656658
work_started.set()
657659
yield "data"

0 commit comments

Comments
 (0)