Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions packages/service-library/src/servicelib/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from collections.abc import Awaitable, Callable
from datetime import timedelta
from typing import Any
from typing import Any, ParamSpec, TypeVar

import arrow

Expand All @@ -12,10 +12,16 @@

_logger = logging.getLogger(__file__)

P = ParamSpec("P")
R = TypeVar("R")


def exclusive(
redis: RedisClientSDK, *, lock_key: str, lock_value: bytes | str | None = None
):
redis: RedisClientSDK | Callable[..., RedisClientSDK],
*,
lock_key: str | Callable[..., str],
lock_value: bytes | str | None = None,
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
"""
Define a method to run exclusively across
processes by leveraging a Redis Lock.
Expand All @@ -24,12 +30,30 @@ def exclusive(
redis: the redis client SDK
lock_key: a string as the name of the lock (good practice: app_name:lock_name)
lock_value: some additional data that can be retrieved by another client

Raises:
- ValueError if used incorrectly
- CouldNotAcquireLockError if the lock could not be acquired
"""

def decorator(func):
if not lock_key:
msg = "lock_key cannot be empty string!"
raise ValueError(msg)

def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with redis.lock_context(lock_key=lock_key, lock_value=lock_value):
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
redis_lock_key = (
lock_key(*args, **kwargs) if callable(lock_key) else lock_key
)
assert isinstance(redis_lock_key, str) # nosec

redis_client = redis(*args, **kwargs) if callable(redis) else redis
assert isinstance(redis_client, RedisClientSDK) # nosec

async with redis_client.lock_context(
lock_key=redis_lock_key, lock_value=lock_value
):
return await func(*args, **kwargs)

return wrapper
Expand Down
127 changes: 104 additions & 23 deletions packages/service-library/tests/test_redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from itertools import chain
from typing import Awaitable
from unittest.mock import Mock

import arrow
Expand Down Expand Up @@ -32,39 +33,117 @@ async def _is_locked(redis_client_sdk: RedisClientSDK, lock_name: str) -> bool:

@pytest.fixture
def lock_name(faker: Faker) -> str:
return faker.uuid4() # type: ignore
return faker.pystr()


def _exclusive_sleeping_task(
redis_client_sdk: RedisClientSDK | Callable[..., RedisClientSDK],
lock_name: str | Callable[..., str],
sleep_duration: float,
) -> Callable[..., Awaitable[float]]:
@exclusive(redis_client_sdk, lock_key=lock_name)
async def _() -> float:
resolved_client = (
redis_client_sdk() if callable(redis_client_sdk) else redis_client_sdk
)
resolved_lock_name = lock_name() if callable(lock_name) else lock_name
assert await _is_locked(resolved_client, resolved_lock_name)
await asyncio.sleep(sleep_duration)
assert await _is_locked(resolved_client, resolved_lock_name)
return sleep_duration

return _


@pytest.fixture
def sleep_duration(faker: Faker) -> float:
return faker.pyfloat(positive=True, min_value=0.2, max_value=0.8)


async def _contained_client(
async def test_exclusive_decorator(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str,
task_duration: float,
) -> None:
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
assert not await _is_locked(redis_client_sdk, lock_name)

@exclusive(redis_client_sdk, lock_key=lock_name)
async def _some_task() -> None:
assert await _is_locked(redis_client_sdk, lock_name)
await asyncio.sleep(task_duration)
assert await _is_locked(redis_client_sdk, lock_name)

await _some_task()
sleep_duration: float,
):

assert not await _is_locked(redis_client_sdk, lock_name)
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:
for _ in range(3):
assert (
await _exclusive_sleeping_task(
redis_client, lock_name, sleep_duration
)()
== sleep_duration
)


