Skip to content

Commit 51e2bde

Browse files
committed
passing tests
1 parent 76fbc07 commit 51e2bde

File tree

5 files changed

+66
-92
lines changed

5 files changed

+66
-92
lines changed

packages/service-library/src/servicelib/redis/_semaphore.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from types import TracebackType
55
from typing import Annotated, ClassVar
66

7+
import redis.exceptions
78
from common_library.basic_types import DEFAULT_FACTORY
89
from pydantic import (
910
BaseModel,
@@ -13,6 +14,7 @@
1314
field_validator,
1415
)
1516
from redis.commands.core import AsyncScript
17+
from servicelib.redis._utils import handle_redis_returns_union_types
1618

1719
from ._client import RedisClientSDK
1820
from ._constants import (
@@ -29,6 +31,7 @@
2931
from ._semaphore_lua import (
3032
ACQUIRE_FAIR_SEMAPHORE_V2_SCRIPT,
3133
COUNT_FAIR_SEMAPHORE_V2_SCRIPT,
34+
REGISTER_FAIR_SEMAPHORE_SCRIPT,
3235
RELEASE_FAIR_SEMAPHORE_V2_SCRIPT,
3336
RENEW_FAIR_SEMAPHORE_V2_SCRIPT,
3437
SCRIPT_BAD_EXIT_CODE,
@@ -92,6 +95,7 @@ class DistributedSemaphore(BaseModel):
9295
] = DEFAULT_FACTORY
9396

9497
# Class and/or Private state attributes (not part of the model)
98+
register_semaphore: ClassVar[AsyncScript | None] = None
9599
acquire_script: ClassVar[AsyncScript | None] = None
96100
count_script: ClassVar[AsyncScript | None] = None
97101
release_script: ClassVar[AsyncScript | None] = None
@@ -104,6 +108,9 @@ def _register_scripts(cls, redis_client: RedisClientSDK) -> None:
104108
caches the script SHA, so this is efficient. Even if called multiple times,
105109
the script is only registered once."""
106110
if cls.acquire_script is None:
111+
cls.register_semaphore = redis_client.redis.register_script(
112+
REGISTER_FAIR_SEMAPHORE_SCRIPT
113+
)
107114
cls.acquire_script = redis_client.redis.register_script(
108115
ACQUIRE_FAIR_SEMAPHORE_V2_SCRIPT
109116
)
@@ -145,6 +152,12 @@ def holder_key(self) -> str:
145152
"""Redis key for this instance's holder entry."""
146153
return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:{self.instance_id}"
147154

155+
@computed_field
156+
@property
157+
def holder_prefix(self) -> str:
158+
"""Prefix for holder keys (used in cleanup)."""
159+
return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:"
160+
148161
# Additional validation
149162
@field_validator("ttl")
150163
@classmethod
@@ -176,22 +189,51 @@ async def acquire(self) -> bool:
176189
"""
177190

178191
ttl_seconds = int(self.ttl.total_seconds())
179-
blocking_timeout_seconds = 0.1
192+
blocking_timeout_seconds = 1
180193
if self.blocking:
181-
blocking_timeout_seconds = (
194+
blocking_timeout_seconds = int(
182195
self.blocking_timeout.total_seconds() if self.blocking_timeout else 0
183196
)
184197

185-
# Execute the Lua script atomically
198+
# Execute the Lua scripts atomically
186199
cls = type(self)
200+
assert cls.register_semaphore is not None # nosec
201+
await cls.register_semaphore(
202+
keys=[self.tokens_key, self.holders_key],
203+
args=[self.capacity, ttl_seconds],
204+
client=self.redis_client.redis,
205+
) # pylint: disable=not-callable
206+
207+
try:
208+
# this is blocking pop with timeout
209+
tokens_key_token: list[str] = await handle_redis_returns_union_types(
210+
self.redis_client.redis.brpop(
211+
[self.tokens_key], timeout=blocking_timeout_seconds
212+
)
213+
)
214+
except redis.exceptions.TimeoutError as e:
215+
_logger.debug(
216+
"Timeout acquiring semaphore '%s' (instance: %s)",
217+
self.key,
218+
self.instance_id,
219+
)
220+
if self.blocking:
221+
raise SemaphoreAcquisitionError(
222+
name=self.key, capacity=self.capacity
223+
) from e
224+
return False
225+
226+
assert len(tokens_key_token) == 2 # nosec
227+
assert tokens_key_token[0] == self.tokens_key # nosec
228+
token = tokens_key_token[1]
229+
187230
assert cls.acquire_script is not None # nosec
188231
result = await cls.acquire_script( # pylint: disable=not-callable
189-
keys=[self.tokens_key, self.holders_key, self.holder_key],
232+
keys=[self.holders_key, self.holder_key],
190233
args=[
234+
token[0],
191235
self.instance_id,
192-
self.capacity,
193236
ttl_seconds,
194-
blocking_timeout_seconds,
195237
],
196238
client=self.redis_client.redis,
197239
)
@@ -209,6 +251,7 @@ async def acquire(self) -> bool:
209251
current_count,
210252
)
211253
return True
254+
212255
if status == "timeout":
213256
if self.blocking:
214257
_logger.debug(
@@ -261,6 +304,7 @@ async def release(self) -> None:
261304
self.instance_id,
262305
current_count,
263306
)
307+
return
264308

265309
# Instance was already expired or not acquired
266310
assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec

packages/service-library/src/servicelib/redis/_semaphore_lua.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _load_script(script_name: str) -> str:
3333
COUNT_SEMAPHORE_SCRIPT: Final[str] = _load_script("count_semaphore")
3434

3535
# fair semaphore scripts (token pool based)
36+
REGISTER_FAIR_SEMAPHORE_SCRIPT: Final[str] = _load_script("register_semaphore_holder")
3637
ACQUIRE_FAIR_SEMAPHORE_V2_SCRIPT: Final[str] = _load_script("acquire_fair_semaphore_v2")
3738
RELEASE_FAIR_SEMAPHORE_V2_SCRIPT: Final[str] = _load_script("release_fair_semaphore_v2")
3839
CLEANUP_FAIR_SEMAPHORE_V2_SCRIPT: Final[str] = _load_script("cleanup_fair_semaphore_v2")

packages/service-library/src/servicelib/redis/lua/acquire_fair_semaphore_v2.lua

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,43 +11,20 @@
1111
-- exit_code: 0 if acquired, 255 if timeout/failed
1212
-- status: 'acquired' or 'timeout'
1313

14-
local tokens_key = KEYS[1]
15-
local holders_key = KEYS[2]
16-
local holder_key = KEYS[3]
14+
local holders_key = KEYS[1]
15+
local holder_key = KEYS[2]
1716

18-
local instance_id = ARGV[1]
19-
local capacity = tonumber(ARGV[2])
17+
local token = ARGV[1]
18+
local instance_id = ARGV[2]
2019
local ttl_seconds = tonumber(ARGV[3])
21-
local timeout_seconds = tonumber(ARGV[4])
22-
23-
-- Step 1: Initialize token pool if needed (first time setup)
24-
local tokens_exist = redis.call('EXISTS', tokens_key)
25-
if tokens_exist == 0 then
26-
-- Initialize with capacity number of tokens
27-
for i = 1, capacity do
28-
redis.call('LPUSH', tokens_key, 'token_' .. i)
29-
end
30-
-- Set expiry on tokens list to prevent infinite growth
31-
redis.call('EXPIRE', tokens_key, ttl_seconds * 10)
32-
end
33-
34-
-- Step 2: Try to get a token using blocking pop
35-
-- timeout_seconds = 0 means block indefinitely
36-
local token_result = redis.call('BRPOP', tokens_key, timeout_seconds)
37-
38-
if token_result == false or token_result == nil then
39-
-- Timeout occurred
40-
local current_count = redis.call('SCARD', holders_key)
41-
return {255, 'timeout', '', current_count}
42-
end
43-
44-
local token = token_result[2] -- BRPOP returns {key, value}
45-
46-
-- Step 3: Register as holder
20+
21+
22+
23+
-- Step 1: Register as holder
4724
redis.call('SADD', holders_key, instance_id)
4825
redis.call('SETEX', holder_key, ttl_seconds, token)
4926

50-
-- Step 4: Set expiry on holders set to prevent infinite growth
27+
-- Step 2: Set expiry on holders set to prevent infinite growth
5128
redis.call('EXPIRE', holders_key, ttl_seconds * 10)
5229

5330
local current_count = redis.call('SCARD', holders_key)

packages/service-library/src/servicelib/redis/lua/register_semaphore_holder.lua

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,20 @@
1212

1313
local tokens_key = KEYS[1]
1414
local holders_key = KEYS[2]
15-
local holder_key = KEYS[3]
1615

17-
local instance_id = ARGV[1]
18-
local capacity = tonumber(ARGV[2])
19-
local ttl_seconds = tonumber(ARGV[3])
20-
local token = ARGV[4]
16+
local capacity = tonumber(ARGV[1])
17+
local ttl_seconds = tonumber(ARGV[2])
2118

2219
-- Step 1: Initialize token pool if needed (first time setup)
2320
local tokens_exist = redis.call('EXISTS', tokens_key)
24-
if tokens_exist == 0 then
21+
local holders_exist = redis.call('EXISTS', holders_key)
22+
if tokens_exist == 0 and holders_exist == 0 then
2523
-- Initialize with capacity number of tokens
2624
for i = 1, capacity do
2725
redis.call('LPUSH', tokens_key, 'token_' .. i)
2826
end
2927
-- Set expiry on tokens list to prevent infinite growth
30-
redis.call('EXPIRE', tokens_key, ttl_seconds * 10)
28+
-- redis.call('EXPIRE', tokens_key, ttl_seconds)
3129
end
3230

33-
-- Step 2: Register as holder (token was already popped by Python BRPOP)
34-
redis.call('SADD', holders_key, instance_id)
35-
redis.call('SETEX', holder_key, ttl_seconds, token)
36-
37-
-- Step 3: Set expiry on holders set to prevent infinite growth
38-
redis.call('EXPIRE', holders_key, ttl_seconds * 10)
39-
40-
local current_count = redis.call('SCARD', holders_key)
41-
42-
return {0, 'registered', current_count}
31+
return 0

packages/service-library/tests/redis/test_semaphore.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -290,43 +290,6 @@ async def _raising_context():
290290
assert await captured_semaphore.get_current_count() == 0
291291

292292

293-
async def test_semaphore_ttl_cleanup(
294-
redis_client_sdk: RedisClientSDK,
295-
semaphore_name: str,
296-
semaphore_capacity: int,
297-
short_ttl: datetime.timedelta,
298-
):
299-
# Create semaphore with explicit short TTL
300-
semaphore = DistributedSemaphore(
301-
redis_client=redis_client_sdk,
302-
key=semaphore_name,
303-
capacity=semaphore_capacity,
304-
ttl=short_ttl,
305-
)
306-
307-
# Manually add an expired entry
308-
expired_instance_id = "expired-instance"
309-
current_time = asyncio.get_event_loop().time()
310-
# Make sure it's definitely expired by using the short TTL
311-
expired_time = current_time - short_ttl.total_seconds() - 1
312-
313-
await redis_client_sdk.redis.zadd(
314-
semaphore.semaphore_key, {expired_instance_id: expired_time}
315-
)
316-
317-
# Verify the entry was added
318-
initial_count = await redis_client_sdk.redis.zcard(semaphore.semaphore_key)
319-
assert initial_count == 1
320-
321-
# Current count should clean up expired entries
322-
count = await semaphore.get_current_count()
323-
assert count == 0
324-
325-
# Verify expired entry was removed
326-
remaining = await redis_client_sdk.redis.zcard(semaphore.semaphore_key)
327-
assert remaining == 0
328-
329-
330293
async def test_multiple_semaphores_different_keys(
331294
redis_client_sdk: RedisClientSDK,
332295
faker: Faker,

0 commit comments

Comments
 (0)