Skip to content

Commit dc1be02

Browse files
committed
on the road
1 parent e2a6d6e commit dc1be02

File tree

1 file changed

+106
-131
lines changed

1 file changed

+106
-131
lines changed

packages/service-library/src/servicelib/redis/_semaphore.py

Lines changed: 106 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@
1313
field_validator,
1414
)
1515
from redis.commands.core import AsyncScript
16-
from tenacity import (
17-
RetryError,
18-
before_sleep_log,
19-
retry,
20-
retry_if_not_result,
21-
stop_after_delay,
22-
stop_never,
23-
wait_random_exponential,
24-
)
2516

2617
from ._client import RedisClientSDK
2718
from ._constants import (
@@ -36,10 +27,10 @@
3627
SemaphoreNotAcquiredError,
3728
)
3829
from ._semaphore_lua import (
39-
ACQUIRE_SEMAPHORE_SCRIPT,
40-
COUNT_SEMAPHORE_SCRIPT,
41-
RELEASE_SEMAPHORE_SCRIPT,
42-
RENEW_SEMAPHORE_SCRIPT,
30+
ACQUIRE_FAIR_SEMAPHORE_V2_SCRIPT,
31+
COUNT_FAIR_SEMAPHORE_V2_SCRIPT,
32+
RELEASE_FAIR_SEMAPHORE_V2_SCRIPT,
33+
RENEW_FAIR_SEMAPHORE_V2_SCRIPT,
4334
SCRIPT_BAD_EXIT_CODE,
4435
SCRIPT_OK_EXIT_CODE,
4536
)
@@ -114,16 +105,16 @@ def _register_scripts(cls, redis_client: RedisClientSDK) -> None:
114105
the script is only registered once."""
115106
if cls.acquire_script is None:
116107
cls.acquire_script = redis_client.redis.register_script(
117-
ACQUIRE_SEMAPHORE_SCRIPT
108+
ACQUIRE_FAIR_SEMAPHORE_V2_SCRIPT
118109
)
119110
cls.count_script = redis_client.redis.register_script(
120-
COUNT_SEMAPHORE_SCRIPT
111+
COUNT_FAIR_SEMAPHORE_V2_SCRIPT
121112
)
122113
cls.release_script = redis_client.redis.register_script(
123-
RELEASE_SEMAPHORE_SCRIPT
114+
RELEASE_FAIR_SEMAPHORE_V2_SCRIPT
124115
)
125116
cls.renew_script = redis_client.redis.register_script(
126-
RENEW_SEMAPHORE_SCRIPT
117+
RENEW_FAIR_SEMAPHORE_V2_SCRIPT
127118
)
128119

129120
def __init__(self, **data) -> None:
@@ -136,6 +127,18 @@ def semaphore_key(self) -> str:
136127
"""Redis key for the semaphore sorted set."""
137128
return f"{SEMAPHORE_KEY_PREFIX}{self.key}"
138129

130+
@computed_field # type: ignore[prop-decorator]
131+
@property
132+
def tokens_key(self) -> str:
133+
"""Redis key for the token pool LIST."""
134+
return f"{SEMAPHORE_KEY_PREFIX}{self.key}:tokens"
135+
136+
@computed_field # type: ignore[prop-decorator]
137+
@property
138+
def holders_key(self) -> str:
139+
"""Redis key for the holders SET."""
140+
return f"{SEMAPHORE_KEY_PREFIX}{self.key}:holders"
141+
139142
@computed_field # type: ignore[prop-decorator]
140143
@property
141144
def holder_key(self) -> str:
@@ -172,109 +175,103 @@ async def acquire(self) -> bool:
172175
SemaphoreAcquisitionError: If acquisition fails and blocking=True
173176
"""
174177

175-
if not self.blocking:
176-
# Non-blocking: try once
177-
return await self._try_acquire()
178-
179-
# Blocking
180-
@retry(
181-
wait=wait_random_exponential(min=0.1, max=2),
182-
reraise=True,
183-
stop=(
184-
stop_after_delay(self.blocking_timeout.total_seconds())
185-
if self.blocking_timeout
186-
else stop_never
187-
),
188-
retry=retry_if_not_result(lambda acquired: acquired),
189-
before_sleep=before_sleep_log(_logger, logging.DEBUG),
190-
)
191-
async def _blocking_acquire() -> bool:
192-
return await self._try_acquire()
193-
194-
try:
195-
return await _blocking_acquire()
196-
except RetryError as exc:
197-
raise SemaphoreAcquisitionError(
198-
name=self.key, capacity=self.capacity
199-
) from exc
200-
201-
async def release(self) -> None:
202-
"""
203-
Release the semaphore atomically using Lua script.
204-
205-
Raises:
206-
SemaphoreNotAcquiredError: If semaphore was not acquired by this instance
207-
"""
208178
ttl_seconds = int(self.ttl.total_seconds())
179+
blocking_timeout_seconds = 0.1
180+
if self.blocking:
181+
blocking_timeout_seconds = (
182+
self.blocking_timeout.total_seconds() if self.blocking_timeout else 0
183+
)
209184

