Skip to content

Commit dd06440

Browse files
committed
ok this is nice
1 parent a2636b1 commit dd06440

File tree

2 files changed

+101
-3
lines changed

2 files changed

+101
-3
lines changed

packages/service-library/src/servicelib/redis/_semaphore.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,42 @@
3333

3434
_logger = logging.getLogger(__name__)
3535

36+
# Lua script for atomic semaphore acquisition
37+
ACQUIRE_SEMAPHORE_SCRIPT = """
38+
-- Atomic semaphore acquisition script
39+
-- KEYS[1]: semaphore_key (ZSET storing holders with timestamps)
40+
-- KEYS[2]: holder_key (individual holder TTL key)
41+
-- ARGV[1]: instance_id
42+
-- ARGV[2]: capacity (max concurrent holders)
43+
-- ARGV[3]: ttl_seconds
44+
-- ARGV[4]: current_time (Redis server time)
45+
46+
local semaphore_key = KEYS[1]
47+
local holder_key = KEYS[2]
48+
local instance_id = ARGV[1]
49+
local capacity = tonumber(ARGV[2])
50+
local ttl_seconds = tonumber(ARGV[3])
51+
local current_time = tonumber(ARGV[4])
52+
53+
-- Step 1: Clean up expired entries
54+
local expiry_threshold = current_time - ttl_seconds
55+
local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold)
56+
57+
-- Step 2: Check current capacity after cleanup
58+
local current_count = redis.call('ZCARD', semaphore_key)
59+
60+
-- Step 3: Try to acquire if under capacity
61+
if current_count < capacity then
62+
-- Atomically add to semaphore and set holder key
63+
redis.call('ZADD', semaphore_key, current_time, instance_id)
64+
redis.call('SETEX', holder_key, ttl_seconds, '1')
65+
66+
return {1, 'acquired', current_count + 1, expired_count}
67+
else
68+
return {0, 'capacity_full', current_count, expired_count}
69+
end
70+
"""
71+
3672

3773
class DistributedSemaphore(BaseModel):
3874
"""
@@ -205,8 +241,63 @@ async def get_redis_time(self) -> float:
205241
seconds, microseconds = time_response
206242
return float(seconds) + float(microseconds) / 1_000_000
207243

208-
async def _try_acquire(self) -> bool:
209-
"""Atomically try to acquire the semaphore using Redis operations"""
244+
async def _acquire_with_lua(self) -> bool:
245+
"""
246+
Try to acquire the semaphore using atomic Lua script.
247+
248+
Returns:
249+
True if acquired successfully, False otherwise
250+
"""
251+
current_time = await self.get_redis_time()
252+
ttl_seconds = int(self.ttl.total_seconds())
253+
254+
try:
255+
# Execute the Lua script atomically
256+
result = await self.redis_client.redis.eval(
257+
ACQUIRE_SEMAPHORE_SCRIPT,
258+
2, # Number of keys
259+
self.semaphore_key,
260+
self.holder_key,
261+
self.instance_id,
262+
str(self.capacity),
263+
str(ttl_seconds),
264+
str(current_time),
265+
)
266+
267+
# Lua script returns: [success, status, current_count, expired_count]
268+
result_list = list(result) if isinstance(result, list | tuple) else [result]
269+
success, status, current_count, expired_count = result_list
270+
271+
if success == 1:
272+
_logger.debug(
273+
"Acquired semaphore '%s' (instance: %s, count: %s, expired: %s)",
274+
self.key,
275+
self.instance_id,
276+
current_count,
277+
expired_count,
278+
)
279+
return True
280+
281+
_logger.debug(
282+
"Failed to acquire semaphore '%s' - %s (count: %s, expired: %s)",
283+
self.key,
284+
status,
285+
current_count,
286+
expired_count,
287+
)
288+
return False
289+
290+
except Exception as exc:
291+
_logger.warning(
292+
"Error executing acquisition Lua script for semaphore '%s': %s",
293+
self.key,
294+
exc,
295+
)
296+
# Fallback to original implementation
297+
return await self._try_acquire_fallback()
298+
299+
async def _try_acquire_fallback(self) -> bool:
300+
"""Fallback implementation using Redis transactions (original method)"""
210301
current_time = await self.get_redis_time()
211302
ttl_seconds = self.ttl.total_seconds()
212303

@@ -238,6 +329,10 @@ async def _try_acquire(self) -> bool:
238329

239330
return False
240331

332+
async def _try_acquire(self) -> bool:
333+
"""Try to acquire the semaphore using Lua script with fallback"""
334+
return await self._acquire_with_lua()
335+
241336
async def get_current_count(self) -> int:
242337
"""Get the current number of semaphore holders"""
243338
current_time = await self.get_redis_time()

packages/service-library/tests/redis/test_semaphore.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,11 @@ async def test_redis_connection_failure_during_acquire(
335335
blocking=False,
336336
)
337337

338-
# Mock Redis to raise an exception
338+
# Mock Redis eval to raise an exception (which should trigger fallback)
339+
# and also mock zcard in fallback to raise an error
339340
with mock.patch.object(
341+
redis_client_sdk.redis, "eval", side_effect=Exception("Redis eval error")
342+
), mock.patch.object(
340343
redis_client_sdk.redis, "zcard", side_effect=Exception("Redis error")
341344
):
342345
result = await semaphore.acquire()

0 commit comments

Comments
 (0)