|
3 | 3 | import functools |
4 | 4 | import logging |
5 | 5 | import socket |
6 | | -from collections.abc import Callable, Coroutine |
| 6 | +from collections.abc import AsyncIterator, Callable, Coroutine |
| 7 | +from contextlib import asynccontextmanager |
7 | 8 | from typing import Any, ParamSpec, TypeVar |
8 | 9 |
|
9 | 10 | from common_library.async_tools import cancel_wait_task |
|
27 | 28 | R = TypeVar("R") |
28 | 29 |
|
29 | 30 |
|
| 31 | +@asynccontextmanager |
| 32 | +async def _managed_semaphore_execution( |
| 33 | + semaphore: DistributedSemaphore, |
| 34 | + semaphore_key: str, |
| 35 | + ttl: datetime.timedelta, |
| 36 | + execution_context: str, |
| 37 | +) -> AsyncIterator: |
| 38 | + """Common semaphore management logic with auto-renewal.""" |
| 39 | + # Acquire the semaphore first |
| 40 | + if not await semaphore.acquire(): |
| 41 | + raise SemaphoreAcquisitionError(name=semaphore_key, capacity=semaphore.capacity) |
| 42 | + |
| 43 | + try: |
| 44 | + # Use TaskGroup for proper exception propagation |
| 45 | + async with asyncio.TaskGroup() as tg: |
| 46 | + started_event = asyncio.Event() |
| 47 | + |
| 48 | + # Create auto-renewal task |
| 49 | + @periodic(interval=ttl / 3, raise_on_error=True) |
| 50 | + async def _periodic_renewer() -> None: |
| 51 | + await semaphore.reacquire() |
| 52 | + if not started_event.is_set(): |
| 53 | + started_event.set() |
| 54 | + |
| 55 | + # Start the renewal task |
| 56 | + renewal_task = tg.create_task( |
| 57 | + _periodic_renewer(), |
| 58 | + name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}", |
| 59 | + ) |
| 60 | + |
| 61 | + # Wait for first renewal to complete (ensures task is running) |
| 62 | + await started_event.wait() |
| 63 | + |
| 64 | + # Yield control back to caller |
| 65 | + yield |
| 66 | + |
| 67 | + # Cancel renewal task when execution is done |
| 68 | + await cancel_wait_task(renewal_task, max_delay=None) |
| 69 | + |
| 70 | + except BaseExceptionGroup as eg: |
| 71 | + # Re-raise the first exception in the group |
| 72 | + raise eg.exceptions[0] from eg |
| 73 | + |
| 74 | + finally: |
| 75 | + # Always attempt to release the semaphore |
| 76 | + try: |
| 77 | + await semaphore.release() |
| 78 | + except Exception as exc: |
| 79 | + _logger.exception( |
| 80 | + **create_troubleshootting_log_kwargs( |
| 81 | + "Unexpected error while releasing semaphore", |
| 82 | + error=exc, |
| 83 | + error_context={ |
| 84 | + "semaphore_key": semaphore_key, |
| 85 | + "client_name": semaphore.redis_client.client_name, |
| 86 | + "hostname": socket.gethostname(), |
| 87 | + "execution_context": execution_context, |
| 88 | + }, |
| 89 | + tip="This might happen if the semaphore was lost before releasing it. " |
| 90 | + "Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.", |
| 91 | + ) |
| 92 | + ) |
| 93 | + |
| 94 | + |
| 95 | +def _create_semaphore( |
| 96 | + redis_client: RedisClientSDK | Callable[..., RedisClientSDK], |
| 97 | + args: tuple[Any, ...], |
| 98 | + *, |
| 99 | + key: str | Callable[..., str], |
| 100 | + capacity: int | Callable[..., int], |
| 101 | + ttl: datetime.timedelta, |
| 102 | + blocking: bool, |
| 103 | + blocking_timeout: datetime.timedelta | None, |
| 104 | + kwargs: dict[str, Any], |
| 105 | +) -> tuple[DistributedSemaphore, str]: |
| 106 | + """Create and configure a distributed semaphore from callable or static parameters.""" |
| 107 | + semaphore_key = key(*args, **kwargs) if callable(key) else key |
| 108 | + semaphore_capacity = capacity(*args, **kwargs) if callable(capacity) else capacity |
| 109 | + client = redis_client(*args, **kwargs) if callable(redis_client) else redis_client |
| 110 | + |
| 111 | + assert isinstance(semaphore_key, str) # nosec |
| 112 | + assert isinstance(semaphore_capacity, int) # nosec |
| 113 | + assert isinstance(client, RedisClientSDK) # nosec |
| 114 | + |
| 115 | + semaphore = DistributedSemaphore( |
| 116 | + redis_client=client, |
| 117 | + key=semaphore_key, |
| 118 | + capacity=semaphore_capacity, |
| 119 | + ttl=ttl, |
| 120 | + blocking=blocking, |
| 121 | + blocking_timeout=blocking_timeout, |
| 122 | + ) |
| 123 | + |
| 124 | + return semaphore, semaphore_key |
| 125 | + |
| 126 | + |
30 | 127 | def with_limited_concurrency( |
31 | 128 | redis_client: RedisClientSDK | Callable[..., RedisClientSDK], |
32 | 129 | *, |
@@ -75,101 +172,89 @@ def _decorator( |
75 | 172 | ) -> Callable[P, Coroutine[Any, Any, R]]: |
76 | 173 | @functools.wraps(coro) |
77 | 174 | async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
78 | | - # Resolve callable parameters |
79 | | - semaphore_key = key(*args, **kwargs) if callable(key) else key |
80 | | - semaphore_capacity = ( |
81 | | - capacity(*args, **kwargs) if callable(capacity) else capacity |
82 | | - ) |
83 | | - client = ( |
84 | | - redis_client(*args, **kwargs) |
85 | | - if callable(redis_client) |
86 | | - else redis_client |
| 175 | + semaphore, semaphore_key = _create_semaphore( |
| 176 | + redis_client, |
| 177 | + args, |
| 178 | + key=key, |
| 179 | + capacity=capacity, |
| 180 | + ttl=ttl, |
| 181 | + blocking=blocking, |
| 182 | + blocking_timeout=blocking_timeout, |
| 183 | + kwargs=kwargs, |
87 | 184 | ) |
88 | 185 |
|
89 | | - assert isinstance(semaphore_key, str) # nosec |
90 | | - assert isinstance(semaphore_capacity, int) # nosec |
91 | | - assert isinstance(client, RedisClientSDK) # nosec |
| 186 | + async with _managed_semaphore_execution( |
| 187 | + semaphore, semaphore_key, ttl, f"coroutine_{coro.__name__}" |
| 188 | + ): |
| 189 | + return await coro(*args, **kwargs) |
| 190 | + |
| 191 | + return _wrapper |
| 192 | + |
| 193 | + return _decorator |
| 194 | + |
92 | 195 |
|
93 | | - # Create the semaphore (without auto-renewal) |
94 | | - semaphore = DistributedSemaphore( |
95 | | - redis_client=client, |
96 | | - key=semaphore_key, |
97 | | - capacity=semaphore_capacity, |
| 196 | +def with_limited_concurrency_cm( |
| 197 | + redis_client: RedisClientSDK | Callable[..., RedisClientSDK], |
| 198 | + *, |
| 199 | + key: str | Callable[..., str], |
| 200 | + capacity: int | Callable[..., int], |
| 201 | + ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL, |
| 202 | + blocking: bool = True, |
| 203 | + blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT, |
| 204 | +) -> Callable[[Callable[P, AsyncIterator[R]]], Callable[P, AsyncIterator[R]]]: |
| 205 | + """ |
| 206 | + Decorator to limit concurrent execution of async context managers using a distributed semaphore. |
| 207 | +
|
| 208 | + This decorator ensures that only a specified number of instances of the decorated |
| 209 | + async context manager can be active concurrently across multiple processes/instances |
| 210 | + using Redis as the coordination backend. |
| 211 | +
|
| 212 | + Args: |
| 213 | + redis_client: Redis client for coordination (can be callable) |
| 214 | + key: Unique identifier for the semaphore (can be callable) |
| 215 | + capacity: Maximum number of concurrent executions (can be callable) |
| 216 | + ttl: Time-to-live for semaphore entries (default: 5 minutes) |
| 217 | + blocking: Whether to block when semaphore is full (default: True) |
| 218 | + blocking_timeout: Maximum time to wait when blocking (default: socket timeout) |
| 219 | +
|
| 220 | + Example: |
| 221 | + @asynccontextmanager |
| 222 | + @with_limited_concurrency_cm( |
| 223 | + redis_client, |
| 224 | + key="cluster:my-cluster", |
| 225 | + capacity=5, |
| 226 | + blocking=True, |
| 227 | + blocking_timeout=None |
| 228 | + ) |
| 229 | + async def get_cluster_client(): |
| 230 | + async with pool.acquire() as client: |
| 231 | + yield client |
| 232 | +
|
| 233 | + Raises: |
| 234 | + SemaphoreAcquisitionError: If semaphore cannot be acquired and blocking=True |
| 235 | + """ |
| 236 | + |
| 237 | + def _decorator( |
| 238 | + cm_func: Callable[P, AsyncIterator[R]], |
| 239 | + ) -> Callable[P, AsyncIterator[R]]: |
| 240 | + @functools.wraps(cm_func) |
| 241 | + async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]: |
| 242 | + semaphore, semaphore_key = _create_semaphore( |
| 243 | + redis_client, |
| 244 | + args, |
| 245 | + key=key, |
| 246 | + capacity=capacity, |
98 | 247 | ttl=ttl, |
99 | 248 | blocking=blocking, |
100 | 249 | blocking_timeout=blocking_timeout, |
| 250 | + kwargs=kwargs, |
101 | 251 | ) |
102 | 252 |
|
103 | | - # Acquire the semaphore first |
104 | | - if not await semaphore.acquire(): |
105 | | - raise SemaphoreAcquisitionError( |
106 | | - name=semaphore_key, capacity=semaphore_capacity |
107 | | - ) |
108 | | - |
109 | | - try: |
110 | | - # Use TaskGroup for proper exception propagation (similar to exclusive decorator) |
111 | | - async with asyncio.TaskGroup() as tg: |
112 | | - started_event = asyncio.Event() |
113 | | - |
114 | | - # Create auto-renewal task |
115 | | - @periodic(interval=ttl / 3, raise_on_error=True) |
116 | | - async def _periodic_renewer() -> None: |
117 | | - await semaphore.reacquire() |
118 | | - if not started_event.is_set(): |
119 | | - started_event.set() |
120 | | - |
121 | | - # Start the renewal task |
122 | | - renewal_task = tg.create_task( |
123 | | - _periodic_renewer(), |
124 | | - name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}", |
125 | | - ) |
126 | | - |
127 | | - # Wait for first renewal to complete (ensures task is running) |
128 | | - await started_event.wait() |
129 | | - |
130 | | - # Run the user work |
131 | | - work_task = tg.create_task( |
132 | | - coro(*args, **kwargs), |
133 | | - name=f"semaphore/work_{coro.__module__}.{coro.__name__}", |
134 | | - ) |
135 | | - result = await work_task |
136 | | - |
137 | | - # Cancel renewal task (work is done) |
138 | | - # NOTE: if we do not explicitely await the task inside the context manager |
139 | | - # it sometimes hangs forever (Python issue?) |
140 | | - await cancel_wait_task(renewal_task, max_delay=None) |
141 | | - |
142 | | - return result |
143 | | - |
144 | | - except BaseExceptionGroup as eg: |
145 | | - # Handle exceptions similar to exclusive decorator |
146 | | - # If renewal fails, the TaskGroup will propagate the exception |
147 | | - # and cancel the work task automatically |
148 | | - |
149 | | - # Re-raise the first exception in the group |
150 | | - raise eg.exceptions[0] from eg |
151 | | - |
152 | | - finally: |
153 | | - # Always attempt to release the semaphore, regardless of Python state |
154 | | - # The Redis-side state is the source of truth, not the Python _acquired flag |
155 | | - try: |
156 | | - await semaphore.release() |
157 | | - except Exception as exc: |
158 | | - # Log any other release errors but don't re-raise |
159 | | - _logger.exception( |
160 | | - **create_troubleshootting_log_kwargs( |
161 | | - "Unexpected error while releasing semaphore", |
162 | | - error=exc, |
163 | | - error_context={ |
164 | | - "semaphore_key": semaphore_key, |
165 | | - "client_name": client.client_name, |
166 | | - "hostname": socket.gethostname(), |
167 | | - "coroutine": coro.__name__, |
168 | | - }, |
169 | | - tip="This might happen if the semaphore was lost before releasing it. " |
170 | | - "Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.", |
171 | | - ) |
172 | | - ) |
| 253 | + async with _managed_semaphore_execution( |
| 254 | + semaphore, semaphore_key, ttl, f"context_manager_{cm_func.__name__}" |
| 255 | + ): |
| 256 | + async for value in cm_func(*args, **kwargs): |
| 257 | + yield value |
173 | 258 |
|
174 | 259 | return _wrapper |
175 | 260 |
|
|
0 commit comments