|
1 | 1 | import asyncio |
2 | 2 | import datetime |
3 | | -import functools |
4 | 3 | import logging |
5 | | -import socket |
6 | 4 | import uuid |
7 | | -from collections.abc import Callable, Coroutine |
8 | 5 | from types import TracebackType |
9 | | -from typing import Annotated, Any, ParamSpec, TypeVar |
| 6 | +from typing import Annotated |
10 | 7 |
|
11 | 8 | from pydantic import ( |
12 | 9 | BaseModel, |
|
17 | 14 | field_validator, |
18 | 15 | ) |
19 | 16 |
|
20 | | -from ..background_task import periodic |
21 | | -from ..logging_errors import create_troubleshootting_log_kwargs |
22 | 17 | from ..logging_utils import log_catch |
23 | 18 | from ._client import RedisClientSDK |
24 | 19 | from ._constants import ( |
|
32 | 27 | _logger = logging.getLogger(__name__) |
33 | 28 |
|
34 | 29 |
|
35 | | -async def _renew_semaphore_entry(semaphore: "DistributedSemaphore") -> None: |
36 | | - """ |
37 | | - Manually renew a semaphore entry by updating its timestamp and TTL. |
38 | | -
|
39 | | - This function is intended to be called by decorators or external renewal mechanisms. |
40 | | -
|
41 | | - Args: |
42 | | - semaphore: The semaphore instance to renew |
43 | | -
|
44 | | - Raises: |
45 | | - Exception: If the renewal operation fails |
46 | | - """ |
47 | | - |
48 | | - current_time = asyncio.get_event_loop().time() |
49 | | - ttl_seconds = semaphore.ttl.total_seconds() |
50 | | - |
51 | | - # Update timestamp in sorted set and refresh holder key |
52 | | - async with semaphore.redis_client.redis.pipeline(transaction=True) as pipe: |
53 | | - await pipe.zadd(semaphore.semaphore_key, {semaphore.instance_id: current_time}) |
54 | | - await pipe.expire(semaphore.holder_key, int(ttl_seconds)) |
55 | | - await pipe.execute() |
56 | | - |
57 | | - |
58 | 30 | class DistributedSemaphore(BaseModel): |
59 | 31 | """ |
| 32 | + Warning: This should only be used directly via the decorator |
60 | 33 | A distributed semaphore implementation using Redis. |
61 | 34 |
|
62 | 35 | This semaphore allows limiting the number of concurrent operations across |
@@ -283,155 +256,3 @@ async def __aexit__( |
283 | 256 | ) -> None: |
284 | 257 | if self._acquired: |
285 | 258 | await self.release() |
286 | | - |
287 | | - |
288 | | -P = ParamSpec("P") |
289 | | -R = TypeVar("R") |
290 | | - |
291 | | - |
292 | | -def with_limited_concurrency( |
293 | | - redis_client: RedisClientSDK | Callable[..., RedisClientSDK], |
294 | | - *, |
295 | | - key: str | Callable[..., str], |
296 | | - capacity: int | Callable[..., int], |
297 | | - ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL, |
298 | | - blocking: bool = True, |
299 | | - blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT, |
300 | | -) -> Callable[ |
301 | | - [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] |
302 | | -]: |
303 | | - """ |
304 | | - Decorator to limit concurrent execution of a function using a distributed semaphore. |
305 | | -
|
306 | | - This decorator ensures that only a specified number of instances of the decorated |
307 | | - function can run concurrently across multiple processes/instances using Redis |
308 | | - as the coordination backend. |
309 | | -
|
310 | | - Args: |
311 | | - redis_client: Redis client for coordination (can be callable) |
312 | | - key: Unique identifier for the semaphore (can be callable) |
313 | | - capacity: Maximum number of concurrent executions (can be callable) |
314 | | - ttl: Time-to-live for semaphore entries (default: 5 minutes) |
315 | | - blocking: Whether to block when semaphore is full (default: True) |
316 | | - blocking_timeout: Maximum time to wait when blocking (default: socket timeout) |
317 | | -
|
318 | | - Example: |
319 | | - @with_limited_concurrency( |
320 | | - redis_client, |
321 | | - key=f"{user_id}-{wallet_id}", |
322 | | - capacity=20, |
323 | | - blocking=True, |
324 | | - blocking_timeout=None |
325 | | - ) |
326 | | - async def process_user_wallet(user_id: str, wallet_id: str): |
327 | | - # Only 20 instances of this function can run concurrently |
328 | | - # for the same user_id-wallet_id combination |
329 | | - await do_processing() |
330 | | -
|
331 | | - Raises: |
332 | | - SemaphoreAcquisitionError: If semaphore cannot be acquired and blocking=True |
333 | | - """ |
334 | | - |
335 | | - def _decorator( |
336 | | - coro: Callable[P, Coroutine[Any, Any, R]], |
337 | | - ) -> Callable[P, Coroutine[Any, Any, R]]: |
338 | | - @functools.wraps(coro) |
339 | | - async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
340 | | - # Resolve callable parameters |
341 | | - semaphore_key = key(*args, **kwargs) if callable(key) else key |
342 | | - semaphore_capacity = ( |
343 | | - capacity(*args, **kwargs) if callable(capacity) else capacity |
344 | | - ) |
345 | | - client = ( |
346 | | - redis_client(*args, **kwargs) |
347 | | - if callable(redis_client) |
348 | | - else redis_client |
349 | | - ) |
350 | | - |
351 | | - assert isinstance(semaphore_key, str) # nosec |
352 | | - assert isinstance(semaphore_capacity, int) # nosec |
353 | | - assert isinstance(client, RedisClientSDK) # nosec |
354 | | - |
355 | | - # Create the semaphore (without auto-renewal) |
356 | | - semaphore = DistributedSemaphore( |
357 | | - redis_client=client, |
358 | | - key=semaphore_key, |
359 | | - capacity=semaphore_capacity, |
360 | | - ttl=ttl, |
361 | | - blocking=blocking, |
362 | | - blocking_timeout=blocking_timeout, |
363 | | - ) |
364 | | - |
365 | | - # Acquire the semaphore first |
366 | | - if not await semaphore.acquire(): |
367 | | - raise SemaphoreAcquisitionError( |
368 | | - name=semaphore_key, capacity=semaphore_capacity |
369 | | - ) |
370 | | - |
371 | | - try: |
372 | | - # Use TaskGroup for proper exception propagation (similar to exclusive decorator) |
373 | | - async with asyncio.TaskGroup() as tg: |
374 | | - started_event = asyncio.Event() |
375 | | - |
376 | | - # Create auto-renewal task |
377 | | - @periodic(interval=ttl / 3, raise_on_error=True) |
378 | | - async def _periodic_renewer() -> None: |
379 | | - await _renew_semaphore_entry(semaphore) |
380 | | - started_event.set() |
381 | | - |
382 | | - # Start the renewal task |
383 | | - renewal_task = tg.create_task( |
384 | | - _periodic_renewer(), |
385 | | - name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}", |
386 | | - ) |
387 | | - |
388 | | - # Wait for first renewal to complete (ensures task is running) |
389 | | - await started_event.wait() |
390 | | - |
391 | | - # Run the user work |
392 | | - work_task = tg.create_task( |
393 | | - coro(*args, **kwargs), |
394 | | - name=f"semaphore/work_{coro.__module__}.{coro.__name__}", |
395 | | - ) |
396 | | - |
397 | | - result = await work_task |
398 | | - |
399 | | - # Cancel renewal task (work is done) |
400 | | - renewal_task.cancel() |
401 | | - |
402 | | - return result |
403 | | - |
404 | | - except BaseExceptionGroup as eg: |
405 | | - # Handle exceptions similar to exclusive decorator |
406 | | - # If renewal fails, the TaskGroup will propagate the exception |
407 | | - # and cancel the work task automatically |
408 | | - |
409 | | - # Re-raise the first exception in the group |
410 | | - if eg.exceptions: |
411 | | - raise eg.exceptions[0] from eg |
412 | | - raise |
413 | | - |
414 | | - finally: |
415 | | - # Always release the semaphore |
416 | | - if semaphore.is_acquired(): |
417 | | - try: |
418 | | - await semaphore.release() |
419 | | - except SemaphoreNotAcquiredError as exc: |
420 | | - _logger.exception( |
421 | | - **create_troubleshootting_log_kwargs( |
422 | | - "Unexpected error while releasing semaphore", |
423 | | - error=exc, |
424 | | - error_context={ |
425 | | - "semaphore_key": semaphore_key, |
426 | | - "client_name": client.client_name, |
427 | | - "hostname": socket.gethostname(), |
428 | | - "coroutine": coro.__name__, |
429 | | - }, |
430 | | - tip="This might happen if the semaphore was lost before releasing it. " |
431 | | - "Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.", |
432 | | - ) |
433 | | - ) |
434 | | - |
435 | | - return _wrapper |
436 | | - |
437 | | - return _decorator |
0 commit comments