210-
# Execute the release Lua script atomically
185+
# Execute the Lua script atomically
211186
cls = type(self)
212-
assert cls.release_script is not None # nosec
213-
result = await cls.release_script( # pylint: disable=not-callable
214-
keys=(
215-
self.semaphore_key,
216-
self.holder_key,
217-
),
218-
args=(
187+
assert cls.acquire_script is not None # nosec
188+
result = await cls.acquire_script( # pylint: disable=not-callable
189+
keys=[self.tokens_key, self.holders_key, self.holder_key],
190+
args=[
219191
self.instance_id,
220-
str(ttl_seconds),
221-
),
192+
self.capacity,
193+
ttl_seconds,
194+
blocking_timeout_seconds,
195+
],
222196
client=self.redis_client.redis,
223197
)
224198

199+
# Lua script returns: [exit_code, status, current_count, expired_count]
225200
assert isinstance(result, list) # nosec
226-
exit_code, status, current_count, expired_count = result
227-
result = status
201+
exit_code, status, token, current_count = result
228202

229-
if result == "released":
230-
assert exit_code == SCRIPT_OK_EXIT_CODE # nosec
203+
if exit_code == SCRIPT_OK_EXIT_CODE:
204+
_logger.debug(
205+
"Acquired semaphore '%s' with token %s (instance: %s, count: %s)",
206+
self.key,
207+
token,
208+
self.instance_id,
209+
current_count,
210+
)
211+
return True
212+
if status == "timeout":
213+
if self.blocking:
214+
_logger.debug(
215+
"Timeout acquiring semaphore '%s' (instance: %s, count: %s)",
216+
self.key,
217+
self.instance_id,
218+
current_count,
219+
)
220+
raise SemaphoreAcquisitionError(name=self.key, capacity=self.capacity)
231221
_logger.debug(
232-
"Released semaphore '%s' (instance: %s, count: %s, expired: %s)",
222+
"Timeout acquiring semaphore '%s' (instance: %s, count: %s)",
233223
self.key,
234224
self.instance_id,
235225
current_count,
236-
expired_count,
237226
)
238-
else:
239-
# Instance wasn't in the semaphore set - this shouldn't happen
240-
# but let's handle it gracefully
241-
assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
242-
raise SemaphoreNotAcquiredError(name=self.key)
227+
return False
243228

244-
async def _try_acquire(self) -> bool:
245-
ttl_seconds = int(self.ttl.total_seconds())
229+
_logger.debug(
230+
"Failed to acquire semaphore '%s' - %s (count: %s)",
231+
self.key,
232+
status,
233+
current_count,
234+
)
235+
raise SemaphoreAcquisitionError(name=self.key, capacity=self.capacity)
246236

247-
# Execute the Lua script atomically
237+
async def release(self) -> None:
238+
"""
239+
Release the semaphore atomically using Lua script.
240+
241+
Raises:
242+
SemaphoreNotAcquiredError: If semaphore was not acquired by this instance
243+
"""
244+
245+
# Execute the release Lua script atomically
248246
cls = type(self)
249-
assert cls.acquire_script is not None # nosec
250-
result = await cls.acquire_script( # pylint: disable=not-callable
251-
keys=(self.semaphore_key, self.holder_key),
252-
args=(self.instance_id, str(self.capacity), str(ttl_seconds)),
247+
assert cls.release_script is not None # nosec
248+
result = await cls.release_script( # pylint: disable=not-callable
249+
keys=[self.tokens_key, self.holders_key, self.holder_key],
250+
args=[self.instance_id],
253251
client=self.redis_client.redis,
254252
)
255253

256-
# Lua script returns: [exit_code, status, current_count, expired_count]
257254
assert isinstance(result, list) # nosec
258-
exit_code, status, current_count, expired_count = result
259-
255+
exit_code, status, current_count = result
260256
if exit_code == SCRIPT_OK_EXIT_CODE:
257+
assert status == "released" # nosec
261258
_logger.debug(
262-
"Acquired semaphore '%s' (instance: %s, count: %s, expired: %s)",
259+
"Released semaphore '%s' (instance: %s, count: %s)",
263260
self.key,
264261
self.instance_id,
265262
current_count,
266-
expired_count,
267263
)
268-
return True
269264

270-
_logger.debug(
271-
"Failed to acquire semaphore '%s' - %s (count: %s, expired: %s)",
265+
# Instance was already expired or not acquired
266+
assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
267+
_logger.error(
268+
"Failed to release semaphore '%s' - %s (instance: %s, count: %s)",
272269
self.key,
273270
status,
271+
self.instance_id,
274272
current_count,
275-
expired_count,
276273
)
277-
return False
274+
raise SemaphoreNotAcquiredError(name=self.key)
278275

279276
async def reacquire(self) -> None:
280277
"""
@@ -293,72 +290,50 @@ async def reacquire(self) -> None:
293290
cls = type(self)
294291
assert cls.renew_script is not None # nosec
295292
result = await cls.renew_script( # pylint: disable=not-callable
296-
keys=(self.semaphore_key, self.holder_key),
297-
args=(
298-
self.instance_id,
299-
str(ttl_seconds),
300-
),
293+
keys=[self.holders_key, self.holder_key],
294+
args=[self.instance_id, ttl_seconds],
301295
client=self.redis_client.redis,
302296
)
303297

304298
assert isinstance(result, list) # nosec
305-
exit_code, status, current_count, expired_count = result
299+
exit_code, status, current_count = result
306300

307-
# Lua script returns: 'renewed' or status message
308-
if status == "renewed":
309-
assert exit_code == SCRIPT_OK_EXIT_CODE # nosec
301+
if exit_code == SCRIPT_OK_EXIT_CODE:
302+
assert status == "renewed" # nosec
310303
_logger.debug(
311-
"Renewed semaphore '%s' (instance: %s, count: %s, expired: %s)",
304+
"Renewed semaphore '%s' (instance: %s, count: %s)",
312305
self.key,
313306
self.instance_id,
314307
current_count,
315-
expired_count,
316308
)
317-
else:
318-
assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
319-
if status == "expired":
320-
_logger.warning(
321-
"Semaphore '%s' holder key expired (instance: %s, count: %s, expired: %s)",
322-
self.key,
323-
self.instance_id,
324-
current_count,
325-
expired_count,
326-
)
327-
elif status == "not_held":
328-
_logger.warning(
329-
"Semaphore '%s' not held (instance: %s, count: %s, expired: %s)",
330-
self.key,
331-
self.instance_id,
332-
current_count,
333-
expired_count,
334-
)
309+
return
310+
assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
335311

336-
raise SemaphoreLostError(name=self.key, instance_id=self.instance_id)
312+
_logger.warning(
313+
"Semaphore '%s' holder key was lost (instance: %s, status: %s, count: %s)",
314+
self.key,
315+
self.instance_id,
316+
status,
317+
current_count,
318+
)
319+
320+
raise SemaphoreLostError(name=self.key, instance_id=self.instance_id)
337321

338322
async def get_current_count(self) -> int:
339323
"""Get the current number of semaphore holders"""
340-
ttl_seconds = int(self.ttl.total_seconds())
341324

342-
# Execute the count Lua script atomically
343325
cls = type(self)
344326
assert cls.count_script is not None # nosec
345327
result = await cls.count_script( # pylint: disable=not-callable
346-
keys=(self.semaphore_key,),
347-
args=(str(ttl_seconds),),
328+
keys=[self.holders_key, self.tokens_key],
329+
args=[self.capacity],
348330
client=self.redis_client.redis,
349331
)
350332

351333
assert isinstance(result, list) # nosec
352-
current_count, expired_count = result
353-
354-
if int(expired_count) > 0:
355-
_logger.debug(
356-
"Cleaned up %s expired entries from semaphore '%s'",
357-
expired_count,
358-
self.key,
359-
)
334+
current_holders, available_tokens, capacity = result
360335

361-
return int(current_count)
336+
return int(current_holders)
362337

363338
async def get_available_count(self) -> int:
364339
"""Get the number of available semaphore slots"""

0 commit comments

Comments
 (0)