|
1 | | -import asyncio |
2 | 1 | import datetime |
3 | 2 | import functools |
4 | 3 | import logging |
5 | | -import socket |
6 | 4 | from collections.abc import AsyncIterator, Callable, Coroutine |
7 | 5 | from contextlib import AbstractAsyncContextManager, asynccontextmanager |
8 | 6 | from typing import Any, ParamSpec, TypeVar |
9 | 7 |
|
10 | | -import arrow |
11 | | -from common_library.async_tools import cancel_wait_task |
12 | | -from common_library.logging.logging_errors import create_troubleshooting_log_kwargs |
13 | | - |
14 | | -from ..background_task import periodic |
15 | 8 | from ._client import RedisClientSDK |
16 | 9 | from ._constants import ( |
17 | 10 | DEFAULT_EXPECTED_LOCK_OVERALL_TIME, |
18 | 11 | DEFAULT_SEMAPHORE_TTL, |
19 | 12 | DEFAULT_SOCKET_TIMEOUT, |
20 | 13 | ) |
21 | | -from ._errors import ( |
22 | | - SemaphoreAcquisitionError, |
23 | | - SemaphoreLostError, |
24 | | - SemaphoreNotAcquiredError, |
25 | | -) |
26 | | -from ._semaphore import DistributedSemaphore |
| 14 | +from ._semaphore import distributed_semaphore |
27 | 15 |
|
28 | 16 | _logger = logging.getLogger(__name__) |
29 | 17 |
|
|
32 | 20 | R = TypeVar("R") |
33 | 21 |
|
34 | 22 |
|
35 | | -@asynccontextmanager |
36 | | -async def _managed_semaphore_execution( |
37 | | - semaphore: DistributedSemaphore, |
38 | | - semaphore_key: str, |
39 | | - ttl: datetime.timedelta, |
40 | | - execution_context: str, |
41 | | - expected_lock_overall_time: datetime.timedelta, |
42 | | -) -> AsyncIterator: |
43 | | - """Common semaphore management logic with auto-renewal.""" |
44 | | - # Acquire the semaphore first |
45 | | - if not await semaphore.acquire(): |
46 | | - raise SemaphoreAcquisitionError( |
47 | | - name=semaphore_key, instance_id=semaphore.instance_id |
48 | | - ) |
49 | | - |
50 | | - lock_acquisition_time = arrow.utcnow() |
51 | | - try: |
52 | | - # NOTE: Use TaskGroup for proper exception propagation, this ensures that in case of error the context manager will be properly exited |
53 | | - # and the semaphore released. |
54 | | - # If we use create_task() directly, exceptions in the task are not propagated to the parent task |
55 | | - # and the context manager may never exit, leading to semaphore leaks. |
56 | | - async with asyncio.TaskGroup() as tg: |
57 | | - started_event = asyncio.Event() |
58 | | - |
59 | | - # Create auto-renewal task |
60 | | - @periodic(interval=ttl / 3, raise_on_error=True) |
61 | | - async def _periodic_renewer() -> None: |
62 | | - await semaphore.reacquire() |
63 | | - if not started_event.is_set(): |
64 | | - started_event.set() |
65 | | - |
66 | | - # Start the renewal task |
67 | | - renewal_task = tg.create_task( |
68 | | - _periodic_renewer(), |
69 | | - name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}", |
70 | | - ) |
71 | | - await started_event.wait() |
72 | | - |
73 | | - yield |
74 | | - |
75 | | - # NOTE: if we do not explicitely await the task inside the context manager |
76 | | - # it sometimes hangs forever (Python issue?) |
77 | | - await cancel_wait_task(renewal_task, max_delay=None) |
78 | | - |
79 | | - except BaseExceptionGroup as eg: |
80 | | - semaphore_lost_errors, other_errors = eg.split(SemaphoreLostError) |
81 | | - # If there are any other errors, re-raise them |
82 | | - if other_errors: |
83 | | - assert len(other_errors.exceptions) == 1 # nosec |
84 | | - raise other_errors.exceptions[0] from eg |
85 | | - |
86 | | - assert semaphore_lost_errors is not None # nosec |
87 | | - assert len(semaphore_lost_errors.exceptions) == 1 # nosec |
88 | | - raise semaphore_lost_errors.exceptions[0] from eg |
89 | | - |
90 | | - finally: |
91 | | - # Always attempt to release the semaphore |
92 | | - try: |
93 | | - await semaphore.release() |
94 | | - except SemaphoreNotAcquiredError as exc: |
95 | | - _logger.exception( |
96 | | - **create_troubleshooting_log_kwargs( |
97 | | - f"Unexpected error while releasing semaphore '{semaphore_key}'", |
98 | | - error=exc, |
99 | | - error_context={ |
100 | | - "semaphore_key": semaphore_key, |
101 | | - "client_name": semaphore.redis_client.client_name, |
102 | | - "hostname": socket.gethostname(), |
103 | | - "execution_context": execution_context, |
104 | | - }, |
105 | | - tip="This might happen if the semaphore was lost before releasing it. " |
106 | | - "Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.", |
107 | | - ) |
108 | | - ) |
109 | | - finally: |
110 | | - lock_release_time = arrow.utcnow() |
111 | | - locking_time = lock_release_time - lock_acquisition_time |
112 | | - if locking_time > expected_lock_overall_time: |
113 | | - _logger.warning( |
114 | | - "Semaphore '%s' was held for %s which is longer than expected (%s). " |
115 | | - "TIP: consider reducing the locking time by optimizing the code inside " |
116 | | - "the critical section or increasing the default locking time", |
117 | | - semaphore_key, |
118 | | - locking_time, |
119 | | - expected_lock_overall_time, |
120 | | - ) |
121 | | - |
122 | | - |
123 | | -def _create_semaphore( |
124 | | - redis_client: RedisClientSDK | Callable[..., RedisClientSDK], |
125 | | - args: tuple[Any, ...], |
126 | | - *, |
127 | | - key: str | Callable[..., str], |
128 | | - capacity: int | Callable[..., int], |
129 | | - ttl: datetime.timedelta, |
130 | | - blocking: bool, |
131 | | - blocking_timeout: datetime.timedelta | None, |
132 | | - kwargs: dict[str, Any], |
133 | | -) -> tuple[DistributedSemaphore, str]: |
134 | | - """Create and configure a distributed semaphore from callable or static parameters.""" |
135 | | - semaphore_key = key(*args, **kwargs) if callable(key) else key |
136 | | - semaphore_capacity = capacity(*args, **kwargs) if callable(capacity) else capacity |
137 | | - client = redis_client(*args, **kwargs) if callable(redis_client) else redis_client |
138 | | - |
139 | | - assert isinstance(semaphore_key, str) # nosec |
140 | | - assert isinstance(semaphore_capacity, int) # nosec |
141 | | - assert isinstance(client, RedisClientSDK) # nosec |
142 | | - |
143 | | - semaphore = DistributedSemaphore( |
144 | | - redis_client=client, |
145 | | - key=semaphore_key, |
146 | | - capacity=semaphore_capacity, |
147 | | - ttl=ttl, |
148 | | - blocking=blocking, |
149 | | - blocking_timeout=blocking_timeout, |
150 | | - ) |
151 | | - |
152 | | - return semaphore, semaphore_key |
153 | | - |
154 | | - |
155 | 23 | def with_limited_concurrency( |
156 | 24 | redis_client: RedisClientSDK | Callable[..., RedisClientSDK], |
157 | 25 | *, |
@@ -202,23 +70,28 @@ def _decorator( |
202 | 70 | ) -> Callable[P, Coroutine[Any, Any, R]]: |
203 | 71 | @functools.wraps(coro) |
204 | 72 | async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
205 | | - semaphore, semaphore_key = _create_semaphore( |
206 | | - redis_client, |
207 | | - args, |
208 | | - key=key, |
209 | | - capacity=capacity, |
| 73 | + semaphore_key = key(*args, **kwargs) if callable(key) else key |
| 74 | + semaphore_capacity = ( |
| 75 | + capacity(*args, **kwargs) if callable(capacity) else capacity |
| 76 | + ) |
| 77 | + client = ( |
| 78 | + redis_client(*args, **kwargs) |
| 79 | + if callable(redis_client) |
| 80 | + else redis_client |
| 81 | + ) |
| 82 | + |
| 83 | + assert isinstance(semaphore_key, str) # nosec |
| 84 | + assert isinstance(semaphore_capacity, int) # nosec |
| 85 | + assert isinstance(client, RedisClientSDK) # nosec |
| 86 | + |
| 87 | + async with distributed_semaphore( |
| 88 | + redis_client=client, |
| 89 | + key=semaphore_key, |
| 90 | + capacity=semaphore_capacity, |
210 | 91 | ttl=ttl, |
211 | 92 | blocking=blocking, |
212 | 93 | blocking_timeout=blocking_timeout, |
213 | | - kwargs=kwargs, |
214 | | - ) |
215 | | - |
216 | | - async with _managed_semaphore_execution( |
217 | | - semaphore, |
218 | | - semaphore_key, |
219 | | - ttl, |
220 | | - f"coroutine_{coro.__name__}", |
221 | | - expected_lock_overall_time, |
| 94 | + expected_lock_overall_time=expected_lock_overall_time, |
222 | 95 | ): |
223 | 96 | return await coro(*args, **kwargs) |
224 | 97 |
|
@@ -279,24 +152,29 @@ def _decorator( |
279 | 152 | @functools.wraps(cm_func) |
280 | 153 | @asynccontextmanager |
281 | 154 | async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]: |
282 | | - semaphore, semaphore_key = _create_semaphore( |
283 | | - redis_client, |
284 | | - args, |
285 | | - key=key, |
286 | | - capacity=capacity, |
287 | | - ttl=ttl, |
288 | | - blocking=blocking, |
289 | | - blocking_timeout=blocking_timeout, |
290 | | - kwargs=kwargs, |
| 155 | + semaphore_key = key(*args, **kwargs) if callable(key) else key |
| 156 | + semaphore_capacity = ( |
| 157 | + capacity(*args, **kwargs) if callable(capacity) else capacity |
291 | 158 | ) |
| 159 | + client = ( |
| 160 | + redis_client(*args, **kwargs) |
| 161 | + if callable(redis_client) |
| 162 | + else redis_client |
| 163 | + ) |
| 164 | + |
| 165 | + assert isinstance(semaphore_key, str) # nosec |
| 166 | + assert isinstance(semaphore_capacity, int) # nosec |
| 167 | + assert isinstance(client, RedisClientSDK) # nosec |
292 | 168 |
|
293 | 169 | async with ( |
294 | | - _managed_semaphore_execution( |
295 | | - semaphore, |
296 | | - semaphore_key, |
297 | | - ttl, |
298 | | - f"context_manager_{cm_func.__name__}", |
299 | | - expected_lock_overall_time, |
| 170 | + distributed_semaphore( |
| 171 | + redis_client=client, |
| 172 | + key=semaphore_key, |
| 173 | + capacity=semaphore_capacity, |
| 174 | + ttl=ttl, |
| 175 | + blocking=blocking, |
| 176 | + blocking_timeout=blocking_timeout, |
| 177 | + expected_lock_overall_time=expected_lock_overall_time, |
300 | 178 | ), |
301 | 179 | cm_func(*args, **kwargs) as value, |
302 | 180 | ): |
|
0 commit comments