|
19 | 19 | from ._constants import ( |
20 | 20 | DEFAULT_SEMAPHORE_TTL, |
21 | 21 | DEFAULT_SOCKET_TIMEOUT, |
22 | | - SEMAPHORE_HOLDER_KEY_PREFIX, |
23 | 22 | SEMAPHORE_KEY_PREFIX, |
24 | 23 | ) |
25 | 24 | from ._errors import ( |
@@ -150,13 +149,13 @@ def holders_key(self) -> str: |
150 | 149 | @property |
151 | 150 | def holder_key(self) -> str: |
152 | 151 | """Redis key for this instance's holder entry.""" |
153 | | - return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:{self.instance_id}" |
| 152 | + return f"{SEMAPHORE_KEY_PREFIX}{self.key}:holders_:{self.instance_id}" |
154 | 153 |
|
155 | 154 | @computed_field |
156 | 155 | @property |
157 | 156 | def holder_prefix(self) -> str: |
158 | 157 | """Prefix for holder keys (used in cleanup).""" |
159 | | - return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:" |
| 158 | + return f"{SEMAPHORE_KEY_PREFIX}{self.key}:holders_:" |
160 | 159 |
|
161 | 160 | # Additional validation |
162 | 161 | @field_validator("ttl") |
@@ -188,6 +187,14 @@ async def acquire(self) -> bool: |
188 | 187 | SemaphoreAcquisitionError: If acquisition fails and blocking=True |
189 | 188 | """ |
190 | 189 |
|
| 190 | + if await self.is_acquired(): |
| 191 | + _logger.debug( |
| 192 | + "Semaphore '%s' already acquired by this instance (instance: %s)", |
| 193 | + self.key, |
| 194 | + self.instance_id, |
| 195 | + ) |
| 196 | + return True |
| 197 | + |
191 | 198 | ttl_seconds = int(self.ttl.total_seconds()) |
192 | 199 | blocking_timeout_seconds = 0.001 |
193 | 200 | if self.blocking: |
@@ -363,26 +370,26 @@ async def reacquire(self) -> None: |
363 | 370 |
|
364 | 371 | raise SemaphoreLostError(name=self.key, instance_id=self.instance_id) |
365 | 372 |
|
| 373 | + async def is_acquired(self) -> bool: |
| 374 | + """Check if the semaphore is currently acquired by this instance.""" |
| 375 | + return ( |
| 376 | + await handle_redis_returns_union_types( |
| 377 | + self.redis_client.redis.exists(self.holder_key) |
| 378 | + ) |
| 379 | + == 1 |
| 380 | + ) |
| 381 | + |
366 | 382 | async def get_current_count(self) -> int: |
367 | 383 | """Get the current number of semaphore holders""" |
368 | | - |
369 | | - cls = type(self) |
370 | | - assert cls.count_script is not None # nosec |
371 | | - result = await cls.count_script( # pylint: disable=not-callable |
372 | | - keys=[self.holders_key, self.tokens_key], |
373 | | - args=[self.capacity], |
374 | | - client=self.redis_client.redis, |
| 384 | + return await handle_redis_returns_union_types( |
| 385 | + self.redis_client.redis.scard(self.holders_key) |
375 | 386 | ) |
376 | 387 |
|
377 | | - assert isinstance(result, list) # nosec |
378 | | - current_holders, available_tokens, capacity = result |
379 | | - |
380 | | - return int(current_holders) |
381 | | - |
382 | 388 | async def get_available_count(self) -> int: |
383 | 389 | """Get the number of available semaphore slots""" |
384 | | - current_count = await self.get_current_count() |
385 | | - return max(0, self.capacity - current_count) |
| 390 | + return await handle_redis_returns_union_types( |
| 391 | + self.redis_client.redis.llen(self.tokens_key) |
| 392 | + ) |
386 | 393 |
|
387 | 394 | # Context manager support |
388 | 395 | async def __aenter__(self) -> "DistributedSemaphore": |
|
0 commit comments