Skip to content

Commit 761ecb8

Browse files
committed
use underlying context manager
1 parent b46f1c7 commit 761ecb8

File tree

1 file changed

+40
-162
lines changed

1 file changed

+40
-162
lines changed

packages/service-library/src/servicelib/redis/_semaphore_decorator.py

Lines changed: 40 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,17 @@
1-
import asyncio
21
import datetime
32
import functools
43
import logging
5-
import socket
64
from collections.abc import AsyncIterator, Callable, Coroutine
75
from contextlib import AbstractAsyncContextManager, asynccontextmanager
86
from typing import Any, ParamSpec, TypeVar
97

10-
import arrow
11-
from common_library.async_tools import cancel_wait_task
12-
from common_library.logging.logging_errors import create_troubleshooting_log_kwargs
13-
14-
from ..background_task import periodic
158
from ._client import RedisClientSDK
169
from ._constants import (
1710
DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
1811
DEFAULT_SEMAPHORE_TTL,
1912
DEFAULT_SOCKET_TIMEOUT,
2013
)
21-
from ._errors import (
22-
SemaphoreAcquisitionError,
23-
SemaphoreLostError,
24-
SemaphoreNotAcquiredError,
25-
)
26-
from ._semaphore import DistributedSemaphore
14+
from ._semaphore import distributed_semaphore
2715

2816
_logger = logging.getLogger(__name__)
2917

@@ -32,126 +20,6 @@
3220
R = TypeVar("R")
3321

3422

