Skip to content

Commit ffd5c03

Browse files
committed
work with lost tokens
1 parent 92ca699 commit ffd5c03

File tree

3 files changed

+109
-6
lines changed

3 files changed

+109
-6
lines changed

packages/service-library/src/servicelib/redis/_semaphore.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class DistributedSemaphore(BaseModel):
107107
release_script: ClassVar[AsyncScript | None] = None
108108
renew_script: ClassVar[AsyncScript | None] = None
109109

110+
_token: str | None = None # currently held token, if any
111+
110112
@classmethod
111113
def _register_scripts(cls, redis_client: RedisClientSDK) -> None:
112114
"""Register Lua scripts with Redis if not already done.
@@ -258,15 +260,15 @@ async def _try_acquire() -> list[str] | None:
258260
assert tokens_key_token is not None # nosec
259261
assert len(tokens_key_token) == 2 # nosec # noqa: PLR2004
260262
assert tokens_key_token[0] == self.tokens_key # nosec
261-
token = tokens_key_token[1]
263+
self._token = tokens_key_token[1]
262264

263265
# set up the semaphore holder with a TTL
264266
cls = type(self)
265267
assert cls.acquire_script is not None # nosec
266268
result = await cls.acquire_script( # pylint: disable=not-callable
267269
keys=[self.holders_set, self.holder_key],
268270
args=[
269-
token,
271+
self._token,
270272
self.instance_id,
271273
ttl_seconds,
272274
],
@@ -300,9 +302,12 @@ async def release(self) -> None:
300302
# Execute the release Lua script atomically
301303
cls = type(self)
302304
assert cls.release_script is not None # nosec
305+
release_args = [self.instance_id]
306+
if self._token is not None:
307+
release_args.append(self._token)
303308
result = await cls.release_script( # pylint: disable=not-callable
304309
keys=[self.tokens_key, self.holders_set, self.holder_key],
305-
args=[self.instance_id],
310+
args=release_args,
306311
client=self.redis_client.redis,
307312
)
308313

packages/service-library/src/servicelib/redis/lua/release_fair_semaphore_v2.lua

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
-- KEYS[3]: holder_key (individual holder TTL key for this instance)
55

66
-- ARGV[1]: instance_id
7+
-- ARGV[2]: passed_token (the token held by this instance or nil if unknown)
78
--
89
-- Returns: {exit_code, status, current_count}
910
-- exit_code: 0 if released, 255 if failed
@@ -14,6 +15,7 @@ local holders_key = KEYS[2]
1415
local holder_key = KEYS[3]
1516

1617
local instance_id = ARGV[1]
18+
local passed_token = ARGV[2]
1719

1820
-- Step 1: Check if this instance is currently a holder
1921
local is_holder = redis.call('SISMEMBER', holders_key, instance_id)
@@ -29,6 +31,11 @@ if not token then
2931
-- this indicates a lost semaphore (e.g. due to TTL expiry)
3032
-- remove from holders set and return error
3133
redis.call('SREM', holders_key, instance_id)
34+
-- if the token was passed return it to the pool
35+
if passed_token then
36+
redis.call('LPUSH', tokens_key, passed_token)
37+
end
38+
-- Note: we do NOT push a recovered token since we don't know its state
3239
return {255, 'expired', redis.call('SCARD', holders_key)}
3340
end
3441

packages/service-library/tests/redis/test_semaphore.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ async def test_semaphore_initialization(
5858
assert semaphore.blocking is True
5959
assert semaphore.instance_id is not None
6060
assert semaphore.semaphore_key == f"{SEMAPHORE_KEY_PREFIX}{semaphore_name}"
61-
assert semaphore.holder_key.startswith(
62-
f"{SEMAPHORE_KEY_PREFIX}{semaphore_name}:holders:"
63-
)
61+
assert semaphore.tokens_key.startswith(f"{semaphore.semaphore_key}:")
62+
assert semaphore.holders_set.startswith(f"{semaphore.semaphore_key}:")
63+
assert semaphore.holder_key.startswith(f"{semaphore.semaphore_key}:")
6464

6565

6666
async def test_invalid_semaphore_initialization(
@@ -102,6 +102,28 @@ async def test_invalid_semaphore_initialization(
102102
)
103103

104104

105+
async def _assert_semaphore_redis_state(
106+
redis_client_sdk: RedisClientSDK,
107+
semaphore: DistributedSemaphore,
108+
*,
109+
expected_count: int,
110+
expected_free_tokens: int,
111+
expected_expired: bool = False,
112+
):
113+
"""Helper to assert the internal Redis state of the semaphore"""
114+
holders = await redis_client_sdk.redis.smembers(semaphore.holders_set)
115+
assert len(holders) == expected_count
116+
if expected_count > 0:
117+
assert semaphore.instance_id in holders
118+
holder_key_exists = await redis_client_sdk.redis.exists(semaphore.holder_key)
119+
if expected_expired:
120+
assert holder_key_exists == 0
121+
else:
122+
assert holder_key_exists == 1
123+
tokens = await redis_client_sdk.redis.lrange(semaphore.tokens_key, 0, -1)
124+
assert len(tokens) == expected_free_tokens
125+
126+
105127
async def test_semaphore_acquire_release_basic(
106128
redis_client_sdk: RedisClientSDK,
107129
semaphore_name: str,
@@ -119,46 +141,88 @@ async def test_semaphore_acquire_release_basic(
119141
assert await semaphore.current_count() == 0
120142
assert await semaphore.size() == semaphore_capacity
121143
assert await semaphore.is_acquired() is False
144+
await _assert_semaphore_redis_state(
145+
redis_client_sdk,
146+
semaphore,
147+
expected_count=0,
148+
expected_free_tokens=semaphore_capacity,
149+
)
122150

123151
# Acquire
124152
result = await semaphore.acquire()
125153
assert result is True
126154
assert await semaphore.current_count() == 1
127155
assert await semaphore.size() == semaphore_capacity - 1
128156
assert await semaphore.is_acquired() is True
157+
await _assert_semaphore_redis_state(
158+
redis_client_sdk,
159+
semaphore,
160+
expected_count=1,
161+
expected_free_tokens=semaphore_capacity - 1,
162+
)
129163

130164
# Acquire again on same instance should return True immediately and keep the same count (reentrant)
131165
result = await semaphore.acquire()
132166
assert result is True
133167
assert await semaphore.current_count() == 1
134168
assert await semaphore.size() == semaphore_capacity - 1
135169
assert await semaphore.is_acquired() is True
170+
await _assert_semaphore_redis_state(
171+
redis_client_sdk,
172+
semaphore,
173+
expected_count=1,
174+
expected_free_tokens=semaphore_capacity - 1,
175+
)
136176

137177
# reacquire should just work
138178
await semaphore.reacquire()
139179
assert await semaphore.current_count() == 1
140180
assert await semaphore.size() == semaphore_capacity - 1
141181
assert await semaphore.is_acquired() is True
182+
await _assert_semaphore_redis_state(
183+
redis_client_sdk,
184+
semaphore,
185+
expected_count=1,
186+
expected_free_tokens=semaphore_capacity - 1,
187+
)
142188

143189
# Release
144190
await semaphore.release()
145191
assert await semaphore.current_count() == 0
146192
assert await semaphore.size() == semaphore_capacity
147193
assert await semaphore.is_acquired() is False
194+
await _assert_semaphore_redis_state(
195+
redis_client_sdk,
196+
semaphore,
197+
expected_count=0,
198+
expected_free_tokens=semaphore_capacity,
199+
)
148200

149201
# reacquire after release should fail
150202
with pytest.raises(
151203
SemaphoreNotAcquiredError,
152204
match=f"Semaphore '{semaphore_name}' was not acquired by this instance",
153205
):
154206
await semaphore.reacquire()
207+
await _assert_semaphore_redis_state(
208+
redis_client_sdk,
209+
semaphore,
210+
expected_count=0,
211+
expected_free_tokens=semaphore_capacity,
212+
)
155213

156214
# so does release again
157215
with pytest.raises(
158216
SemaphoreNotAcquiredError,
159217
match=f"Semaphore '{semaphore_name}' was not acquired by this instance",
160218
):
161219
await semaphore.release()
220+
await _assert_semaphore_redis_state(
221+
redis_client_sdk,
222+
semaphore,
223+
expected_count=0,
224+
expected_free_tokens=semaphore_capacity,
225+
)
162226

163227

164228
async def test_semaphore_acquire_release_with_ttl_expiry(
@@ -176,21 +240,48 @@ async def test_semaphore_acquire_release_with_ttl_expiry(
176240
await semaphore.acquire()
177241
assert await semaphore.current_count() == 1
178242
assert await semaphore.size() == semaphore_capacity - 1
243+
await _assert_semaphore_redis_state(
244+
redis_client_sdk,
245+
semaphore,
246+
expected_count=1,
247+
expected_free_tokens=semaphore_capacity - 1,
248+
)
249+
179250
# wait for TTL to expire
180251
await asyncio.sleep(with_short_default_semaphore_ttl.total_seconds() + 0.1)
252+
await _assert_semaphore_redis_state(
253+
redis_client_sdk,
254+
semaphore,
255+
expected_count=1,
256+
expected_free_tokens=semaphore_capacity - 1,
257+
expected_expired=True,
258+
)
181259

182260
# TTL expired, reacquire should fail
183261
with pytest.raises(
184262
SemaphoreLostError,
185263
match=f"Semaphore '{semaphore_name}' was lost by this instance",
186264
):
187265
await semaphore.reacquire()
266+
await _assert_semaphore_redis_state(
267+
redis_client_sdk,
268+
semaphore,
269+
expected_count=1,
270+
expected_free_tokens=semaphore_capacity - 1,
271+
expected_expired=True,
272+
)
188273
# and release should also fail
189274
with pytest.raises(
190275
SemaphoreLostError,
191276
match=f"Semaphore '{semaphore_name}' was lost by this instance",
192277
):
193278
await semaphore.release()
279+
await _assert_semaphore_redis_state(
280+
redis_client_sdk,
281+
semaphore,
282+
expected_count=0,
283+
expected_free_tokens=semaphore_capacity,
284+
)
194285

195286

196287
async def test_semaphore_context_manager(

0 commit comments

Comments
 (0)