2121)
2222from redis .commands .core import AsyncScript
2323from tenacity import (
24- RetryError ,
2524 retry ,
2625 retry_if_exception_type ,
27- stop_after_delay ,
28- stop_never ,
2926 wait_random_exponential ,
3027)
31- from tenacity .stop import stop_base
3228
3329from ..background_task import periodic
3430from ._client import RedisClientSDK
@@ -195,6 +191,55 @@ async def _ensure_semaphore_initialized(self) -> None:
195191 client = self .redis_client .redis ,
196192 )
197193
194+ async def _blocking_acquire (self ) -> str | None :
195+ @retry (
196+ wait = wait_random_exponential (min = 0.1 , max = 0.5 ),
197+ retry = retry_if_exception_type (redis .exceptions .TimeoutError ),
198+ )
199+ async def _acquire_forever_on_socket_timeout () -> list [str ] | None :
200+ # NOTE: brpop returns None on timeout
201+
202+ tokens_key_token : list [str ] | None = await handle_redis_returns_union_types (
203+ self .redis_client .redis .brpop (
204+ [self .tokens_key ],
205+ timeout = None , # NOTE: we always block forever since tenacity takes care of timing out
206+ )
207+ )
208+ return tokens_key_token
209+
210+ try :
211+ # NOTE: redis-py library timeouts when the defined socket timeout triggers (e.g. DEFAULT_SOCKET_TIMEOUT)
212+ # The BRPOP command itself could timeout but the redis-py socket timeout defeats the purpose
213+ # so we always block forever on BRPOP, tenacity takes care of retrying when a socket timeout happens
214+ # and we use asyncio.timeout to enforce the blocking_timeout if defined
215+ async with asyncio .timeout (
216+ self .blocking_timeout .total_seconds () if self .blocking_timeout else None
217+ ):
218+ tokens_key_token = await _acquire_forever_on_socket_timeout ()
219+ assert tokens_key_token is not None # nosec
220+ assert len (tokens_key_token ) == 2 # nosec # noqa: PLR2004
221+ assert tokens_key_token [0 ] == self .tokens_key # nosec
222+ return tokens_key_token [1 ]
223+ except TimeoutError as e :
224+ raise SemaphoreAcquisitionError (
225+ name = self .key , instance_id = self .instance_id
226+ ) from e
227+
228+ async def _non_blocking_acquire (self ) -> str | None :
229+ token : str | list [str ] | None = await handle_redis_returns_union_types (
230+ self .redis_client .redis .rpop (self .tokens_key )
231+ )
232+ if token is None :
233+ _logger .debug (
234+ "Semaphore '%s' not acquired (no tokens available) (instance: %s)" ,
235+ self .key ,
236+ self .instance_id ,
237+ )
238+ return None
239+
240+ assert isinstance (token , str ) # nosec
241+ return token
242+
198243 async def acquire (self ) -> bool :
199244 """
200245 Acquire the semaphore.
@@ -217,60 +262,14 @@ async def acquire(self) -> bool:
217262
218263 ttl_seconds = self .ttl .total_seconds ()
219264
220- # Determine retry stop condition based on blocking configuration
221- stop_condition : stop_base = stop_after_delay (0 )
222- if self .blocking :
223- stop_condition = (
224- stop_after_delay (self .blocking_timeout )
225- if self .blocking_timeout
226- else stop_never
227- )
228-
229- try :
230-
231- @retry (
232- stop = stop_condition ,
233- wait = wait_random_exponential (min = 0.1 , max = 0.5 ),
234- retry = retry_if_exception_type (redis .exceptions .TimeoutError ),
235- )
236- async def _try_acquire () -> list [str ] | None :
237- # NOTE: brpop returns None on timeout
238- # NOTE: redis-py library timeouts when the socket times out which is defined
239- # elsewhere on the client (e.g. DEFAULT_SOCKET_TIMEOUT)
240- # we always block forever since tenacity takes care of timing out
241- # therefore we can distinguish between a proper timeout (returns None) and a socket
242- # timeout (raises an exception)
243-
244- tokens_key_token : list [str ] | None = (
245- await handle_redis_returns_union_types (
246- self .redis_client .redis .brpop (
247- [self .tokens_key ],
248- timeout = None , # NOTE: we always block forever since tenacity takes care of timing out
249- )
250- )
251- )
252- return tokens_key_token
253-
254- tokens_key_token = await _try_acquire ()
255- except RetryError as e :
256- # NOTE: if we end up here that means we could not acquire the semaphore
257- _logger .debug (
258- "Timeout acquiring semaphore '%s' (instance: %s)" ,
259- self .key ,
260- self .instance_id ,
261- )
262- if self .blocking :
263- raise SemaphoreAcquisitionError (
264- name = self .key , instance_id = self .instance_id
265- ) from e
266- return False
267-
268- # If we got here it means we acquired a token
269- assert tokens_key_token is not None # nosec
270- assert len (tokens_key_token ) == 2 # nosec # noqa: PLR2004
271- assert tokens_key_token [0 ] == self .tokens_key # nosec
272- self ._token = tokens_key_token [1 ]
265+ if self .blocking is False :
266+ self ._token = await self ._non_blocking_acquire ()
267+ if not self ._token :
268+ return False
269+ else :
270+ self ._token = await self ._blocking_acquire ()
273271
272+ assert self ._token is not None # nosec
274273 # set up the semaphore holder with a TTL
275274 cls = type (self )
276275 assert cls .acquire_script is not None # nosec
0 commit comments