35-
@asynccontextmanager
36-
async def _managed_semaphore_execution(
37-
semaphore: DistributedSemaphore,
38-
semaphore_key: str,
39-
ttl: datetime.timedelta,
40-
execution_context: str,
41-
expected_lock_overall_time: datetime.timedelta,
42-
) -> AsyncIterator:
43-
"""Common semaphore management logic with auto-renewal."""
44-
# Acquire the semaphore first
45-
if not await semaphore.acquire():
46-
raise SemaphoreAcquisitionError(
47-
name=semaphore_key, instance_id=semaphore.instance_id
48-
)
49-
50-
lock_acquisition_time = arrow.utcnow()
51-
try:
52-
# NOTE: Use TaskGroup for proper exception propagation, this ensures that in case of error the context manager will be properly exited
53-
# and the semaphore released.
54-
# If we use create_task() directly, exceptions in the task are not propagated to the parent task
55-
# and the context manager may never exit, leading to semaphore leaks.
56-
async with asyncio.TaskGroup() as tg:
57-
started_event = asyncio.Event()
58-
59-
# Create auto-renewal task
60-
@periodic(interval=ttl / 3, raise_on_error=True)
61-
async def _periodic_renewer() -> None:
62-
await semaphore.reacquire()
63-
if not started_event.is_set():
64-
started_event.set()
65-
66-
# Start the renewal task
67-
renewal_task = tg.create_task(
68-
_periodic_renewer(),
69-
name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}",
70-
)
71-
await started_event.wait()
72-
73-
yield
74-
75-
# NOTE: if we do not explicitely await the task inside the context manager
76-
# it sometimes hangs forever (Python issue?)
77-
await cancel_wait_task(renewal_task, max_delay=None)
78-
79-
except BaseExceptionGroup as eg:
80-
semaphore_lost_errors, other_errors = eg.split(SemaphoreLostError)
81-
# If there are any other errors, re-raise them
82-
if other_errors:
83-
assert len(other_errors.exceptions) == 1 # nosec
84-
raise other_errors.exceptions[0] from eg
85-
86-
assert semaphore_lost_errors is not None # nosec
87-
assert len(semaphore_lost_errors.exceptions) == 1 # nosec
88-
raise semaphore_lost_errors.exceptions[0] from eg
89-
90-
finally:
91-
# Always attempt to release the semaphore
92-
try:
93-
await semaphore.release()
94-
except SemaphoreNotAcquiredError as exc:
95-
_logger.exception(
96-
**create_troubleshooting_log_kwargs(
97-
f"Unexpected error while releasing semaphore '{semaphore_key}'",
98-
error=exc,
99-
error_context={
100-
"semaphore_key": semaphore_key,
101-
"client_name": semaphore.redis_client.client_name,
102-
"hostname": socket.gethostname(),
103-
"execution_context": execution_context,
104-
},
105-
tip="This might happen if the semaphore was lost before releasing it. "
106-
"Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.",
107-
)
108-
)
109-
finally:
110-
lock_release_time = arrow.utcnow()
111-
locking_time = lock_release_time - lock_acquisition_time
112-
if locking_time > expected_lock_overall_time:
113-
_logger.warning(
114-
"Semaphore '%s' was held for %s which is longer than expected (%s). "
115-
"TIP: consider reducing the locking time by optimizing the code inside "
116-
"the critical section or increasing the default locking time",
117-
semaphore_key,
118-
locking_time,
119-
expected_lock_overall_time,
120-
)
121-
122-
123-
def _create_semaphore(
124-
redis_client: RedisClientSDK | Callable[..., RedisClientSDK],
125-
args: tuple[Any, ...],
126-
*,
127-
key: str | Callable[..., str],
128-
capacity: int | Callable[..., int],
129-
ttl: datetime.timedelta,
130-
blocking: bool,
131-
blocking_timeout: datetime.timedelta | None,
132-
kwargs: dict[str, Any],
133-
) -> tuple[DistributedSemaphore, str]:
134-
"""Create and configure a distributed semaphore from callable or static parameters."""
135-
semaphore_key = key(*args, **kwargs) if callable(key) else key
136-
semaphore_capacity = capacity(*args, **kwargs) if callable(capacity) else capacity
137-
client = redis_client(*args, **kwargs) if callable(redis_client) else redis_client
138-
139-
assert isinstance(semaphore_key, str) # nosec
140-
assert isinstance(semaphore_capacity, int) # nosec
141-
assert isinstance(client, RedisClientSDK) # nosec
142-
143-
semaphore = DistributedSemaphore(
144-
redis_client=client,
145-
key=semaphore_key,
146-
capacity=semaphore_capacity,
147-
ttl=ttl,
148-
blocking=blocking,
149-
blocking_timeout=blocking_timeout,
150-
)
151-
152-
return semaphore, semaphore_key
153-
154-
15523
def with_limited_concurrency(
15624
redis_client: RedisClientSDK | Callable[..., RedisClientSDK],
15725
*,
@@ -202,23 +70,28 @@ def _decorator(
20270
) -> Callable[P, Coroutine[Any, Any, R]]:
20371
@functools.wraps(coro)
20472
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
205-
semaphore, semaphore_key = _create_semaphore(
206-
redis_client,
207-
args,
208-
key=key,
209-
capacity=capacity,
73+
semaphore_key = key(*args, **kwargs) if callable(key) else key
74+
semaphore_capacity = (
75+
capacity(*args, **kwargs) if callable(capacity) else capacity
76+
)
77+
client = (
78+
redis_client(*args, **kwargs)
79+
if callable(redis_client)
80+
else redis_client
81+
)
82+
83+
assert isinstance(semaphore_key, str) # nosec
84+
assert isinstance(semaphore_capacity, int) # nosec
85+
assert isinstance(client, RedisClientSDK) # nosec
86+
87+
async with distributed_semaphore(
88+
redis_client=client,
89+
key=semaphore_key,
90+
capacity=semaphore_capacity,
21091
ttl=ttl,
21192
blocking=blocking,
21293
blocking_timeout=blocking_timeout,
213-
kwargs=kwargs,
214-
)
215-
216-
async with _managed_semaphore_execution(
217-
semaphore,
218-
semaphore_key,
219-
ttl,
220-
f"coroutine_{coro.__name__}",
221-
expected_lock_overall_time,
94+
expected_lock_overall_time=expected_lock_overall_time,
22295
):
22396
return await coro(*args, **kwargs)
22497

@@ -279,24 +152,29 @@ def _decorator(
279152
@functools.wraps(cm_func)
280153
@asynccontextmanager
281154
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]:
282-
semaphore, semaphore_key = _create_semaphore(
283-
redis_client,
284-
args,
285-
key=key,
286-
capacity=capacity,
287-
ttl=ttl,
288-
blocking=blocking,
289-
blocking_timeout=blocking_timeout,
290-
kwargs=kwargs,
155+
semaphore_key = key(*args, **kwargs) if callable(key) else key
156+
semaphore_capacity = (
157+
capacity(*args, **kwargs) if callable(capacity) else capacity
291158
)
159+
client = (
160+
redis_client(*args, **kwargs)
161+
if callable(redis_client)
162+
else redis_client
163+
)
164+
165+
assert isinstance(semaphore_key, str) # nosec
166+
assert isinstance(semaphore_capacity, int) # nosec
167+
assert isinstance(client, RedisClientSDK) # nosec
292168

293169
async with (
294-
_managed_semaphore_execution(
295-
semaphore,
296-
semaphore_key,
297-
ttl,
298-
f"context_manager_{cm_func.__name__}",
299-
expected_lock_overall_time,
170+
distributed_semaphore(
171+
redis_client=client,
172+
key=semaphore_key,
173+
capacity=semaphore_capacity,
174+
ttl=ttl,
175+
blocking=blocking,
176+
blocking_timeout=blocking_timeout,
177+
expected_lock_overall_time=expected_lock_overall_time,
300178
),
301179
cm_func(*args, **kwargs) as value,
302180
):

0 commit comments

Comments
 (0)