Skip to content

Commit b4a42a1

Browse files
committed
ongoing fair queuing semaphore
1 parent 16dafee commit b4a42a1

File tree

9 files changed

+1327
-0
lines changed

9 files changed

+1327
-0
lines changed

packages/service-library/src/servicelib/redis/fair_semaphore.py

Lines changed: 404 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""Fair semaphore decorator with automatic renewal and crash recovery."""
2+
3+
import asyncio
4+
import datetime
5+
import functools
6+
import logging
7+
from collections.abc import Callable, Coroutine
8+
from contextlib import asynccontextmanager
9+
from typing import Any, ParamSpec, TypeVar
10+
11+
from common_library.logging.logging_errors import create_troubleshooting_log_kwargs
12+
13+
from ._constants import (
14+
DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
15+
DEFAULT_SEMAPHORE_TTL,
16+
DEFAULT_SOCKET_TIMEOUT,
17+
)
18+
from ._errors import (
19+
SemaphoreAcquisitionError,
20+
SemaphoreLostError,
21+
SemaphoreNotAcquiredError,
22+
)
23+
from .fair_semaphore import FairSemaphore
24+
25+
_logger = logging.getLogger(__name__)
26+
27+
P = ParamSpec("P")
28+
R = TypeVar("R")
29+
30+
31+
@asynccontextmanager
32+
async def _managed_fair_semaphore_execution(
33+
semaphore: FairSemaphore,
34+
semaphore_key: str,
35+
ttl: datetime.timedelta,
36+
execution_context: str,
37+
enable_auto_renewal: bool = True,
38+
):
39+
"""Context manager for fair semaphore with auto-renewal."""
40+
41+
async def _auto_renewal():
42+
"""Background task to automatically renew semaphore."""
43+
if not enable_auto_renewal:
44+
return
45+
46+
renewal_interval = ttl.total_seconds() / 3 # Renew at 1/3 TTL
47+
48+
while semaphore.acquired:
49+
try:
50+
await asyncio.sleep(renewal_interval)
51+
if semaphore.acquired: # Check again after sleep
52+
await semaphore.renew()
53+
_logger.debug(f"Renewed fair semaphore {semaphore_key}")
54+
except SemaphoreLostError:
55+
_logger.error(
56+
f"Fair semaphore {semaphore_key} was lost during execution"
57+
)
58+
break
59+
except Exception as e:
60+
_logger.warning(f"Failed to renew fair semaphore {semaphore_key}: {e}")
61+
break
62+
63+
renewal_task = None
64+
try:
65+
# Acquire the semaphore (blocks until available)
66+
if not await semaphore.acquire():
67+
raise SemaphoreAcquisitionError(
68+
f"Failed to acquire fair semaphore {semaphore_key}"
69+
)
70+
71+
_logger.info(f"Acquired fair semaphore {semaphore_key} for {execution_context}")
72+
73+
# Start auto-renewal task if enabled
74+
if enable_auto_renewal:
75+
renewal_task = asyncio.create_task(_auto_renewal())
76+
77+
yield
78+
79+
except Exception as e:
80+
_logger.error(
81+
f"Error in fair semaphore-protected execution: {e}",
82+
extra=create_troubleshooting_log_kwargs(
83+
context=execution_context,
84+
semaphore_key=semaphore_key,
85+
),
86+
)
87+
raise
88+
finally:
89+
# Cancel renewal task
90+
if renewal_task and not renewal_task.done():
91+
renewal_task.cancel()
92+
try:
93+
await renewal_task
94+
except asyncio.CancelledError:
95+
pass
96+
97+
# Release semaphore
98+
if semaphore.acquired:
99+
try:
100+
await semaphore.release()
101+
_logger.info(f"Released fair semaphore {semaphore_key}")
102+
except Exception as e:
103+
_logger.error(f"Failed to release fair semaphore {semaphore_key}: {e}")
104+
105+
106+
def fair_semaphore(
107+
*,
108+
key: str,
109+
capacity: int,
110+
ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL,
111+
timeout: datetime.timedelta = DEFAULT_SOCKET_TIMEOUT,
112+
expected_execution_time: datetime.timedelta = DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
113+
cleanup_interval: datetime.timedelta = datetime.timedelta(seconds=30),
114+
enable_auto_cleanup: bool = True,
115+
enable_auto_renewal: bool = True,
116+
) -> Callable[
117+
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
118+
]:
119+
"""
120+
Decorator that protects async functions with a fair distributed semaphore.
121+
122+
Uses Redis BRPOP for true FIFO fairness - first requester gets first slot.
123+
No starvation possible, automatic crash recovery.
124+
125+
Args:
126+
key: Unique semaphore identifier
127+
capacity: Maximum concurrent executions allowed
128+
ttl: How long each holder can keep the semaphore
129+
timeout: How long to wait for semaphore (0 = infinite wait)
130+
expected_execution_time: Expected total execution time (unused, kept for compatibility)
131+
cleanup_interval: How often to run cleanup for crashed clients
132+
enable_auto_cleanup: Whether to run background cleanup
133+
enable_auto_renewal: Whether to automatically renew TTL during execution
134+
135+
Example:
136+
@fair_semaphore(
137+
key="api_calls",
138+
capacity=10,
139+
ttl=datetime.timedelta(seconds=30),
140+
timeout=datetime.timedelta(seconds=60)
141+
)
142+
async def call_external_api():
143+
# This will block fairly until semaphore available
144+
# Maximum 10 concurrent executions
145+
# First-come-first-served ordering guaranteed
146+
pass
147+
"""
148+
149+
def decorator(
150+
func: Callable[P, Coroutine[Any, Any, R]],
151+
) -> Callable[P, Coroutine[Any, Any, R]]:
152+
@functools.wraps(func)
153+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
154+
semaphore = FairSemaphore(
155+
key=key,
156+
capacity=capacity,
157+
ttl=ttl,
158+
timeout=timeout,
159+
cleanup_interval=cleanup_interval,
160+
enable_auto_cleanup=enable_auto_cleanup,
161+
)
162+
163+
execution_context = f"{func.__module__}.{func.__qualname__}"
164+
165+
async with _managed_fair_semaphore_execution(
166+
semaphore=semaphore,
167+
semaphore_key=key,
168+
ttl=ttl,
169+
execution_context=execution_context,
170+
enable_auto_renewal=enable_auto_renewal,
171+
):
172+
return await func(*args, **kwargs)
173+
174+
return wrapper
175+
176+
return decorator
177+
178+
179+
class FairSemaphoreContext:
180+
"""Async context manager for manual fair semaphore control."""
181+
182+
def __init__(
183+
self,
184+
key: str,
185+
capacity: int,
186+
ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL,
187+
timeout: datetime.timedelta = DEFAULT_SOCKET_TIMEOUT,
188+
cleanup_interval: datetime.timedelta = datetime.timedelta(seconds=30),
189+
enable_auto_cleanup: bool = True,
190+
enable_auto_renewal: bool = True,
191+
):
192+
self.semaphore = FairSemaphore(
193+
key=key,
194+
capacity=capacity,
195+
ttl=ttl,
196+
timeout=timeout,
197+
cleanup_interval=cleanup_interval,
198+
enable_auto_cleanup=enable_auto_cleanup,
199+
)
200+
self.ttl = ttl
201+
self.enable_auto_renewal = enable_auto_renewal
202+
self._renewal_task: Optional[asyncio.Task] = None
203+
204+
async def __aenter__(self) -> FairSemaphore:
205+
"""Acquire semaphore and start auto-renewal."""
206+
await self.semaphore.acquire()
207+
208+
# Start auto-renewal if enabled
209+
if self.enable_auto_renewal:
210+
211+
async def _auto_renewal():
212+
renewal_interval = self.ttl.total_seconds() / 3
213+
while self.semaphore.acquired:
214+
try:
215+
await asyncio.sleep(renewal_interval)
216+
if self.semaphore.acquired:
217+
await self.semaphore.renew()
218+
except (SemaphoreLostError, SemaphoreNotAcquiredError):
219+
break
220+
except Exception as e:
221+
_logger.warning(f"Auto-renewal failed: {e}")
222+
223+
self._renewal_task = asyncio.create_task(_auto_renewal())
224+
225+
return self.semaphore
226+
227+
async def __aexit__(self, exc_type, exc_val, exc_tb):
228+
"""Stop renewal and release semaphore."""
229+
if self._renewal_task and not self._renewal_task.done():
230+
self._renewal_task.cancel()
231+
try:
232+
await self._renewal_task
233+
except asyncio.CancelledError:
234+
pass
235+
236+
if self.semaphore.acquired:
237+
await self.semaphore.release()
238+
239+
240+
# Convenience function for creating fair semaphore contexts
241+
def fair_semaphore_context(
242+
key: str,
243+
capacity: int,
244+
ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL,
245+
timeout: datetime.timedelta = DEFAULT_SOCKET_TIMEOUT,
246+
cleanup_interval: datetime.timedelta = datetime.timedelta(seconds=30),
247+
enable_auto_cleanup: bool = True,
248+
enable_auto_renewal: bool = True,
249+
) -> FairSemaphoreContext:
250+
"""
251+
Create an async context manager for fair semaphore usage.
252+
253+
Example:
254+
async with fair_semaphore_context(
255+
"my_resource",
256+
capacity=5,
257+
timeout=datetime.timedelta(seconds=30)
258+
) as sem:
259+
# Protected code here - guaranteed fair access
260+
# sem is the FairSemaphore instance
261+
stats = await sem.count()
262+
print(f"Current holders: {stats['current_holders']}")
263+
"""
264+
return FairSemaphoreContext(
265+
key=key,
266+
capacity=capacity,
267+
ttl=ttl,
268+
timeout=timeout,
269+
cleanup_interval=cleanup_interval,
270+
enable_auto_cleanup=enable_auto_cleanup,
271+
enable_auto_renewal=enable_auto_renewal,
272+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
-- Fair distributed semaphore using token pool (BRPOP-based)
2+
-- KEYS[1]: tokens_key (LIST of available tokens)
3+
-- KEYS[2]: holders_key (SET of current holder instance IDs)
4+
-- KEYS[3]: holder_key (individual holder TTL key for this instance)
5+
-- ARGV[1]: instance_id
6+
-- ARGV[2]: capacity (max concurrent holders)
7+
-- ARGV[3]: ttl_seconds
8+
-- ARGV[4]: timeout_seconds (for BRPOP)
9+
--
10+
-- Returns: {exit_code, status, token, current_count}
11+
-- exit_code: 0 if acquired, 255 if timeout/failed
12+
-- status: 'acquired' or 'timeout'
13+
14+
local holders_key = KEYS[1]
15+
local holder_key = KEYS[2]
16+
17+
local token = ARGV[1]
18+
local instance_id = ARGV[2]
19+
local ttl_seconds = tonumber(ARGV[3])
20+
21+
22+
23+
-- Step 1: Register as holder
24+
redis.call('SADD', holders_key, instance_id)
25+
redis.call('SETEX', holder_key, ttl_seconds, token)
26+
27+
-- Step 2: Set expiry on holders set to prevent infinite growth
28+
redis.call('EXPIRE', holders_key, ttl_seconds * 10)
29+
30+
local current_count = redis.call('SCARD', holders_key)
31+
32+
return {0, 'acquired', token, current_count}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
-- Cleanup orphaned tokens from crashed clients
2+
-- KEYS[1]: tokens_key (LIST of available tokens)
3+
-- KEYS[2]: holders_key (SET of current holders)
4+
-- KEYS[3]: holder_prefix (prefix for holder keys, e.g. "semaphores:holders:key:")
5+
-- ARGV[1]: capacity (total semaphore capacity)
6+
--
7+
-- Returns: {recovered_tokens, current_holders, available_tokens, total_cleaned}
8+
-- This script should be run periodically to recover tokens from crashed clients
9+
10+
local tokens_key = KEYS[1]
11+
local holders_key = KEYS[2]
12+
local holder_prefix = KEYS[3]
13+
14+
local capacity = tonumber(ARGV[1])
15+
16+
-- Step 1: Get all current holders
17+
local current_holders = redis.call('SMEMBERS', holders_key)
18+
local recovered_tokens = 0
19+
local cleaned_holders = {}
20+
21+
-- Step 2: Check each holder to see if their TTL key still exists
22+
for i = 1, #current_holders do
23+
local holder_id = current_holders[i]
24+
local holder_key = holder_prefix .. holder_id
25+
local exists = redis.call('EXISTS', holder_key)
26+
27+
if exists == 0 then
28+
-- Holder key doesn't exist but holder is in SET
29+
-- This indicates a crashed client - clean up and recover token
30+
redis.call('SREM', holders_key, holder_id)
31+
redis.call('LPUSH', tokens_key, 'token_recovered_' .. holder_id)
32+
recovered_tokens = recovered_tokens + 1
33+
table.insert(cleaned_holders, holder_id)
34+
end
35+
end
36+
37+
-- Step 3: Ensure we have the correct total number of tokens
38+
local remaining_holders = redis.call('SCARD', holders_key)
39+
local available_tokens_count = redis.call('LLEN', tokens_key)
40+
local total_tokens = remaining_holders + available_tokens_count
41+
42+
-- If we're missing tokens (due to crashes or Redis issues), add them back
43+
local missing_tokens = capacity - total_tokens
44+
for i = 1, missing_tokens do
45+
redis.call('LPUSH', tokens_key, 'token_missing_' .. i)
46+
recovered_tokens = recovered_tokens + 1
47+
end
48+
49+
-- If we somehow have too many tokens (shouldn't happen), remove extras
50+
local excess_tokens = total_tokens - capacity
51+
for i = 1, excess_tokens do
52+
redis.call('RPOP', tokens_key)
53+
end
54+
55+
-- Step 4: Refresh expiry on data structures to prevent cleanup
56+
local final_holders = redis.call('SCARD', holders_key)
57+
local final_available = redis.call('LLEN', tokens_key)
58+
59+
if final_holders > 0 then
60+
redis.call('EXPIRE', holders_key, 3600) -- 1 hour expiry
61+
end
62+
if final_available > 0 then
63+
redis.call('EXPIRE', tokens_key, 3600) -- 1 hour expiry
64+
end
65+
66+
return {recovered_tokens, final_holders, final_available, #cleaned_holders}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
-- Count current semaphore holders (simplified for token pool design)
2+
-- KEYS[1]: holders_key (SET of current holders)
3+
-- KEYS[2]: tokens_key (LIST of available tokens)
4+
-- ARGV[1]: capacity (total semaphore capacity)
5+
--
6+
-- Returns: {current_holders, available_tokens, total_capacity}
7+
8+
local holders_key = KEYS[1]
9+
local tokens_key = KEYS[2]
10+
11+
local capacity = tonumber(ARGV[1])
12+
13+
-- Count current holders and available tokens
14+
local current_holders = redis.call('SCARD', holders_key)
15+
local available_tokens = redis.call('LLEN', tokens_key)
16+
17+
return {current_holders, available_tokens, capacity}

0 commit comments

Comments
 (0)