1313 field_validator ,
1414)
1515from redis .commands .core import AsyncScript
16- from tenacity import (
17- RetryError ,
18- before_sleep_log ,
19- retry ,
20- retry_if_not_result ,
21- stop_after_delay ,
22- stop_never ,
23- wait_random_exponential ,
24- )
2516
2617from ._client import RedisClientSDK
2718from ._constants import (
3627 SemaphoreNotAcquiredError ,
3728)
3829from ._semaphore_lua import (
39- ACQUIRE_SEMAPHORE_SCRIPT ,
40- COUNT_SEMAPHORE_SCRIPT ,
41- RELEASE_SEMAPHORE_SCRIPT ,
42- RENEW_SEMAPHORE_SCRIPT ,
30+ ACQUIRE_FAIR_SEMAPHORE_V2_SCRIPT ,
31+ COUNT_FAIR_SEMAPHORE_V2_SCRIPT ,
32+ RELEASE_FAIR_SEMAPHORE_V2_SCRIPT ,
33+ RENEW_FAIR_SEMAPHORE_V2_SCRIPT ,
4334 SCRIPT_BAD_EXIT_CODE ,
4435 SCRIPT_OK_EXIT_CODE ,
4536)
@@ -114,16 +105,16 @@ def _register_scripts(cls, redis_client: RedisClientSDK) -> None:
114105 the script is only registered once."""
115106 if cls .acquire_script is None :
116107 cls .acquire_script = redis_client .redis .register_script (
117- ACQUIRE_SEMAPHORE_SCRIPT
108+ ACQUIRE_FAIR_SEMAPHORE_V2_SCRIPT
118109 )
119110 cls .count_script = redis_client .redis .register_script (
120- COUNT_SEMAPHORE_SCRIPT
111+ COUNT_FAIR_SEMAPHORE_V2_SCRIPT
121112 )
122113 cls .release_script = redis_client .redis .register_script (
123- RELEASE_SEMAPHORE_SCRIPT
114+ RELEASE_FAIR_SEMAPHORE_V2_SCRIPT
124115 )
125116 cls .renew_script = redis_client .redis .register_script (
126- RENEW_SEMAPHORE_SCRIPT
117+ RENEW_FAIR_SEMAPHORE_V2_SCRIPT
127118 )
128119
129120 def __init__ (self , ** data ) -> None :
@@ -136,6 +127,18 @@ def semaphore_key(self) -> str:
136127 """Redis key for the semaphore sorted set."""
137128 return f"{ SEMAPHORE_KEY_PREFIX } { self .key } "
138129
130+ @computed_field # type: ignore[prop-decorator]
131+ @property
132+ def tokens_key (self ) -> str :
133+ """Redis key for the token pool LIST."""
134+ return f"{ SEMAPHORE_KEY_PREFIX } { self .key } :tokens"
135+
136+ @computed_field # type: ignore[prop-decorator]
137+ @property
138+ def holders_key (self ) -> str :
139+ """Redis key for the holders SET."""
140+ return f"{ SEMAPHORE_KEY_PREFIX } { self .key } :holders"
141+
139142 @computed_field # type: ignore[prop-decorator]
140143 @property
141144 def holder_key (self ) -> str :
@@ -172,109 +175,103 @@ async def acquire(self) -> bool:
172175 SemaphoreAcquisitionError: If acquisition fails and blocking=True
173176 """
174177
175- if not self .blocking :
176- # Non-blocking: try once
177- return await self ._try_acquire ()
178-
179- # Blocking
180- @retry (
181- wait = wait_random_exponential (min = 0.1 , max = 2 ),
182- reraise = True ,
183- stop = (
184- stop_after_delay (self .blocking_timeout .total_seconds ())
185- if self .blocking_timeout
186- else stop_never
187- ),
188- retry = retry_if_not_result (lambda acquired : acquired ),
189- before_sleep = before_sleep_log (_logger , logging .DEBUG ),
190- )
191- async def _blocking_acquire () -> bool :
192- return await self ._try_acquire ()
193-
194- try :
195- return await _blocking_acquire ()
196- except RetryError as exc :
197- raise SemaphoreAcquisitionError (
198- name = self .key , capacity = self .capacity
199- ) from exc
200-
201- async def release (self ) -> None :
202- """
203- Release the semaphore atomically using Lua script.
204-
205- Raises:
206- SemaphoreNotAcquiredError: If semaphore was not acquired by this instance
207- """
208178 ttl_seconds = int (self .ttl .total_seconds ())
179+ blocking_timeout_seconds = 0.1
180+ if self .blocking :
181+ blocking_timeout_seconds = (
182+ self .blocking_timeout .total_seconds () if self .blocking_timeout else 0
183+ )
209184
210- # Execute the release Lua script atomically
185+ # Execute the Lua script atomically
211186 cls = type (self )
212- assert cls .release_script is not None # nosec
213- result = await cls .release_script ( # pylint: disable=not-callable
214- keys = (
215- self .semaphore_key ,
216- self .holder_key ,
217- ),
218- args = (
187+ assert cls .acquire_script is not None # nosec
188+ result = await cls .acquire_script ( # pylint: disable=not-callable
189+ keys = [self .tokens_key , self .holders_key , self .holder_key ],
190+ args = [
219191 self .instance_id ,
220- str (ttl_seconds ),
221- ),
192+ self .capacity ,
193+ ttl_seconds ,
194+ blocking_timeout_seconds ,
195+ ],
222196 client = self .redis_client .redis ,
223197 )
224198
199+ # Lua script returns: [exit_code, status, current_count, expired_count]
225200 assert isinstance (result , list ) # nosec
226- exit_code , status , current_count , expired_count = result
227- result = status
201+ exit_code , status , token , current_count = result
228202
229- if result == "released" :
230- assert exit_code == SCRIPT_OK_EXIT_CODE # nosec
203+ if exit_code == SCRIPT_OK_EXIT_CODE :
204+ _logger .debug (
205+ "Acquired semaphore '%s' with token %s (instance: %s, count: %s)" ,
206+ self .key ,
207+ token ,
208+ self .instance_id ,
209+ current_count ,
210+ )
211+ return True
212+ if status == "timeout" :
213+ if self .blocking :
214+ _logger .debug (
215+ "Timeout acquiring semaphore '%s' (instance: %s, count: %s)" ,
216+ self .key ,
217+ self .instance_id ,
218+ current_count ,
219+ )
220+ raise SemaphoreAcquisitionError (name = self .key , capacity = self .capacity )
231221 _logger .debug (
232- "Released semaphore '%s' (instance: %s, count: %s, expired : %s)" ,
222+ "Timeout acquiring semaphore '%s' (instance: %s, count: %s)" ,
233223 self .key ,
234224 self .instance_id ,
235225 current_count ,
236- expired_count ,
237226 )
238- else :
239- # Instance wasn't in the semaphore set - this shouldn't happen
240- # but let's handle it gracefully
241- assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
242- raise SemaphoreNotAcquiredError (name = self .key )
227+ return False
243228
244- async def _try_acquire (self ) -> bool :
245- ttl_seconds = int (self .ttl .total_seconds ())
229+ _logger .debug (
230+ "Failed to acquire semaphore '%s' - %s (count: %s)" ,
231+ self .key ,
232+ status ,
233+ current_count ,
234+ )
235+ raise SemaphoreAcquisitionError (name = self .key , capacity = self .capacity )
246236
247- # Execute the Lua script atomically
237+ async def release (self ) -> None :
238+ """
239+ Release the semaphore atomically using Lua script.
240+
241+ Raises:
242+ SemaphoreNotAcquiredError: If semaphore was not acquired by this instance
243+ """
244+
245+ # Execute the release Lua script atomically
248246 cls = type (self )
249- assert cls .acquire_script is not None # nosec
250- result = await cls .acquire_script ( # pylint: disable=not-callable
251- keys = ( self .semaphore_key , self .holder_key ) ,
252- args = ( self .instance_id , str ( self . capacity ), str ( ttl_seconds )) ,
247+ assert cls .release_script is not None # nosec
248+ result = await cls .release_script ( # pylint: disable=not-callable
249+ keys = [ self .tokens_key , self .holders_key , self . holder_key ] ,
250+ args = [ self .instance_id ] ,
253251 client = self .redis_client .redis ,
254252 )
255253
256- # Lua script returns: [exit_code, status, current_count, expired_count]
257254 assert isinstance (result , list ) # nosec
258- exit_code , status , current_count , expired_count = result
259-
255+ exit_code , status , current_count = result
260256 if exit_code == SCRIPT_OK_EXIT_CODE :
257+ assert status == "released" # nosec
261258 _logger .debug (
262- "Acquired semaphore '%s' (instance: %s, count: %s, expired : %s)" ,
259+ "Released semaphore '%s' (instance: %s, count: %s)" ,
263260 self .key ,
264261 self .instance_id ,
265262 current_count ,
266- expired_count ,
267263 )
268- return True
269264
270- _logger .debug (
271- "Failed to acquire semaphore '%s' - %s (count: %s, expired: %s)" ,
265+ # Instance was already expired or not acquired
266+ assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
267+ _logger .error (
268+ "Failed to release semaphore '%s' - %s (instance: %s, count: %s)" ,
272269 self .key ,
273270 status ,
271+ self .instance_id ,
274272 current_count ,
275- expired_count ,
276273 )
277- return False
274+ raise SemaphoreNotAcquiredError ( name = self . key )
278275
279276 async def reacquire (self ) -> None :
280277 """
@@ -293,72 +290,50 @@ async def reacquire(self) -> None:
293290 cls = type (self )
294291 assert cls .renew_script is not None # nosec
295292 result = await cls .renew_script ( # pylint: disable=not-callable
296- keys = (self .semaphore_key , self .holder_key ),
297- args = (
298- self .instance_id ,
299- str (ttl_seconds ),
300- ),
293+ keys = [self .holders_key , self .holder_key ],
294+ args = [self .instance_id , ttl_seconds ],
301295 client = self .redis_client .redis ,
302296 )
303297
304298 assert isinstance (result , list ) # nosec
305- exit_code , status , current_count , expired_count = result
299+ exit_code , status , current_count = result
306300
307- # Lua script returns: 'renewed' or status message
308- if status == "renewed" :
309- assert exit_code == SCRIPT_OK_EXIT_CODE # nosec
301+ if exit_code == SCRIPT_OK_EXIT_CODE :
302+ assert status == "renewed" # nosec
310303 _logger .debug (
311- "Renewed semaphore '%s' (instance: %s, count: %s, expired: %s )" ,
304+ "Renewed semaphore '%s' (instance: %s, count: %s)" ,
312305 self .key ,
313306 self .instance_id ,
314307 current_count ,
315- expired_count ,
316308 )
317- else :
318- assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
319- if status == "expired" :
320- _logger .warning (
321- "Semaphore '%s' holder key expired (instance: %s, count: %s, expired: %s)" ,
322- self .key ,
323- self .instance_id ,
324- current_count ,
325- expired_count ,
326- )
327- elif status == "not_held" :
328- _logger .warning (
329- "Semaphore '%s' not held (instance: %s, count: %s, expired: %s)" ,
330- self .key ,
331- self .instance_id ,
332- current_count ,
333- expired_count ,
334- )
309+ return
310+ assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec
335311
336- raise SemaphoreLostError (name = self .key , instance_id = self .instance_id )
312+ _logger .warning (
313+ "Semaphore '%s' holder key was lost (instance: %s, status: %s, count: %s)" ,
314+ self .key ,
315+ self .instance_id ,
316+ status ,
317+ current_count ,
318+ )
319+
320+ raise SemaphoreLostError (name = self .key , instance_id = self .instance_id )
337321
338322 async def get_current_count (self ) -> int :
339323 """Get the current number of semaphore holders"""
340- ttl_seconds = int (self .ttl .total_seconds ())
341324
342- # Execute the count Lua script atomically
343325 cls = type (self )
344326 assert cls .count_script is not None # nosec
345327 result = await cls .count_script ( # pylint: disable=not-callable
346- keys = ( self .semaphore_key ,) ,
347- args = ( str ( ttl_seconds ),) ,
328+ keys = [ self .holders_key , self . tokens_key ] ,
329+ args = [ self . capacity ] ,
348330 client = self .redis_client .redis ,
349331 )
350332
351333 assert isinstance (result , list ) # nosec
352- current_count , expired_count = result
353-
354- if int (expired_count ) > 0 :
355- _logger .debug (
356- "Cleaned up %s expired entries from semaphore '%s'" ,
357- expired_count ,
358- self .key ,
359- )
334+ current_holders , available_tokens , capacity = result
360335
361- return int (current_count )
336+ return int (current_holders )
362337
363338 async def get_available_count (self ) -> int :
364339 """Get the number of available semaphore slots"""
0 commit comments