Skip to content

Commit b2e0b38

Browse files
authored
Merge branch 'master' into pr-osparc-connect-opentelemetry-to-missing-services
2 parents 8c0f106 + 0781e63 commit b2e0b38

File tree

11 files changed

+481
-283
lines changed

11 files changed

+481
-283
lines changed

packages/service-library/src/servicelib/redis_utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from collections.abc import Awaitable, Callable
55
from datetime import timedelta
6-
from typing import Any
6+
from typing import Any, ParamSpec, TypeVar
77

88
import arrow
99

@@ -12,10 +12,16 @@
1212

1313
_logger = logging.getLogger(__file__)
1414

15+
P = ParamSpec("P")
16+
R = TypeVar("R")
17+
1518

1619
def exclusive(
17-
redis: RedisClientSDK, *, lock_key: str, lock_value: bytes | str | None = None
18-
):
20+
redis: RedisClientSDK | Callable[..., RedisClientSDK],
21+
*,
22+
lock_key: str | Callable[..., str],
23+
lock_value: bytes | str | None = None,
24+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
1925
"""
2026
Define a method to run exclusively across
2127
processes by leveraging a Redis Lock.
@@ -24,12 +30,30 @@ def exclusive(
2430
redis: the redis client SDK
2531
lock_key: a string as the name of the lock (good practice: app_name:lock_name)
2632
lock_value: some additional data that can be retrieved by another client
33+
34+
Raises:
35+
- ValueError if used incorrectly
36+
- CouldNotAcquireLockError if the lock could not be acquired
2737
"""
2838

29-
def decorator(func):
39+
if not lock_key:
40+
msg = "lock_key cannot be empty string!"
41+
raise ValueError(msg)
42+
43+
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
3044
@functools.wraps(func)
31-
async def wrapper(*args, **kwargs):
32-
async with redis.lock_context(lock_key=lock_key, lock_value=lock_value):
45+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
46+
redis_lock_key = (
47+
lock_key(*args, **kwargs) if callable(lock_key) else lock_key
48+
)
49+
assert isinstance(redis_lock_key, str) # nosec
50+
51+
redis_client = redis(*args, **kwargs) if callable(redis) else redis
52+
assert isinstance(redis_client, RedisClientSDK) # nosec
53+
54+
async with redis_client.lock_context(
55+
lock_key=redis_lock_key, lock_value=lock_value
56+
):
3357
return await func(*args, **kwargs)
3458

3559
return wrapper

packages/service-library/tests/test_redis_utils.py

Lines changed: 104 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from contextlib import AbstractAsyncContextManager
66
from datetime import timedelta
77
from itertools import chain
8+
from typing import Awaitable
89
from unittest.mock import Mock
910

1011
import arrow
@@ -32,39 +33,117 @@ async def _is_locked(redis_client_sdk: RedisClientSDK, lock_name: str) -> bool:
3233

3334
@pytest.fixture
3435
def lock_name(faker: Faker) -> str:
35-
return faker.uuid4() # type: ignore
36+
return faker.pystr()
37+
38+
39+
def _exclusive_sleeping_task(
40+
redis_client_sdk: RedisClientSDK | Callable[..., RedisClientSDK],
41+
lock_name: str | Callable[..., str],
42+
sleep_duration: float,
43+
) -> Callable[..., Awaitable[float]]:
44+
@exclusive(redis_client_sdk, lock_key=lock_name)
45+
async def _() -> float:
46+
resolved_client = (
47+
redis_client_sdk() if callable(redis_client_sdk) else redis_client_sdk
48+
)
49+
resolved_lock_name = lock_name() if callable(lock_name) else lock_name
50+
assert await _is_locked(resolved_client, resolved_lock_name)
51+
await asyncio.sleep(sleep_duration)
52+
assert await _is_locked(resolved_client, resolved_lock_name)
53+
return sleep_duration
54+
55+
return _
56+
57+
58+
@pytest.fixture
59+
def sleep_duration(faker: Faker) -> float:
60+
return faker.pyfloat(positive=True, min_value=0.2, max_value=0.8)
3661

3762

38-
async def _contained_client(
63+
async def test_exclusive_decorator(
3964
get_redis_client_sdk: Callable[
4065
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
4166
],
4267
lock_name: str,
43-
task_duration: float,
44-
) -> None:
45-
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
46-
assert not await _is_locked(redis_client_sdk, lock_name)
47-
48-
@exclusive(redis_client_sdk, lock_key=lock_name)
49-
async def _some_task() -> None:
50-
assert await _is_locked(redis_client_sdk, lock_name)
51-
await asyncio.sleep(task_duration)
52-
assert await _is_locked(redis_client_sdk, lock_name)
53-
54-
await _some_task()
68+
sleep_duration: float,
69+
):
5570

56-
assert not await _is_locked(redis_client_sdk, lock_name)
71+
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:
72+
for _ in range(3):
73+
assert (
74+
await _exclusive_sleeping_task(
75+
redis_client, lock_name, sleep_duration
76+
)()
77+
== sleep_duration
78+
)
5779

5880

