Skip to content

Commit 03d4fa9

Browse files
committed
fixing code
1 parent 8050c13 commit 03d4fa9

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

packages/service-library/src/servicelib/redis/_semaphore.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from ._constants import (
2020
DEFAULT_SEMAPHORE_TTL,
2121
DEFAULT_SOCKET_TIMEOUT,
22-
SEMAPHORE_HOLDER_KEY_PREFIX,
2322
SEMAPHORE_KEY_PREFIX,
2423
)
2524
from ._errors import (
@@ -150,13 +149,13 @@ def holders_key(self) -> str:
150149
@property
151150
def holder_key(self) -> str:
152151
"""Redis key for this instance's holder entry."""
153-
return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:{self.instance_id}"
152+
return f"{SEMAPHORE_KEY_PREFIX}{self.key}:holders_:{self.instance_id}"
154153

155154
@computed_field
156155
@property
157156
def holder_prefix(self) -> str:
158157
"""Prefix for holder keys (used in cleanup)."""
159-
return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:"
158+
return f"{SEMAPHORE_KEY_PREFIX}{self.key}:holders_:"
160159

161160
# Additional validation
162161
@field_validator("ttl")
@@ -188,6 +187,14 @@ async def acquire(self) -> bool:
188187
SemaphoreAcquisitionError: If acquisition fails and blocking=True
189188
"""
190189

190+
if await self.is_acquired():
191+
_logger.debug(
192+
"Semaphore '%s' already acquired by this instance (instance: %s)",
193+
self.key,
194+
self.instance_id,
195+
)
196+
return True
197+
191198
ttl_seconds = int(self.ttl.total_seconds())
192199
blocking_timeout_seconds = 0.001
193200
if self.blocking:
@@ -363,26 +370,26 @@ async def reacquire(self) -> None:
363370

364371
raise SemaphoreLostError(name=self.key, instance_id=self.instance_id)
365372

373+
async def is_acquired(self) -> bool:
374+
"""Check if the semaphore is currently acquired by this instance."""
375+
return (
376+
await handle_redis_returns_union_types(
377+
self.redis_client.redis.exists(self.holder_key)
378+
)
379+
== 1
380+
)
381+
366382
async def get_current_count(self) -> int:
367383
"""Get the current number of semaphore holders"""
368-
369-
cls = type(self)
370-
assert cls.count_script is not None # nosec
371-
result = await cls.count_script( # pylint: disable=not-callable
372-
keys=[self.holders_key, self.tokens_key],
373-
args=[self.capacity],
374-
client=self.redis_client.redis,
384+
return await handle_redis_returns_union_types(
385+
self.redis_client.redis.scard(self.holders_key)
375386
)
376387

377-
assert isinstance(result, list) # nosec
378-
current_holders, available_tokens, capacity = result
379-
380-
return int(current_holders)
381-
382388
async def get_available_count(self) -> int:
383389
"""Get the number of available semaphore slots"""
384-
current_count = await self.get_current_count()
385-
return max(0, self.capacity - current_count)
390+
return await handle_redis_returns_union_types(
391+
self.redis_client.redis.llen(self.tokens_key)
392+
)
386393

387394
# Context manager support
388395
async def __aenter__(self) -> "DistributedSemaphore":

packages/service-library/tests/redis/test_semaphore.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ async def test_semaphore_acquire_release_single(
102102
semaphore_capacity: int,
103103
):
104104
semaphore = DistributedSemaphore(
105-
redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity
105+
redis_client=redis_client_sdk,
106+
key=semaphore_name,
107+
capacity=semaphore_capacity,
108+
ttl=datetime.timedelta(seconds=60),
106109
)
107110

108111
# Initially not acquired
@@ -216,7 +219,10 @@ async def test_semaphore_blocking_timeout(
216219

217220
# First semaphore acquires
218221
async with DistributedSemaphore(
219-
redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity
222+
redis_client=redis_client_sdk,
223+
key=semaphore_name,
224+
capacity=capacity,
225+
ttl=datetime.timedelta(seconds=60),
220226
):
221227
# Second semaphore should timeout
222228
semaphore2 = DistributedSemaphore(

0 commit comments

Comments
 (0)