@pytest.mark.parametrize("task_duration", [0.1, 1, 2])
async def test_exclusive_sequentially(
async def test_exclusive_decorator_with_key_builder(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str,
task_duration: float,
sleep_duration: float,
):
await _contained_client(get_redis_client_sdk, lock_name, task_duration)
def _get_lock_name(*args, **kwargs) -> str:
assert args is not None
assert kwargs is not None
return lock_name

async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:
for _ in range(3):
assert (
await _exclusive_sleeping_task(
redis_client, _get_lock_name, sleep_duration
)()
== sleep_duration
)


async def test_exclusive_decorator_with_client_builder(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str,
sleep_duration: float,
):
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:

def _get_redis_client_builder(*args, **kwargs) -> RedisClientSDK:
assert args is not None
assert kwargs is not None
return redis_client

for _ in range(3):
assert (
await _exclusive_sleeping_task(
_get_redis_client_builder, lock_name, sleep_duration
)()
== sleep_duration
)


async def _acquire_lock_and_exclusively_sleep(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str | Callable[..., str],
sleep_duration: float,
) -> None:
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
redis_lock_name = lock_name() if callable(lock_name) else lock_name
assert not await _is_locked(redis_client_sdk, redis_lock_name)

@exclusive(redis_client_sdk, lock_key=lock_name)
async def _() -> float:
assert await _is_locked(redis_client_sdk, redis_lock_name)
await asyncio.sleep(sleep_duration)
assert await _is_locked(redis_client_sdk, redis_lock_name)
return sleep_duration

assert await _() == sleep_duration

assert not await _is_locked(redis_client_sdk, redis_lock_name)


async def test_exclusive_parallel_lock_is_released_and_reacquired(
Expand All @@ -76,17 +155,19 @@ async def test_exclusive_parallel_lock_is_released_and_reacquired(
parallel_tasks = 10
results = await logged_gather(
*[
_contained_client(get_redis_client_sdk, lock_name, task_duration=0.1)
_acquire_lock_and_exclusively_sleep(
get_redis_client_sdk, lock_name, sleep_duration=0.1
)
for _ in range(parallel_tasks)
],
reraise=False
reraise=False,
)
assert results.count(None) == 1
assert [isinstance(x, CouldNotAcquireLockError) for x in results].count(
True
) == parallel_tasks - 1

# check lock is being released
# check lock is released
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
assert not await _is_locked(redis_client_sdk, lock_name)

Expand Down Expand Up @@ -168,7 +249,7 @@ async def test_start_exclusive_periodic_task_parallel_all_finish(
_assert_task_completes_once(get_redis_client_sdk, stop_after=60)
for _ in range(parallel_tasks)
],
reraise=False
reraise=False,
)

# check no error occurred
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Annotated

from fastapi import Depends, FastAPI, Request

from ...core.settings import ComputationalBackendSettings
Expand All @@ -11,7 +13,7 @@ def get_scheduler(request: Request) -> BaseCompScheduler:


def get_scheduler_settings(
app: FastAPI = Depends(get_app),
app: Annotated[FastAPI, Depends(get_app)]
) -> ComputationalBackendSettings:
settings: ComputationalBackendSettings = (
app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
import logging
from collections.abc import Callable, Coroutine
from typing import Any, cast

from fastapi import FastAPI
from servicelib.logging_utils import log_context

from . import _scheduler_factory
from ._base_scheduler import BaseCompScheduler
from ._task import on_app_shutdown, on_app_startup

_logger = logging.getLogger(__name__)


def on_app_startup(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
async def start_scheduler() -> None:
with log_context(
_logger, level=logging.INFO, msg="starting computational scheduler"
):
app.state.scheduler = scheduler = await _scheduler_factory.create_from_db(
app
)
scheduler.recover_scheduling()

return start_scheduler


def on_app_shutdown(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
async def stop_scheduler() -> None:
await get_scheduler(app).shutdown()

return stop_scheduler


def get_scheduler(app: FastAPI) -> BaseCompScheduler:
return cast(BaseCompScheduler, app.state.scheduler)


def setup(app: FastAPI):
Expand All @@ -12,4 +43,5 @@ def setup(app: FastAPI):
__all__: tuple[str, ...] = (
"setup",
"BaseCompScheduler",
"get_scheduler",
)
Loading
Loading