59-
@pytest.mark.parametrize("task_duration", [0.1, 1, 2])
60-
async def test_exclusive_sequentially(
81+
async def test_exclusive_decorator_with_key_builder(
6182
get_redis_client_sdk: Callable[
6283
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
6384
],
6485
lock_name: str,
65-
task_duration: float,
86+
sleep_duration: float,
6687
):
67-
await _contained_client(get_redis_client_sdk, lock_name, task_duration)
88+
def _get_lock_name(*args, **kwargs) -> str:
89+
assert args is not None
90+
assert kwargs is not None
91+
return lock_name
92+
93+
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:
94+
for _ in range(3):
95+
assert (
96+
await _exclusive_sleeping_task(
97+
redis_client, _get_lock_name, sleep_duration
98+
)()
99+
== sleep_duration
100+
)
101+
102+
103+
async def test_exclusive_decorator_with_client_builder(
104+
get_redis_client_sdk: Callable[
105+
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
106+
],
107+
lock_name: str,
108+
sleep_duration: float,
109+
):
110+
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:
111+
112+
def _get_redis_client_builder(*args, **kwargs) -> RedisClientSDK:
113+
assert args is not None
114+
assert kwargs is not None
115+
return redis_client
116+
117+
for _ in range(3):
118+
assert (
119+
await _exclusive_sleeping_task(
120+
_get_redis_client_builder, lock_name, sleep_duration
121+
)()
122+
== sleep_duration
123+
)
124+
125+
126+
async def _acquire_lock_and_exclusively_sleep(
127+
get_redis_client_sdk: Callable[
128+
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
129+
],
130+
lock_name: str | Callable[..., str],
131+
sleep_duration: float,
132+
) -> None:
133+
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
134+
redis_lock_name = lock_name() if callable(lock_name) else lock_name
135+
assert not await _is_locked(redis_client_sdk, redis_lock_name)
136+
137+
@exclusive(redis_client_sdk, lock_key=lock_name)
138+
async def _() -> float:
139+
assert await _is_locked(redis_client_sdk, redis_lock_name)
140+
await asyncio.sleep(sleep_duration)
141+
assert await _is_locked(redis_client_sdk, redis_lock_name)
142+
return sleep_duration
143+
144+
assert await _() == sleep_duration
145+
146+
assert not await _is_locked(redis_client_sdk, redis_lock_name)
68147

69148

70149
async def test_exclusive_parallel_lock_is_released_and_reacquired(
@@ -76,17 +155,19 @@ async def test_exclusive_parallel_lock_is_released_and_reacquired(
76155
parallel_tasks = 10
77156
results = await logged_gather(
78157
*[
79-
_contained_client(get_redis_client_sdk, lock_name, task_duration=0.1)
158+
_acquire_lock_and_exclusively_sleep(
159+
get_redis_client_sdk, lock_name, sleep_duration=0.1
160+
)
80161
for _ in range(parallel_tasks)
81162
],
82-
reraise=False
163+
reraise=False,
83164
)
84165
assert results.count(None) == 1
85166
assert [isinstance(x, CouldNotAcquireLockError) for x in results].count(
86167
True
87168
) == parallel_tasks - 1
88169

89-
# check lock is being released
170+
# check lock is released
90171
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
91172
assert not await _is_locked(redis_client_sdk, lock_name)
92173

@@ -168,7 +249,7 @@ async def test_start_exclusive_periodic_task_parallel_all_finish(
168249
_assert_task_completes_once(get_redis_client_sdk, stop_after=60)
169250
for _ in range(parallel_tasks)
170251
],
171-
reraise=False
252+
reraise=False,
172253
)
173254

174255
# check no error occurred

services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Annotated
2+
13
from fastapi import Depends, FastAPI, Request
24

35
from ...core.settings import ComputationalBackendSettings
@@ -11,7 +13,7 @@ def get_scheduler(request: Request) -> BaseCompScheduler:
1113

1214

1315
def get_scheduler_settings(
14-
app: FastAPI = Depends(get_app),
16+
app: Annotated[FastAPI, Depends(get_app)]
1517
) -> ComputationalBackendSettings:
1618
settings: ComputationalBackendSettings = (
1719
app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,38 @@
1+
import logging
2+
from collections.abc import Callable, Coroutine
3+
from typing import Any, cast
4+
15
from fastapi import FastAPI
6+
from servicelib.logging_utils import log_context
27

8+
from . import _scheduler_factory
39
from ._base_scheduler import BaseCompScheduler
4-
from ._task import on_app_shutdown, on_app_startup
10+
11+
_logger = logging.getLogger(__name__)
12+
13+
14+
def on_app_startup(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
15+
async def start_scheduler() -> None:
16+
with log_context(
17+
_logger, level=logging.INFO, msg="starting computational scheduler"
18+
):
19+
app.state.scheduler = scheduler = await _scheduler_factory.create_from_db(
20+
app
21+
)
22+
scheduler.recover_scheduling()
23+
24+
return start_scheduler
25+
26+
27+
def on_app_shutdown(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
28+
async def stop_scheduler() -> None:
29+
await get_scheduler(app).shutdown()
30+
31+
return stop_scheduler
32+
33+
34+
def get_scheduler(app: FastAPI) -> BaseCompScheduler:
35+
return cast(BaseCompScheduler, app.state.scheduler)
536

637

738
def setup(app: FastAPI):
@@ -12,4 +43,5 @@ def setup(app: FastAPI):
1243
__all__: tuple[str, ...] = (
1344
"setup",
1445
"BaseCompScheduler",
46+
"get_scheduler",
1547
)

0 commit comments

Comments
 (0)