Skip to content

Commit d47c01f

Browse files
committed
ongoing
1 parent 3c2c76b commit d47c01f

File tree

3 files changed

+197
-104
lines changed

3 files changed

+197
-104
lines changed

packages/service-library/src/servicelib/redis/_semaphore_decorator.py

Lines changed: 173 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import functools
44
import logging
55
import socket
6-
from collections.abc import Callable, Coroutine
6+
from collections.abc import AsyncIterator, Callable, Coroutine
7+
from contextlib import asynccontextmanager
78
from typing import Any, ParamSpec, TypeVar
89

910
from common_library.async_tools import cancel_wait_task
@@ -27,6 +28,102 @@
2728
R = TypeVar("R")
2829

2930

31+
@asynccontextmanager
32+
async def _managed_semaphore_execution(
33+
semaphore: DistributedSemaphore,
34+
semaphore_key: str,
35+
ttl: datetime.timedelta,
36+
execution_context: str,
37+
) -> AsyncIterator:
38+
"""Common semaphore management logic with auto-renewal."""
39+
# Acquire the semaphore first
40+
if not await semaphore.acquire():
41+
raise SemaphoreAcquisitionError(name=semaphore_key, capacity=semaphore.capacity)
42+
43+
try:
44+
# Use TaskGroup for proper exception propagation
45+
async with asyncio.TaskGroup() as tg:
46+
started_event = asyncio.Event()
47+
48+
# Create auto-renewal task
49+
@periodic(interval=ttl / 3, raise_on_error=True)
50+
async def _periodic_renewer() -> None:
51+
await semaphore.reacquire()
52+
if not started_event.is_set():
53+
started_event.set()
54+
55+
# Start the renewal task
56+
renewal_task = tg.create_task(
57+
_periodic_renewer(),
58+
name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}",
59+
)
60+
61+
# Wait for first renewal to complete (ensures task is running)
62+
await started_event.wait()
63+
64+
# Yield control back to caller
65+
yield
66+
67+
# Cancel renewal task when execution is done
68+
await cancel_wait_task(renewal_task, max_delay=None)
69+
70+
except BaseExceptionGroup as eg:
71+
# Re-raise the first exception in the group
72+
raise eg.exceptions[0] from eg
73+
74+
finally:
75+
# Always attempt to release the semaphore
76+
try:
77+
await semaphore.release()
78+
except Exception as exc:
79+
_logger.exception(
80+
**create_troubleshootting_log_kwargs(
81+
"Unexpected error while releasing semaphore",
82+
error=exc,
83+
error_context={
84+
"semaphore_key": semaphore_key,
85+
"client_name": semaphore.redis_client.client_name,
86+
"hostname": socket.gethostname(),
87+
"execution_context": execution_context,
88+
},
89+
tip="This might happen if the semaphore was lost before releasing it. "
90+
"Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.",
91+
)
92+
)
93+
94+
95+
def _create_semaphore(
96+
redis_client: RedisClientSDK | Callable[..., RedisClientSDK],
97+
args: tuple[Any, ...],
98+
*,
99+
key: str | Callable[..., str],
100+
capacity: int | Callable[..., int],
101+
ttl: datetime.timedelta,
102+
blocking: bool,
103+
blocking_timeout: datetime.timedelta | None,
104+
kwargs: dict[str, Any],
105+
) -> tuple[DistributedSemaphore, str]:
106+
"""Create and configure a distributed semaphore from callable or static parameters."""
107+
semaphore_key = key(*args, **kwargs) if callable(key) else key
108+
semaphore_capacity = capacity(*args, **kwargs) if callable(capacity) else capacity
109+
client = redis_client(*args, **kwargs) if callable(redis_client) else redis_client
110+
111+
assert isinstance(semaphore_key, str) # nosec
112+
assert isinstance(semaphore_capacity, int) # nosec
113+
assert isinstance(client, RedisClientSDK) # nosec
114+
115+
semaphore = DistributedSemaphore(
116+
redis_client=client,
117+
key=semaphore_key,
118+
capacity=semaphore_capacity,
119+
ttl=ttl,
120+
blocking=blocking,
121+
blocking_timeout=blocking_timeout,
122+
)
123+
124+
return semaphore, semaphore_key
125+
126+
30127
def with_limited_concurrency(
31128
redis_client: RedisClientSDK | Callable[..., RedisClientSDK],
32129
*,
@@ -75,101 +172,89 @@ def _decorator(
75172
) -> Callable[P, Coroutine[Any, Any, R]]:
76173
@functools.wraps(coro)
77174
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
78-
# Resolve callable parameters
79-
semaphore_key = key(*args, **kwargs) if callable(key) else key
80-
semaphore_capacity = (
81-
capacity(*args, **kwargs) if callable(capacity) else capacity
82-
)
83-
client = (
84-
redis_client(*args, **kwargs)
85-
if callable(redis_client)
86-
else redis_client
175+
semaphore, semaphore_key = _create_semaphore(
176+
redis_client,
177+
args,
178+
key=key,
179+
capacity=capacity,
180+
ttl=ttl,
181+
blocking=blocking,
182+
blocking_timeout=blocking_timeout,
183+
kwargs=kwargs,
87184
)
88185

89-
assert isinstance(semaphore_key, str) # nosec
90-
assert isinstance(semaphore_capacity, int) # nosec
91-
assert isinstance(client, RedisClientSDK) # nosec
186+
async with _managed_semaphore_execution(
187+
semaphore, semaphore_key, ttl, f"coroutine_{coro.__name__}"
188+
):
189+
return await coro(*args, **kwargs)
190+
191+
return _wrapper
192+
193+
return _decorator
194+
92195

93-
# Create the semaphore (without auto-renewal)
94-
semaphore = DistributedSemaphore(
95-
redis_client=client,
96-
key=semaphore_key,
97-
capacity=semaphore_capacity,
196+
def with_limited_concurrency_cm(
197+
redis_client: RedisClientSDK | Callable[..., RedisClientSDK],
198+
*,
199+
key: str | Callable[..., str],
200+
capacity: int | Callable[..., int],
201+
ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL,
202+
blocking: bool = True,
203+
blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT,
204+
) -> Callable[[Callable[P, AsyncIterator[R]]], Callable[P, AsyncIterator[R]]]:
205+
"""
206+
Decorator to limit concurrent execution of async context managers using a distributed semaphore.
207+
208+
This decorator ensures that only a specified number of instances of the decorated
209+
async context manager can be active concurrently across multiple processes/instances
210+
using Redis as the coordination backend.
211+
212+
Args:
213+
redis_client: Redis client for coordination (can be callable)
214+
key: Unique identifier for the semaphore (can be callable)
215+
capacity: Maximum number of concurrent executions (can be callable)
216+
ttl: Time-to-live for semaphore entries (default: 5 minutes)
217+
blocking: Whether to block when semaphore is full (default: True)
218+
blocking_timeout: Maximum time to wait when blocking (default: socket timeout)
219+
220+
Example:
221+
@asynccontextmanager
222+
@with_limited_concurrency_cm(
223+
redis_client,
224+
key="cluster:my-cluster",
225+
capacity=5,
226+
blocking=True,
227+
blocking_timeout=None
228+
)
229+
async def get_cluster_client():
230+
async with pool.acquire() as client:
231+
yield client
232+
233+
Raises:
234+
SemaphoreAcquisitionError: If semaphore cannot be acquired and blocking=True
235+
"""
236+
237+
def _decorator(
238+
cm_func: Callable[P, AsyncIterator[R]],
239+
) -> Callable[P, AsyncIterator[R]]:
240+
@functools.wraps(cm_func)
241+
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]:
242+
semaphore, semaphore_key = _create_semaphore(
243+
redis_client,
244+
args,
245+
key=key,
246+
capacity=capacity,
98247
ttl=ttl,
99248
blocking=blocking,
100249
blocking_timeout=blocking_timeout,
250+
kwargs=kwargs,
101251
)
102252

103-
# Acquire the semaphore first
104-
if not await semaphore.acquire():
105-
raise SemaphoreAcquisitionError(
106-
name=semaphore_key, capacity=semaphore_capacity
107-
)
108-
109-
try:
110-
# Use TaskGroup for proper exception propagation (similar to exclusive decorator)
111-
async with asyncio.TaskGroup() as tg:
112-
started_event = asyncio.Event()
113-
114-
# Create auto-renewal task
115-
@periodic(interval=ttl / 3, raise_on_error=True)
116-
async def _periodic_renewer() -> None:
117-
await semaphore.reacquire()
118-
if not started_event.is_set():
119-
started_event.set()
120-
121-
# Start the renewal task
122-
renewal_task = tg.create_task(
123-
_periodic_renewer(),
124-
name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}",
125-
)
126-
127-
# Wait for first renewal to complete (ensures task is running)
128-
await started_event.wait()
129-
130-
# Run the user work
131-
work_task = tg.create_task(
132-
coro(*args, **kwargs),
133-
name=f"semaphore/work_{coro.__module__}.{coro.__name__}",
134-
)
135-
result = await work_task
136-
137-
# Cancel renewal task (work is done)
138-
# NOTE: if we do not explicitely await the task inside the context manager
139-
# it sometimes hangs forever (Python issue?)
140-
await cancel_wait_task(renewal_task, max_delay=None)
141-
142-
return result
143-
144-
except BaseExceptionGroup as eg:
145-
# Handle exceptions similar to exclusive decorator
146-
# If renewal fails, the TaskGroup will propagate the exception
147-
# and cancel the work task automatically
148-
149-
# Re-raise the first exception in the group
150-
raise eg.exceptions[0] from eg
151-
152-
finally:
153-
# Always attempt to release the semaphore, regardless of Python state
154-
# The Redis-side state is the source of truth, not the Python _acquired flag
155-
try:
156-
await semaphore.release()
157-
except Exception as exc:
158-
# Log any other release errors but don't re-raise
159-
_logger.exception(
160-
**create_troubleshootting_log_kwargs(
161-
"Unexpected error while releasing semaphore",
162-
error=exc,
163-
error_context={
164-
"semaphore_key": semaphore_key,
165-
"client_name": client.client_name,
166-
"hostname": socket.gethostname(),
167-
"coroutine": coro.__name__,
168-
},
169-
tip="This might happen if the semaphore was lost before releasing it. "
170-
"Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.",
171-
)
172-
)
253+
async with _managed_semaphore_execution(
254+
semaphore, semaphore_key, ttl, f"context_manager_{cm_func.__name__}"
255+
):
256+
async for value in cm_func(*args, **kwargs):
257+
yield value
173258

174259
return _wrapper
175260

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
2525
from servicelib.logging_errors import create_troubleshootting_log_kwargs
2626
from servicelib.logging_utils import log_catch, log_context
27+
from servicelib.redis._client import RedisClientSDK
28+
from servicelib.redis._semaphore_decorator import (
29+
with_limited_concurrency_cm,
30+
)
2731
from servicelib.utils import limited_as_completed, limited_gather
2832

2933
from ...core.errors import (
@@ -53,11 +57,13 @@
5357
from ..db.repositories.comp_tasks import CompTasksRepository
5458
from ._constants import (
5559
MAX_CONCURRENT_PIPELINE_SCHEDULING,
60+
MODULE_NAME_WORKER,
5661
)
5762
from ._models import TaskStateTracker
5863
from ._scheduler_base import BaseCompScheduler
5964
from ._utils import (
6065
WAITING_FOR_START_STATES,
66+
get_redis_lock_key,
6167
)
6268

6369
_logger = logging.getLogger(__name__)
@@ -67,7 +73,25 @@
6773
_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: Final[str] = "check_time"
6874

6975

76+
def _get_redis_client_from_scheduler(scheduler: "DaskScheduler") -> RedisClientSDK:
77+
return scheduler.redis_client
78+
79+
80+
def _unique_key_builder(_app, user_id: UserID, run_metadata: RunMetadataDict) -> str:
81+
return f"user_id_{user_id}-wallet_id_{run_metadata.get('wallet_id')}"
82+
83+
7084
@asynccontextmanager
85+
@with_limited_concurrency_cm(
86+
_get_redis_client_from_scheduler,
87+
key=get_redis_lock_key(
88+
MODULE_NAME_WORKER,
89+
unique_lock_key_builder=_unique_key_builder,
90+
),
91+
capacity=1,
92+
blocking=True,
93+
blocking_timeout=None,
94+
)
7195
async def _cluster_dask_client(
7296
user_id: UserID,
7397
scheduler: "DaskScheduler",

services/director-v2/src/simcore_service_director_v2/modules/dask_clients_pool.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@
1010
from models_library.clusters import BaseCluster, ClusterTypeInModel
1111
from pydantic import AnyUrl
1212
from servicelib.logging_utils import log_context
13-
from servicelib.redis._semaphore_decorator import with_limited_concurrency
14-
from settings_library.redis import RedisDatabase
15-
from simcore_service_director_v2.modules.comp_scheduler._utils import (
16-
get_redis_lock_key,
17-
)
18-
from simcore_service_director_v2.modules.redis import get_redis_client_manager
1913

2014
from ..core.errors import (
2115
ComputationalBackendNotConnectedError,
@@ -120,16 +114,6 @@ async def acquire(
120114
`release_client_ref` to release the client reference when done.
121115
"""
122116

123-
@with_limited_concurrency(
124-
get_redis_client_manager(self.app).client(RedisDatabase.LOCKS),
125-
key=get_redis_lock_key(
126-
"dask-clients-pool",
127-
unique_lock_key_builder=lambda: f"{cluster.name}-{cluster.endpoint}",
128-
),
129-
capacity=20,
130-
blocking=True,
131-
blocking_timeout=None,
132-
)
133117
async def _concurently_safe_acquire_client() -> DaskClient:
134118
async with self._client_acquisition_lock:
135119
with log_context(

0 commit comments

Comments
 (0)