diff --git a/packages/service-library/setup.cfg b/packages/service-library/setup.cfg index 874495da36bd..714b873009e3 100644 --- a/packages/service-library/setup.cfg +++ b/packages/service-library/setup.cfg @@ -21,6 +21,7 @@ markers = testit: "marks test to run during development" performance_test: "performance test" no_cleanup_check_rabbitmq_server_has_no_errors: "no check in rabbitmq logs" + heavy_load: "marks test as heavy load" [mypy] plugins = diff --git a/packages/service-library/setup.py b/packages/service-library/setup.py index 521b491b918e..2ddd96c9ece1 100644 --- a/packages/service-library/setup.py +++ b/packages/service-library/setup.py @@ -38,7 +38,7 @@ def read_reqs(reqs_path: Path) -> set[str]: "python_requires": "~=3.11", "install_requires": tuple(PROD_REQUIREMENTS), "packages": find_packages(where="src"), - "package_data": {"": ["py.typed"]}, + "package_data": {"": ["py.typed", "redis/lua/*.lua"]}, "package_dir": {"": "src"}, "test_suite": "tests", "tests_require": tuple(TEST_REQUIREMENTS), diff --git a/packages/service-library/src/servicelib/long_running_tasks/task.py b/packages/service-library/src/servicelib/long_running_tasks/task.py index 256a60fb4b59..e3f9346106cc 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/task.py +++ b/packages/service-library/src/servicelib/long_running_tasks/task.py @@ -11,7 +11,6 @@ from common_library.async_tools import cancel_wait_task from models_library.api_schemas_long_running_tasks.base import TaskProgress from pydantic import NonNegativeFloat, PositiveFloat -from servicelib.utils import limited_gather from settings_library.redis import RedisDatabase, RedisSettings from tenacity import ( AsyncRetrying, @@ -24,6 +23,7 @@ from ..logging_errors import create_troubleshootting_log_kwargs from ..logging_utils import log_catch, log_context from ..redis import RedisClientSDK, exclusive +from ..utils import limited_gather from ._redis_store import RedisStore from ._serialization import dumps from .errors import ( diff --git a/packages/service-library/src/servicelib/redis/__init__.py b/packages/service-library/src/servicelib/redis/__init__.py index 152b3053893c..08d1ff40c47d 100644 --- a/packages/service-library/src/servicelib/redis/__init__.py +++ b/packages/service-library/src/servicelib/redis/__init__.py @@ -6,6 +6,8 @@ CouldNotConnectToRedisError, LockLostError, ProjectLockError, + SemaphoreAcquisitionError, + SemaphoreNotAcquiredError, ) from ._models import RedisManagerDBConfig from ._project_document_version import ( @@ -18,24 +20,26 @@ is_project_locked, with_project_locked, ) +from ._semaphore_decorator import with_limited_concurrency from ._utils import handle_redis_returns_union_types __all__: tuple[str, ...] = ( + "PROJECT_DB_UPDATE_REDIS_LOCK_KEY", + "PROJECT_DOCUMENT_VERSION_KEY", "CouldNotAcquireLockError", "CouldNotConnectToRedisError", - "exclusive", - "increment_and_return_project_document_version", - "get_project_locked_state", - "handle_redis_returns_union_types", - "is_project_locked", "LockLostError", - "PROJECT_DB_UPDATE_REDIS_LOCK_KEY", - "PROJECT_DOCUMENT_VERSION_KEY", "ProjectLockError", "RedisClientSDK", "RedisClientsManager", "RedisManagerDBConfig", + "SemaphoreAcquisitionError", + "SemaphoreNotAcquiredError", + "exclusive", + "get_project_locked_state", + "handle_redis_returns_union_types", + "increment_and_return_project_document_version", + "is_project_locked", + "with_limited_concurrency", "with_project_locked", ) - -# nopycln: file diff --git a/packages/service-library/src/servicelib/redis/_constants.py b/packages/service-library/src/servicelib/redis/_constants.py index 6a10c6b75b0e..37aa19cb5a82 100644 --- a/packages/service-library/src/servicelib/redis/_constants.py +++ b/packages/service-library/src/servicelib/redis/_constants.py @@ -6,6 +6,9 @@ DEFAULT_LOCK_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10) DEFAULT_SOCKET_TIMEOUT: Final[datetime.timedelta] = datetime.timedelta(seconds=30) +DEFAULT_SEMAPHORE_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10) +SEMAPHORE_KEY_PREFIX: Final[str] = "semaphores:" +SEMAPHORE_HOLDER_KEY_PREFIX: Final[str] = "semaphores:holders:" DEFAULT_DECODE_RESPONSES: Final[bool] = True DEFAULT_HEALTH_CHECK_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) diff --git a/packages/service-library/src/servicelib/redis/_errors.py b/packages/service-library/src/servicelib/redis/_errors.py index 7fc3c7823ae0..e83b40e4ec62 100644 --- a/packages/service-library/src/servicelib/redis/_errors.py +++ b/packages/service-library/src/servicelib/redis/_errors.py @@ -4,8 +4,7 @@ from common_library.errors_classes import OsparcErrorMixin -class BaseRedisError(OsparcErrorMixin, RuntimeError): - ... +class BaseRedisError(OsparcErrorMixin, RuntimeError): ... class CouldNotAcquireLockError(BaseRedisError): @@ -25,3 +24,15 @@ class LockLostError(BaseRedisError): ProjectLockError: TypeAlias = redis.exceptions.LockError # NOTE: backwards compatible + + +class SemaphoreAcquisitionError(BaseRedisError): + msg_template: str = "Could not acquire semaphore '{name}' (capacity: {capacity})" + + +class SemaphoreNotAcquiredError(BaseRedisError): + msg_template: str = "Semaphore '{name}' was not acquired by this instance" + + +class SemaphoreLostError(BaseRedisError): + msg_template: str = "Semaphore '{name}' was lost by this instance `{instance_id}`" diff --git a/packages/service-library/src/servicelib/redis/_semaphore.py b/packages/service-library/src/servicelib/redis/_semaphore.py new file mode 100644 index 000000000000..b62fbc7d238d --- /dev/null +++ b/packages/service-library/src/servicelib/redis/_semaphore.py @@ -0,0 +1,379 @@ +import datetime +import logging +import uuid +from types import TracebackType +from typing import Annotated, ClassVar + +from common_library.basic_types import DEFAULT_FACTORY +from pydantic import ( + BaseModel, + Field, + PositiveInt, + computed_field, + field_validator, +) +from redis.commands.core import AsyncScript +from tenacity import ( + RetryError, + before_sleep_log, + retry, + retry_if_not_result, + stop_after_delay, + stop_never, + wait_random_exponential, +) + +from ._client import RedisClientSDK +from ._constants import ( + DEFAULT_SEMAPHORE_TTL, + DEFAULT_SOCKET_TIMEOUT, + SEMAPHORE_HOLDER_KEY_PREFIX, + SEMAPHORE_KEY_PREFIX, +) +from ._errors import ( + SemaphoreAcquisitionError, + SemaphoreLostError, + SemaphoreNotAcquiredError, +) +from ._semaphore_lua import ( + ACQUIRE_SEMAPHORE_SCRIPT, + COUNT_SEMAPHORE_SCRIPT, + RELEASE_SEMAPHORE_SCRIPT, + RENEW_SEMAPHORE_SCRIPT, + SCRIPT_BAD_EXIT_CODE, + SCRIPT_OK_EXIT_CODE, +) + +_logger = logging.getLogger(__name__) + + +class DistributedSemaphore(BaseModel): + """ + Warning: This should only be used directly via the decorator + + A distributed semaphore implementation using Redis. + + This semaphore allows limiting the number of concurrent operations across + multiple processes/instances using Redis as the coordination backend. + + Args: + redis_client: Redis client for coordination + key: Unique identifier for the semaphore + capacity: Maximum number of concurrent holders + ttl: Time-to-live for semaphore entries (auto-cleanup) + blocking: Whether acquire() should block until available + blocking_timeout: Maximum time to wait when blocking (None = no timeout) + + Example: + async with DistributedSemaphore( + redis_client, "my_resource", capacity=3 + ): + # Only 3 instances can execute this block concurrently + await do_limited_work() + """ + + model_config = { + "arbitrary_types_allowed": True, # For RedisClientSDK + } + + # Configuration fields with validation + redis_client: RedisClientSDK + key: Annotated[ + str, Field(min_length=1, description="Unique identifier for the semaphore") + ] + capacity: Annotated[ + PositiveInt, Field(description="Maximum number of concurrent holders") + ] + ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL + blocking: Annotated[ + bool, Field(description="Whether acquire() should block until available") + ] = True + blocking_timeout: Annotated[ + datetime.timedelta | None, + Field(description="Maximum time to wait when blocking"), + ] = DEFAULT_SOCKET_TIMEOUT + instance_id: Annotated[ + str, + Field( + description="Unique instance identifier", + default_factory=lambda: f"{uuid.uuid4()}", + ), + ] = DEFAULT_FACTORY + + # Class and/or Private state attributes (not part of the model) + acquire_script: ClassVar[AsyncScript | None] = None + count_script: ClassVar[AsyncScript | None] = None + release_script: ClassVar[AsyncScript | None] = None + renew_script: ClassVar[AsyncScript | None] = None + + @classmethod + def _register_scripts(cls, redis_client: RedisClientSDK) -> None: + """Register Lua scripts with Redis if not already done. + This is done once per class, not per instance. Internally the Redis client + caches the script SHA, so this is efficient. Even if called multiple times, + the script is only registered once.""" + if cls.acquire_script is None: + cls.acquire_script = redis_client.redis.register_script( + ACQUIRE_SEMAPHORE_SCRIPT + ) + cls.count_script = redis_client.redis.register_script( + COUNT_SEMAPHORE_SCRIPT + ) + cls.release_script = redis_client.redis.register_script( + RELEASE_SEMAPHORE_SCRIPT + ) + cls.renew_script = redis_client.redis.register_script( + RENEW_SEMAPHORE_SCRIPT + ) + + def __init__(self, **data) -> None: + super().__init__(**data) + self.__class__._register_scripts(self.redis_client) # noqa: SLF001 + + @computed_field # type: ignore[prop-decorator] + @property + def semaphore_key(self) -> str: + """Redis key for the semaphore sorted set.""" + return f"{SEMAPHORE_KEY_PREFIX}{self.key}" + + @computed_field # type: ignore[prop-decorator] + @property + def holder_key(self) -> str: + """Redis key for this instance's holder entry.""" + return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:{self.instance_id}" + + # Additional validation + @field_validator("ttl") + @classmethod + def validate_ttl(cls, v: datetime.timedelta) -> datetime.timedelta: + if v.total_seconds() <= 0: + msg = "TTL must be positive" + raise ValueError(msg) + return v + + @field_validator("blocking_timeout") + @classmethod + def validate_timeout( + cls, v: datetime.timedelta | None + ) -> datetime.timedelta | None: + if v is not None and v.total_seconds() <= 0: + msg = "Timeout must be positive" + raise ValueError(msg) + return v + + async def acquire(self) -> bool: + """ + Acquire the semaphore. + + Returns: + True if acquired successfully, False if not acquired and non-blocking + + Raises: + SemaphoreAcquisitionError: If acquisition fails and blocking=True + """ + + if not self.blocking: + # Non-blocking: try once + return await self._try_acquire() + + # Blocking + @retry( + wait=wait_random_exponential(min=0.1, max=2), + reraise=True, + stop=( + stop_after_delay(self.blocking_timeout.total_seconds()) + if self.blocking_timeout + else stop_never + ), + retry=retry_if_not_result(lambda acquired: acquired), + before_sleep=before_sleep_log(_logger, logging.DEBUG), + ) + async def _blocking_acquire() -> bool: + return await self._try_acquire() + + try: + return await _blocking_acquire() + except RetryError as exc: + raise SemaphoreAcquisitionError( + name=self.key, capacity=self.capacity + ) from exc + + async def release(self) -> None: + """ + Release the semaphore atomically using Lua script. + + Raises: + SemaphoreNotAcquiredError: If semaphore was not acquired by this instance + """ + ttl_seconds = int(self.ttl.total_seconds()) + + # Execute the release Lua script atomically + cls = type(self) + assert cls.release_script is not None # nosec + result = await cls.release_script( # pylint: disable=not-callable + keys=( + self.semaphore_key, + self.holder_key, + ), + args=( + self.instance_id, + str(ttl_seconds), + ), + client=self.redis_client.redis, + ) + + assert isinstance(result, list) # nosec + exit_code, status, current_count, expired_count = result + result = status + + if result == "released": + assert exit_code == SCRIPT_OK_EXIT_CODE # nosec + _logger.debug( + "Released semaphore '%s' (instance: %s, count: %s, expired: %s)", + self.key, + self.instance_id, + current_count, + expired_count, + ) + else: + # Instance wasn't in the semaphore set - this shouldn't happen + # but let's handle it gracefully + assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec + raise SemaphoreNotAcquiredError(name=self.key) + + async def _try_acquire(self) -> bool: + ttl_seconds = int(self.ttl.total_seconds()) + + # Execute the Lua script atomically + cls = type(self) + assert cls.acquire_script is not None # nosec + result = await cls.acquire_script( # pylint: disable=not-callable + keys=(self.semaphore_key, self.holder_key), + args=(self.instance_id, str(self.capacity), str(ttl_seconds)), + client=self.redis_client.redis, + ) + + # Lua script returns: [exit_code, status, current_count, expired_count] + assert isinstance(result, list) # nosec + exit_code, status, current_count, expired_count = result + + if exit_code == SCRIPT_OK_EXIT_CODE: + _logger.debug( + "Acquired semaphore '%s' (instance: %s, count: %s, expired: %s)", + self.key, + self.instance_id, + current_count, + expired_count, + ) + return True + + _logger.debug( + "Failed to acquire semaphore '%s' - %s (count: %s, expired: %s)", + self.key, + status, + current_count, + expired_count, + ) + return False + + async def reacquire(self) -> None: + """ + Atomically renew a semaphore entry using Lua script. + + This function is intended to be called by decorators or external renewal mechanisms. + + + Raises: + SemaphoreLostError: If the semaphore was lost or expired + """ + + ttl_seconds = int(self.ttl.total_seconds()) + + # Execute the renewal Lua script atomically + cls = type(self) + assert cls.renew_script is not None # nosec + result = await cls.renew_script( # pylint: disable=not-callable + keys=(self.semaphore_key, self.holder_key), + args=( + self.instance_id, + str(ttl_seconds), + ), + client=self.redis_client.redis, + ) + + assert isinstance(result, list) # nosec + exit_code, status, current_count, expired_count = result + + # Lua script returns: 'renewed' or status message + if status == "renewed": + assert exit_code == SCRIPT_OK_EXIT_CODE # nosec + _logger.debug( + "Renewed semaphore '%s' (instance: %s, count: %s, expired: %s)", + self.key, + self.instance_id, + current_count, + expired_count, + ) + else: + assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec + if status == "expired": + _logger.warning( + "Semaphore '%s' holder key expired (instance: %s, count: %s, expired: %s)", + self.key, + self.instance_id, + current_count, + expired_count, + ) + elif status == "not_held": + _logger.warning( + "Semaphore '%s' not held (instance: %s, count: %s, expired: %s)", + self.key, + self.instance_id, + current_count, + expired_count, + ) + + raise SemaphoreLostError(name=self.key, instance_id=self.instance_id) + + async def get_current_count(self) -> int: + """Get the current number of semaphore holders""" + ttl_seconds = int(self.ttl.total_seconds()) + + # Execute the count Lua script atomically + cls = type(self) + assert cls.count_script is not None # nosec + result = await cls.count_script( # pylint: disable=not-callable + keys=(self.semaphore_key,), + args=(str(ttl_seconds),), + client=self.redis_client.redis, + ) + + assert isinstance(result, list) # nosec + current_count, expired_count = result + + if int(expired_count) > 0: + _logger.debug( + "Cleaned up %s expired entries from semaphore '%s'", + expired_count, + self.key, + ) + + return int(current_count) + + async def get_available_count(self) -> int: + """Get the number of available semaphore slots""" + current_count = await self.get_current_count() + return max(0, self.capacity - current_count) + + # Context manager support + async def __aenter__(self) -> "DistributedSemaphore": + await self.acquire() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.release() diff --git a/packages/service-library/src/servicelib/redis/_semaphore_decorator.py b/packages/service-library/src/servicelib/redis/_semaphore_decorator.py new file mode 100644 index 000000000000..8ae687dd4416 --- /dev/null +++ b/packages/service-library/src/servicelib/redis/_semaphore_decorator.py @@ -0,0 +1,277 @@ +import asyncio +import datetime +import functools +import logging +import socket +from collections.abc import AsyncIterator, Callable, Coroutine +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, ParamSpec, TypeVar + +from common_library.async_tools import cancel_wait_task + +from ..background_task import periodic +from ..logging_errors import create_troubleshootting_log_kwargs +from ._client import RedisClientSDK +from ._constants import ( + DEFAULT_SEMAPHORE_TTL, + DEFAULT_SOCKET_TIMEOUT, +) +from ._errors import ( + SemaphoreAcquisitionError, + SemaphoreLostError, + SemaphoreNotAcquiredError, +) +from ._semaphore import DistributedSemaphore + +_logger = logging.getLogger(__name__) + + +P = ParamSpec("P") +R = TypeVar("R") + + +@asynccontextmanager +async def _managed_semaphore_execution( + semaphore: DistributedSemaphore, + semaphore_key: str, + ttl: datetime.timedelta, + execution_context: str, +) -> AsyncIterator: + """Common semaphore management logic with auto-renewal.""" + # Acquire the semaphore first + if not await semaphore.acquire(): + raise SemaphoreAcquisitionError(name=semaphore_key, capacity=semaphore.capacity) + + try: + # NOTE: Use TaskGroup for proper exception propagation, this ensures that in case of error the context manager will be properly exited + # and the semaphore released. + # If we use create_task() directly, exceptions in the task are not propagated to the parent task + # and the context manager may never exit, leading to semaphore leaks. + async with asyncio.TaskGroup() as tg: + started_event = asyncio.Event() + + # Create auto-renewal task + @periodic(interval=ttl / 3, raise_on_error=True) + async def _periodic_renewer() -> None: + await semaphore.reacquire() + if not started_event.is_set(): + started_event.set() + + # Start the renewal task + renewal_task = tg.create_task( + _periodic_renewer(), + name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}", + ) + await started_event.wait() + + yield + + # NOTE: if we do not explicitely await the task inside the context manager + # it sometimes hangs forever (Python issue?) + await cancel_wait_task(renewal_task, max_delay=None) + + except BaseExceptionGroup as eg: + semaphore_lost_errors, other_errors = eg.split(SemaphoreLostError) + # If there are any other errors, re-raise them + if other_errors: + assert len(other_errors.exceptions) == 1 # nosec + raise other_errors.exceptions[0] from eg + + assert semaphore_lost_errors is not None # nosec + assert len(semaphore_lost_errors.exceptions) == 1 # nosec + raise semaphore_lost_errors.exceptions[0] from eg + + finally: + # Always attempt to release the semaphore + try: + await semaphore.release() + except SemaphoreNotAcquiredError as exc: + _logger.exception( + **create_troubleshootting_log_kwargs( + f"Unexpected error while releasing semaphore '{semaphore_key}'", + error=exc, + error_context={ + "semaphore_key": semaphore_key, + "client_name": semaphore.redis_client.client_name, + "hostname": socket.gethostname(), + "execution_context": execution_context, + }, + tip="This might happen if the semaphore was lost before releasing it. " + "Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.", + ) + ) + + +def _create_semaphore( + redis_client: RedisClientSDK | Callable[..., RedisClientSDK], + args: tuple[Any, ...], + *, + key: str | Callable[..., str], + capacity: int | Callable[..., int], + ttl: datetime.timedelta, + blocking: bool, + blocking_timeout: datetime.timedelta | None, + kwargs: dict[str, Any], +) -> tuple[DistributedSemaphore, str]: + """Create and configure a distributed semaphore from callable or static parameters.""" + semaphore_key = key(*args, **kwargs) if callable(key) else key + semaphore_capacity = capacity(*args, **kwargs) if callable(capacity) else capacity + client = redis_client(*args, **kwargs) if callable(redis_client) else redis_client + + assert isinstance(semaphore_key, str) # nosec + assert isinstance(semaphore_capacity, int) # nosec + assert isinstance(client, RedisClientSDK) # nosec + + semaphore = DistributedSemaphore( + redis_client=client, + key=semaphore_key, + capacity=semaphore_capacity, + ttl=ttl, + blocking=blocking, + blocking_timeout=blocking_timeout, + ) + + return semaphore, semaphore_key + + +def with_limited_concurrency( + redis_client: RedisClientSDK | Callable[..., RedisClientSDK], + *, + key: str | Callable[..., str], + capacity: int | Callable[..., int], + ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL, + blocking: bool = True, + blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT, +) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] +]: + """ + Decorator to limit concurrent execution of a function using a distributed semaphore. + + This decorator ensures that only a specified number of instances of the decorated + function can run concurrently across multiple processes/instances using Redis + as the coordination backend. + + Args: + redis_client: Redis client for coordination (can be callable) + key: Unique identifier for the semaphore (can be callable) + capacity: Maximum number of concurrent executions (can be callable) + ttl: Time-to-live for semaphore entries (default: 5 minutes) + blocking: Whether to block when semaphore is full (default: True) + blocking_timeout: Maximum time to wait when blocking (default: socket timeout) + + Example: + @with_limited_concurrency( + redis_client, + key=f"{user_id}-{wallet_id}", + capacity=20, + blocking=True, + blocking_timeout=None + ) + async def process_user_wallet(user_id: str, wallet_id: str): + # Only 20 instances of this function can run concurrently + # for the same user_id-wallet_id combination + await do_processing() + + Raises: + SemaphoreAcquisitionError: If semaphore cannot be acquired and blocking=True + """ + + def _decorator( + coro: Callable[P, Coroutine[Any, Any, R]], + ) -> Callable[P, Coroutine[Any, Any, R]]: + @functools.wraps(coro) + async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + semaphore, semaphore_key = _create_semaphore( + redis_client, + args, + key=key, + capacity=capacity, + ttl=ttl, + blocking=blocking, + blocking_timeout=blocking_timeout, + kwargs=kwargs, + ) + + async with _managed_semaphore_execution( + semaphore, semaphore_key, ttl, f"coroutine_{coro.__name__}" + ): + return await coro(*args, **kwargs) + + return _wrapper + + return _decorator + + +def with_limited_concurrency_cm( + redis_client: RedisClientSDK | Callable[..., RedisClientSDK], + *, + key: str | Callable[..., str], + capacity: int | Callable[..., int], + ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL, + blocking: bool = True, + blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT, +) -> Callable[ + [Callable[P, AbstractAsyncContextManager[R]]], + Callable[P, AbstractAsyncContextManager[R]], +]: + """ + Decorator to limit concurrent execution of async context managers using a distributed semaphore. + + This decorator ensures that only a specified number of instances of the decorated + async context manager can be active concurrently across multiple processes/instances + using Redis as the coordination backend. + + Args: + redis_client: Redis client for coordination (can be callable) + key: Unique identifier for the semaphore (can be callable) + capacity: Maximum number of concurrent executions (can be callable) + ttl: Time-to-live for semaphore entries (default: 5 minutes) + blocking: Whether to block when semaphore is full (default: True) + blocking_timeout: Maximum time to wait when blocking (default: socket timeout) + + Example: + @asynccontextmanager + @with_limited_concurrency_cm( + redis_client, + key="cluster:my-cluster", + capacity=5, + blocking=True, + blocking_timeout=None + ) + async def get_cluster_client(): + async with pool.acquire() as client: + yield client + + Raises: + SemaphoreAcquisitionError: If semaphore cannot be acquired and blocking=True + """ + + def _decorator( + cm_func: Callable[P, AbstractAsyncContextManager[R]], + ) -> Callable[P, AbstractAsyncContextManager[R]]: + @functools.wraps(cm_func) + @asynccontextmanager + async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]: + semaphore, semaphore_key = _create_semaphore( + redis_client, + args, + key=key, + capacity=capacity, + ttl=ttl, + blocking=blocking, + blocking_timeout=blocking_timeout, + kwargs=kwargs, + ) + + async with ( + _managed_semaphore_execution( + semaphore, semaphore_key, ttl, f"context_manager_{cm_func.__name__}" + ), + cm_func(*args, **kwargs) as value, + ): + yield value + + return _wrapper + + return _decorator diff --git a/packages/service-library/src/servicelib/redis/_semaphore_lua.py b/packages/service-library/src/servicelib/redis/_semaphore_lua.py new file mode 100644 index 000000000000..8bf685b30a86 --- /dev/null +++ b/packages/service-library/src/servicelib/redis/_semaphore_lua.py @@ -0,0 +1,35 @@ +"""used to load a lua script from the package resources in memory + +Example: + >>> from servicelib.redis._semaphore_lua import ACQUIRE_SEMAPHORE_SCRIPT + # This will register the script in redis and return a Script object + # which can be used to execute the script. Even from multiple processes + # the script will be loaded only once in redis as the redis server computes + # the SHA1 of the script and uses it to identify it. + >>> from aioredis import Redis + >>> redis = Redis(...) + >>> my_acquire_script = redis.register_script( + ACQUIRE_SEMAPHORE_SCRIPT + >>> my_acquire_script(keys=[...], args=[...]) +""" + +from functools import lru_cache +from importlib import resources +from typing import Final + + +@lru_cache +def _load_script(script_name: str) -> str: + with resources.as_file( + resources.files("servicelib.redis.lua") / f"{script_name}.lua" + ) as script_file: + return script_file.read_text(encoding="utf-8").strip() + + +ACQUIRE_SEMAPHORE_SCRIPT: Final[str] = _load_script("acquire_semaphore") +RELEASE_SEMAPHORE_SCRIPT: Final[str] = _load_script("release_semaphore") +RENEW_SEMAPHORE_SCRIPT: Final[str] = _load_script("renew_semaphore") +COUNT_SEMAPHORE_SCRIPT: Final[str] = _load_script("count_semaphore") + +SCRIPT_BAD_EXIT_CODE: Final[int] = 255 +SCRIPT_OK_EXIT_CODE: Final[int] = 0 diff --git a/packages/service-library/src/servicelib/redis/lua/acquire_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/acquire_semaphore.lua new file mode 100644 index 000000000000..b73608677909 --- /dev/null +++ b/packages/service-library/src/servicelib/redis/lua/acquire_semaphore.lua @@ -0,0 +1,40 @@ +-- Atomically acquire a distributed semaphore +-- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) +-- KEYS[2]: holder_key (individual holder TTL key) +-- ARGV[1]: instance_id +-- ARGV[2]: capacity (max concurrent holders) +-- ARGV[3]: ttl_seconds +-- +-- Returns: {exit_code, status, current_count, expired_count} +-- exit_code: 0 if acquired, 255 if failed +-- status: 'acquired' or 'capacity_full' +-- current_count: number of holders after operation +-- expired_count: number of expired entries cleaned up + +local semaphore_key = KEYS[1] +local holder_key = KEYS[2] +local instance_id = ARGV[1] +local capacity = tonumber(ARGV[2]) +local ttl_seconds = tonumber(ARGV[3]) + +-- Get current Redis server time +local time_result = redis.call('TIME') +local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) + +-- Step 1: Clean up expired entries +local expiry_threshold = current_time - ttl_seconds +local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) + +-- Step 2: Check current capacity after cleanup +local current_count = redis.call('ZCARD', semaphore_key) + +-- Step 3: Try to acquire if under capacity +if current_count < capacity then + -- Atomically add to semaphore and set holder key + redis.call('ZADD', semaphore_key, current_time, instance_id) + redis.call('SETEX', holder_key, ttl_seconds, '1') + + return {0, 'acquired', current_count + 1, expired_count} +else + return {255, 'capacity_full', current_count, expired_count} +end diff --git a/packages/service-library/src/servicelib/redis/lua/count_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/count_semaphore.lua new file mode 100644 index 000000000000..a7c023bae89a --- /dev/null +++ b/packages/service-library/src/servicelib/redis/lua/count_semaphore.lua @@ -0,0 +1,23 @@ +-- Atomically count current semaphore holders (with cleanup) +-- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) +-- ARGV[1]: ttl_seconds +-- +-- Returns: {current_count, expired_count} +-- current_count: number of active holders after cleanup +-- expired_count: number of expired entries cleaned up + +local semaphore_key = KEYS[1] +local ttl_seconds = tonumber(ARGV[1]) + +-- Get current Redis server time +local time_result = redis.call('TIME') +local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) + +-- Step 1: Clean up expired entries +local expiry_threshold = current_time - ttl_seconds +local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) + +-- Step 2: Count remaining entries +local current_count = redis.call('ZCARD', semaphore_key) + +return {current_count, expired_count} diff --git a/packages/service-library/src/servicelib/redis/lua/release_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/release_semaphore.lua new file mode 100644 index 000000000000..a1411060a99a --- /dev/null +++ b/packages/service-library/src/servicelib/redis/lua/release_semaphore.lua @@ -0,0 +1,46 @@ +-- Atomically release a distributed semaphore +-- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) +-- KEYS[2]: holder_key (individual holder TTL key) +-- ARGV[1]: instance_id +-- ARGV[2]: ttl_seconds +-- +-- Returns: {success, status, current_count, expired_count} +-- exit_code: 0 if released, 255 if failed +-- status: 'released', 'not_held', or 'already_expired' +-- current_count: number of holders after operation +-- expired_count: number of expired entries cleaned up + +local semaphore_key = KEYS[1] +local holder_key = KEYS[2] +local instance_id = ARGV[1] +local ttl_seconds = tonumber(ARGV[2]) + +-- Get current Redis server time +local time_result = redis.call('TIME') +local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) + +-- Step 1: Clean up expired entries +local expiry_threshold = current_time - ttl_seconds +local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) + +-- Step 2: Check if this instance currently holds the semaphore +local score = redis.call('ZSCORE', semaphore_key, instance_id) + +if score == false then + -- Instance doesn't hold the semaphore + local current_count = redis.call('ZCARD', semaphore_key) + return {255, 'not_held', current_count, expired_count} +end + +-- Step 3: Remove the semaphore entry and holder key +local removed_from_zset = redis.call('ZREM', semaphore_key, instance_id) +local removed_holder = redis.call('DEL', holder_key) + +local current_count = redis.call('ZCARD', semaphore_key) + +if removed_from_zset == 1 then + return {0, 'released', current_count, expired_count} +else + -- This shouldn't happen since we checked ZSCORE above, but handle it + return {255, 'already_expired', current_count, expired_count} +end diff --git a/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua new file mode 100644 index 000000000000..231d9d421ea0 --- /dev/null +++ b/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua @@ -0,0 +1,49 @@ +-- Atomically renew a distributed semaphore holder's TTL +-- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) +-- KEYS[2]: holder_key (individual holder TTL key) +-- ARGV[1]: instance_id +-- ARGV[2]: ttl_seconds +-- +-- Returns: {success, status, current_count, expired_count} +-- exit_code: 0 if renewed, 255 if failed +-- status: 'renewed', 'not_held', or 'expired' +-- current_count: number of holders after operation +-- expired_count: number of expired entries cleaned up + +local semaphore_key = KEYS[1] +local holder_key = KEYS[2] +local instance_id = ARGV[1] +local ttl_seconds = tonumber(ARGV[2]) + +-- Get current Redis server time +local time_result = redis.call('TIME') +local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) + +-- Step 1: Clean up expired entries +local expiry_threshold = current_time - ttl_seconds +local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) + +-- Step 2: Check if this instance currently holds the semaphore +local score = redis.call('ZSCORE', semaphore_key, instance_id) + +if score == false then + -- Instance doesn't hold the semaphore + local current_count = redis.call('ZCARD', semaphore_key) + return {255, 'not_held', current_count, expired_count} +end + +-- Step 3: Check if the holder key still exists (not expired) +local exists = redis.call('EXISTS', holder_key) +if exists == 0 then + -- Holder key expired, remove from semaphore and fail renewal + redis.call('ZREM', semaphore_key, instance_id) + local current_count = redis.call('ZCARD', semaphore_key) + return {255, 'expired', current_count, expired_count + 1} +end + +-- Step 4: Renew both the semaphore entry and holder key +redis.call('ZADD', semaphore_key, current_time, instance_id) +redis.call('SETEX', holder_key, ttl_seconds, '1') + +local current_count = redis.call('ZCARD', semaphore_key) +return {0, 'renewed', current_count, expired_count} diff --git a/packages/service-library/tests/redis/conftest.py b/packages/service-library/tests/redis/conftest.py index c975dc1f4ad9..f29c76bdfb22 100644 --- a/packages/service-library/tests/redis/conftest.py +++ b/packages/service-library/tests/redis/conftest.py @@ -30,3 +30,18 @@ def with_short_default_redis_lock_ttl(mocker: MockerFixture) -> datetime.timedel short_ttl = datetime.timedelta(seconds=0.25) mocker.patch.object(redis_constants, "DEFAULT_LOCK_TTL", short_ttl) return short_ttl + + +@pytest.fixture +def semaphore_name(faker: Faker) -> str: + return faker.pystr() + + +@pytest.fixture +def semaphore_capacity() -> int: + return 3 + + +@pytest.fixture +def short_ttl() -> datetime.timedelta: + return datetime.timedelta(seconds=1) diff --git a/packages/service-library/tests/redis/test_semaphore.py b/packages/service-library/tests/redis/test_semaphore.py new file mode 100644 index 000000000000..14042589b5a0 --- /dev/null +++ b/packages/service-library/tests/redis/test_semaphore.py @@ -0,0 +1,347 @@ +# ruff: noqa: SLF001, EM101, TRY003, PT011, PLR0917 +# pylint: disable=no-value-for-parameter +# pylint: disable=protected-access +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument +# pylint: disable=unused-variable + +import asyncio +import datetime + +import pytest +from faker import Faker +from pytest_mock import MockerFixture +from servicelib.redis import RedisClientSDK +from servicelib.redis._constants import ( + DEFAULT_SEMAPHORE_TTL, + SEMAPHORE_HOLDER_KEY_PREFIX, + SEMAPHORE_KEY_PREFIX, +) +from servicelib.redis._errors import SemaphoreLostError +from servicelib.redis._semaphore import ( + DistributedSemaphore, + SemaphoreAcquisitionError, + SemaphoreNotAcquiredError, +) + +pytest_simcore_core_services_selection = [ + "redis", +] +pytest_simcore_ops_services_selection = [ + "redis-commander", +] + + +@pytest.fixture +def with_short_default_semaphore_ttl( + mocker: MockerFixture, +) -> datetime.timedelta: + short_ttl = datetime.timedelta(seconds=0.5) + mocker.patch( + "servicelib.redis._semaphore._DEFAULT_SEMAPHORE_TTL", + short_ttl, + ) + return short_ttl + + +async def test_semaphore_initialization( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, +): + semaphore = DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity + ) + + assert semaphore.key == semaphore_name + assert semaphore.capacity == semaphore_capacity + assert semaphore.ttl == DEFAULT_SEMAPHORE_TTL + assert semaphore.blocking is True + assert semaphore.instance_id is not None + assert semaphore.semaphore_key == f"{SEMAPHORE_KEY_PREFIX}{semaphore_name}" + assert semaphore.holder_key.startswith( + f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:" + ) + + +async def test_invalid_semaphore_initialization( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + with pytest.raises(ValueError, match="Input should be greater than 0"): + DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=0 + ) + + with pytest.raises(ValueError, match="Input should be greater than 0"): + DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=-1 + ) + + with pytest.raises(ValueError, match="TTL must be positive"): + DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + ttl=datetime.timedelta(seconds=0), + ) + with pytest.raises(ValueError, match="Timeout must be positive"): + DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + ttl=datetime.timedelta(seconds=10), + blocking=True, + blocking_timeout=datetime.timedelta(seconds=0), + ) + + +async def test_semaphore_acquire_release_single( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, +): + semaphore = DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity + ) + + # Initially not acquired + + # Acquire successfully + result = await semaphore.acquire() + assert result is True + + # Check Redis state + assert await semaphore.get_current_count() == 1 + assert await semaphore.get_available_count() == semaphore_capacity - 1 + + # Acquire again on same instance should return True immediately and keep the same count (reentrant) + result = await semaphore.acquire() + assert result is True + assert await semaphore.get_current_count() == 1 + assert await semaphore.get_available_count() == semaphore_capacity - 1 + + # reacquire should just work + await semaphore.reacquire() + assert await semaphore.get_current_count() == 1 + assert await semaphore.get_available_count() == semaphore_capacity - 1 + + # Release + await semaphore.release() + assert await semaphore.get_current_count() == 0 + assert await semaphore.get_available_count() == semaphore_capacity + + # reacquire after release should fail + with pytest.raises( + SemaphoreLostError, + match=f"Semaphore '{semaphore_name}' was lost by this instance", + ): + await semaphore.reacquire() + + +async def test_semaphore_context_manager( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, +): + async with DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity + ) as semaphore: + assert await semaphore.get_current_count() == 1 + + # Should be released after context + assert await semaphore.get_current_count() == 0 + + +async def test_semaphore_release_without_acquire_raises( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, +): + semaphore = DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity + ) + + with pytest.raises( + SemaphoreNotAcquiredError, + match=f"Semaphore '{semaphore_name}' was not acquired by this instance", + ): + await semaphore.release() + + +async def test_semaphore_multiple_instances_capacity_limit( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + capacity = 2 + semaphores = [ + DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity + ) + for _ in range(4) + ] + + # Acquire first two should succeed + assert await semaphores[0].acquire() is True + assert await semaphores[1].acquire() is True + + # Third and fourth should fail in non-blocking mode + for semaphore in semaphores[2:]: + semaphore.blocking = False + assert await semaphore.acquire() is False + + # Check counts + assert await semaphores[0].get_current_count() == 2 + assert await semaphores[0].get_available_count() == 0 + + # Release one + await semaphores[0].release() + assert await semaphores[0].get_current_count() == 1 + assert await semaphores[0].get_available_count() == 1 + + # Now third can acquire + assert await semaphores[2].acquire() is True + + # Clean up + await semaphores[1].release() + await semaphores[2].release() + + +async def test_semaphore_blocking_timeout( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + capacity = 1 + timeout = datetime.timedelta(seconds=0.1) + + # First semaphore acquires + async with DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity + ): + # Second semaphore should timeout + semaphore2 = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=capacity, + blocking_timeout=timeout, + ) + + with pytest.raises( + SemaphoreAcquisitionError, + match=f"Could not acquire semaphore '{semaphore_name}' \\(capacity: {capacity}\\)", + ): + await semaphore2.acquire() + + +async def test_semaphore_blocking_acquire_waits( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + capacity = 1 + semaphore1 = DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity + ) + semaphore2 = DistributedSemaphore( + redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity + ) + + # First acquires immediately + await semaphore1.acquire() + + # Second will wait + async def delayed_release() -> None: + await asyncio.sleep(0.1) + await semaphore1.release() + + acquire_task = asyncio.create_task(semaphore2.acquire()) + release_task = asyncio.create_task(delayed_release()) + + # Both should complete successfully + results = await asyncio.gather(acquire_task, release_task) + assert results[0] is True # acquire succeeded + + await semaphore2.release() + + +async def test_semaphore_context_manager_with_exception( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, +): + captured_semaphore: DistributedSemaphore | None = None + + async def _raising_context(): + async with DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ) as sem: + nonlocal captured_semaphore + captured_semaphore = sem + msg = "Test exception" + raise RuntimeError(msg) + + with pytest.raises(RuntimeError, match="Test exception"): + await _raising_context() + + # Should be released even after exception + assert captured_semaphore is not None + # captured_semaphore is guaranteed to be not None by the assert above + assert await captured_semaphore.get_current_count() == 0 + + +async def test_semaphore_ttl_cleanup( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, + short_ttl: datetime.timedelta, +): + # Create semaphore with explicit short TTL + semaphore = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + + # Manually add an expired entry + expired_instance_id = "expired-instance" + current_time = asyncio.get_event_loop().time() + # Make sure it's definitely expired by using the short TTL + expired_time = current_time - short_ttl.total_seconds() - 1 + + await redis_client_sdk.redis.zadd( + semaphore.semaphore_key, {expired_instance_id: expired_time} + ) + + # Verify the entry was added + initial_count = await redis_client_sdk.redis.zcard(semaphore.semaphore_key) + assert initial_count == 1 + + # Current count should clean up expired entries + count = await semaphore.get_current_count() + assert count == 0 + + # Verify expired entry was removed + remaining = await redis_client_sdk.redis.zcard(semaphore.semaphore_key) + assert remaining == 0 + + +async def test_multiple_semaphores_different_keys( + redis_client_sdk: RedisClientSDK, + faker: Faker, +): + """Test that semaphores with different keys don't interfere""" + key1 = faker.pystr() + key2 = faker.pystr() + capacity = 1 + + async with ( + DistributedSemaphore( + redis_client=redis_client_sdk, key=key1, capacity=capacity + ), + DistributedSemaphore( + redis_client=redis_client_sdk, key=key2, capacity=capacity + ), + ): + ... diff --git a/packages/service-library/tests/redis/test_semaphore_decorator.py b/packages/service-library/tests/redis/test_semaphore_decorator.py new file mode 100644 index 000000000000..7cee29331806 --- /dev/null +++ b/packages/service-library/tests/redis/test_semaphore_decorator.py @@ -0,0 +1,691 @@ +# ruff: noqa: SLF001, EM101, TRY003, PT011, PLR0917 +# pylint: disable=no-value-for-parameter +# pylint: disable=protected-access +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument +# pylint: disable=unused-variable + +import asyncio +import datetime +import logging +from contextlib import asynccontextmanager +from typing import Literal + +import pytest +from pytest_mock import MockerFixture +from servicelib.redis import RedisClientSDK +from servicelib.redis._constants import ( + SEMAPHORE_HOLDER_KEY_PREFIX, +) +from servicelib.redis._errors import SemaphoreLostError +from servicelib.redis._semaphore import ( + DistributedSemaphore, + SemaphoreAcquisitionError, +) +from servicelib.redis._semaphore_decorator import ( + with_limited_concurrency, + with_limited_concurrency_cm, +) + +pytest_simcore_core_services_selection = [ + "redis", +] +pytest_simcore_ops_services_selection = [ + "redis-commander", +] + + +async def test_basic_functionality( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + call_count = 0 + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=1, + ) + async def limited_function(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.1) + return call_count + + # Multiple concurrent calls + tasks = [asyncio.create_task(limited_function()) for _ in range(3)] + results = await asyncio.gather(*tasks) + + # All should complete successfully + assert len(results) == 3 + assert all(isinstance(r, int) for r in results) + + +async def test_auto_renewal( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, + short_ttl: datetime.timedelta, +): + work_started = asyncio.Event() + work_completed = asyncio.Event() + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + async def long_running_work() -> Literal["success"]: + work_started.set() + # Wait longer than TTL to ensure renewal works + await asyncio.sleep(short_ttl.total_seconds() * 2) + work_completed.set() + return "success" + + task = asyncio.create_task(long_running_work()) + await work_started.wait() + + # Check that semaphore is being held + temp_semaphore = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + assert await temp_semaphore.get_current_count() == 1 + assert await temp_semaphore.get_available_count() == semaphore_capacity - 1 + + # Wait for work to complete + result = await task + assert result == "success" + assert work_completed.is_set() + + # After completion, semaphore should be released + assert await temp_semaphore.get_current_count() == 0 + assert await temp_semaphore.get_available_count() == semaphore_capacity + + +async def test_auto_renewal_lose_semaphore_raises( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, + short_ttl: datetime.timedelta, +): + work_started = asyncio.Event() + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + async def coro_that_should_fail() -> Literal["should not reach here"]: + work_started.set() + # Wait long enough for renewal to be attempted multiple times + await asyncio.sleep(short_ttl.total_seconds() * 100) + return "should not reach here" + + task = asyncio.create_task(coro_that_should_fail()) + await work_started.wait() + + # Wait for the first renewal interval to pass + renewal_interval = short_ttl / 3 + await asyncio.sleep(renewal_interval.total_seconds() * 1.5) + + # Find and delete all holder keys for this semaphore + holder_keys = await redis_client_sdk.redis.keys( + f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:*" + ) + assert holder_keys, "Holder keys should exist before deletion" + await redis_client_sdk.redis.delete(*holder_keys) + + # wait another renewal interval to ensure the renewal fails + await asyncio.sleep(renewal_interval.total_seconds() * 1.5) + + # it shall have raised already, do not wait too much + async with asyncio.timeout(renewal_interval.total_seconds()): + with pytest.raises(SemaphoreLostError): + await task + + +async def test_decorator_with_callable_parameters( + redis_client_sdk: RedisClientSDK, +): + executed_keys = [] + + def get_redis_client(*args, **kwargs): + return redis_client_sdk + + def get_key(user_id: str, resource: str) -> str: + return f"{user_id}-{resource}" + + def get_capacity(user_id: str, resource: str) -> int: + return 2 + + @with_limited_concurrency( + get_redis_client, + key=get_key, + capacity=get_capacity, + ) + async def process_user_resource(user_id: str, resource: str): + executed_keys.append(f"{user_id}-{resource}") + await asyncio.sleep(0.05) + + # Test with different parameters + await asyncio.gather( + process_user_resource("user1", "wallet1"), + process_user_resource("user1", "wallet2"), + process_user_resource("user2", "wallet1"), + ) + + assert len(executed_keys) == 3 + assert "user1-wallet1" in executed_keys + assert "user1-wallet2" in executed_keys + assert "user2-wallet1" in executed_keys + + +async def test_decorator_capacity_enforcement( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + concurrent_count = 0 + max_concurrent = 0 + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=2, + ) + async def limited_function(): + nonlocal concurrent_count, max_concurrent + concurrent_count += 1 + max_concurrent = max(max_concurrent, concurrent_count) + await asyncio.sleep(0.1) + concurrent_count -= 1 + + # Start 5 concurrent tasks + tasks = [asyncio.create_task(limited_function()) for _ in range(5)] + await asyncio.gather(*tasks) + + # Should never exceed capacity of 2 + assert max_concurrent <= 2 + + +async def test_exception_handling( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=1, + ) + async def failing_function(): + raise RuntimeError("Test exception") + + with pytest.raises(RuntimeError, match="Test exception"): + await failing_function() + + # Semaphore should be released even after exception + # Test by trying to acquire again + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=1, + ) + async def success_function(): + return "success" + + result = await success_function() + assert result == "success" + + +async def test_non_blocking_behavior( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + # Test the blocking timeout behavior + started_event = asyncio.Event() + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=1, + blocking=True, + blocking_timeout=datetime.timedelta(seconds=0.1), + ) + async def limited_function() -> None: + started_event.set() + await asyncio.sleep(2) + + # Start first task that will hold the semaphore + task1 = asyncio.create_task(limited_function()) + await started_event.wait() # Wait until semaphore is actually acquired + + # Second task should timeout and raise an exception + with pytest.raises(SemaphoreAcquisitionError): + await limited_function() + + await task1 + + # now doing the same with non-blocking should raise + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=1, + blocking=False, + blocking_timeout=datetime.timedelta(seconds=0.1), + ) + async def limited_function_non_blocking() -> None: + await asyncio.sleep(0.5) + + tasks = [asyncio.create_task(limited_function_non_blocking()) for _ in range(3)] + results = await asyncio.gather(*tasks, return_exceptions=True) + assert len(results) == 3 + assert any(isinstance(r, SemaphoreAcquisitionError) for r in results) + + +async def test_user_exceptions_properly_reraised( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, + short_ttl: datetime.timedelta, + mocker: MockerFixture, +): + class UserFunctionError(Exception): + """Custom exception to ensure we're catching the right exception""" + + work_started = asyncio.Event() + + # Track that auto-renewal is actually happening + from servicelib.redis._semaphore import DistributedSemaphore + + spied_renew_fct = mocker.spy(DistributedSemaphore, "reacquire") + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, # Short TTL to ensure renewal happens + ) + async def failing_function(): + work_started.set() + # Wait long enough for at least one renewal to happen + await asyncio.sleep(short_ttl.total_seconds() * 0.8) + # Then raise our custom exception + raise UserFunctionError("User function failed intentionally") + + # Verify the exception is properly re-raised + with pytest.raises(UserFunctionError, match="User function failed intentionally"): + await failing_function() + + # Ensure work actually started + assert work_started.is_set() + + # Verify auto-renewal was working (at least one renewal should have happened) + assert ( + spied_renew_fct.call_count >= 1 + ), "Auto-renewal should have been called at least once" + + # Verify semaphore was properly released by trying to acquire it again + test_semaphore = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + assert ( + await test_semaphore.get_current_count() == 0 + ), "Semaphore should be released after exception" + + +async def test_cancelled_error_preserved( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, +): + """Test that CancelledError is properly preserved through the decorator""" + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ) + async def function_raising_cancelled_error(): + raise asyncio.CancelledError + + # Verify CancelledError is preserved + with pytest.raises(asyncio.CancelledError): + await function_raising_cancelled_error() + + +@pytest.mark.heavy_load +async def test_with_large_capacity( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + large_capacity = 100 + concurrent_count = 0 + max_concurrent = 0 + sleep_time_s = 5 + num_tasks = 1000 + + @with_limited_concurrency( + redis_client_sdk, + key=semaphore_name, + capacity=large_capacity, + blocking=True, + blocking_timeout=None, + ) + async def limited_function() -> None: + nonlocal concurrent_count, max_concurrent + concurrent_count += 1 + max_concurrent = max(max_concurrent, concurrent_count) + logging.info("Started task, current concurrent: %d", concurrent_count) + await asyncio.sleep(sleep_time_s) + logging.info("Done task, current concurrent: %d", concurrent_count) + concurrent_count -= 1 + + # Start tasks equal to the large capacity + tasks = [asyncio.create_task(limited_function()) for _ in range(num_tasks)] + done, pending = await asyncio.wait( + tasks, + timeout=float(num_tasks) / float(large_capacity) * 10.0 * float(sleep_time_s), + ) + assert not pending, f"Some tasks did not complete: {len(pending)} pending" + assert len(done) == num_tasks + + # Should never exceed the large capacity + assert max_concurrent <= large_capacity + + +async def test_context_manager_basic_functionality( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + call_count = 0 + + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=1, + ) + @asynccontextmanager + async def limited_context_manager(): + nonlocal call_count + call_count += 1 + yield call_count + + # Multiple concurrent context managers + async def use_context_manager() -> int: + async with limited_context_manager() as value: + await asyncio.sleep(0.1) + return value + + tasks = [asyncio.create_task(use_context_manager()) for _ in range(3)] + results = await asyncio.gather(*tasks) + + # All should complete successfully + assert len(results) == 3 + assert all(isinstance(r, int) for r in results) + + +async def test_context_manager_capacity_enforcement( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + concurrent_count = 0 + max_concurrent = 0 + + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=2, + ) + @asynccontextmanager + async def limited_context_manager(): + nonlocal concurrent_count, max_concurrent + concurrent_count += 1 + max_concurrent = max(max_concurrent, concurrent_count) + try: + yield + await asyncio.sleep(0.1) + finally: + concurrent_count -= 1 + + async def use_context_manager() -> None: + async with limited_context_manager(): + await asyncio.sleep(0.1) + + # Start concurrent context managers + tasks = [asyncio.create_task(use_context_manager()) for _ in range(20)] + await asyncio.gather(*tasks) + + # Should never exceed capacity of 2 + assert max_concurrent <= 2 + + +async def test_context_manager_exception_handling( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=1, + ) + @asynccontextmanager + async def failing_context_manager(): + yield + raise RuntimeError("Test exception") + + with pytest.raises(RuntimeError, match="Test exception"): + async with failing_context_manager(): + pass + + # Semaphore should be released even after exception + + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=1, + ) + @asynccontextmanager + async def success_context_manager(): + yield "success" + + async with success_context_manager() as result: + assert result == "success" + + +async def test_context_manager_auto_renewal( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, + short_ttl: datetime.timedelta, +): + work_started = asyncio.Event() + work_completed = asyncio.Event() + + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + @asynccontextmanager + async def long_running_context_manager(): + work_started.set() + yield "data" + # Wait longer than TTL to ensure renewal works + await asyncio.sleep(short_ttl.total_seconds() * 2) + work_completed.set() + + async def use_long_running_cm(): + async with long_running_context_manager() as data: + assert data == "data" + # Keep context manager active for longer than TTL + await asyncio.sleep(short_ttl.total_seconds() * 1.5) + + task = asyncio.create_task(use_long_running_cm()) + await work_started.wait() + + # Check that semaphore is being held + temp_semaphore = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + assert await temp_semaphore.get_current_count() == 1 + assert await temp_semaphore.get_available_count() == semaphore_capacity - 1 + + # Wait for work to complete + await task + assert work_completed.is_set() + + # After completion, semaphore should be released + assert await temp_semaphore.get_current_count() == 0 + assert await temp_semaphore.get_available_count() == semaphore_capacity + + +async def test_context_manager_with_callable_parameters( + redis_client_sdk: RedisClientSDK, +): + executed_keys = [] + + def get_redis_client(*args, **kwargs): + return redis_client_sdk + + def get_key(user_id: str, resource: str) -> str: + return f"{user_id}-{resource}" + + def get_capacity(user_id: str, resource: str) -> int: + return 2 + + @with_limited_concurrency_cm( + get_redis_client, + key=get_key, + capacity=get_capacity, + ) + @asynccontextmanager + async def process_user_resource_cm(user_id: str, resource: str): + executed_keys.append(f"{user_id}-{resource}") + yield f"processed-{user_id}-{resource}" + await asyncio.sleep(0.05) + + async def use_cm(user_id: str, resource: str): + async with process_user_resource_cm(user_id, resource) as result: + return result + + # Test with different parameters + results = await asyncio.gather( + use_cm("user1", "wallet1"), + use_cm("user1", "wallet2"), + use_cm("user2", "wallet1"), + ) + + assert len(executed_keys) == 3 + assert "user1-wallet1" in executed_keys + assert "user1-wallet2" in executed_keys + assert "user2-wallet1" in executed_keys + + assert len(results) == 3 + assert "processed-user1-wallet1" in results + assert "processed-user1-wallet2" in results + assert "processed-user2-wallet1" in results + + +async def test_context_manager_non_blocking_behavior( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + started_event = asyncio.Event() + + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=1, + blocking=True, + blocking_timeout=datetime.timedelta(seconds=0.1), + ) + @asynccontextmanager + async def limited_context_manager(): + started_event.set() + yield + await asyncio.sleep(2) + + # Start first context manager that will hold the semaphore + async def long_running_cm(): + async with limited_context_manager(): + await asyncio.sleep(2) + + task1 = asyncio.create_task(long_running_cm()) + await started_event.wait() # Wait until semaphore is actually acquired + + # Second context manager should timeout and raise an exception + + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=1, + blocking=True, + blocking_timeout=datetime.timedelta(seconds=0.1), + ) + @asynccontextmanager + async def timeout_context_manager(): + yield + + with pytest.raises(SemaphoreAcquisitionError): + async with timeout_context_manager(): + pass + + await task1 + + +async def test_context_manager_lose_semaphore_raises( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + semaphore_capacity: int, + short_ttl: datetime.timedelta, +): + work_started = asyncio.Event() + + @with_limited_concurrency_cm( + redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=short_ttl, + ) + @asynccontextmanager + async def context_manager_that_should_fail(): + yield "data" + + async def use_failing_cm() -> None: + async with context_manager_that_should_fail() as data: + assert data == "data" + work_started.set() + # Wait long enough for renewal to be attempted multiple times + await asyncio.sleep(short_ttl.total_seconds() * 100) + + task = asyncio.create_task(use_failing_cm()) + await work_started.wait() + + # Wait for the first renewal interval to pass + renewal_interval = short_ttl / 3 + await asyncio.sleep(renewal_interval.total_seconds() + 1.5) + + # Find and delete all holder keys for this semaphore + holder_keys = await redis_client_sdk.redis.keys( + f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:*" + ) + assert holder_keys, "Holder keys should exist before deletion" + await redis_client_sdk.redis.delete(*holder_keys) + + # wait another renewal interval to ensure the renewal fails + await asyncio.sleep(renewal_interval.total_seconds() * 1.5) + + async with asyncio.timeout(renewal_interval.total_seconds()): + with pytest.raises(SemaphoreLostError): + await task diff --git a/services/director-v2/src/simcore_service_director_v2/api/dependencies/database.py b/services/director-v2/src/simcore_service_director_v2/api/dependencies/database.py index 949ef83bbdf6..806c122fcc8d 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/dependencies/database.py +++ b/services/director-v2/src/simcore_service_director_v2/api/dependencies/database.py @@ -26,17 +26,11 @@ def get_base_repository(engine: AsyncEngine, repo_type: type[RepoType]) -> RepoT # now the current solution is to acquire connection when needed. # Get pool metrics - checkedin = engine.pool.checkedin() # type: ignore # connections available in pool - checkedout = engine.pool.checkedout() # type: ignore # connections in use + in_use = engine.pool.checkedout() # type: ignore # connections in use total_size = engine.pool.size() # type: ignore # current total connections - if (checkedin < 2) and (total_size > 1): # noqa: PLR2004 - logger.warning( - "Database connection pool near limits: total=%d, in_use=%d, available=%d", - total_size, - checkedout, - checkedin, - ) + if (total_size > 1) and (in_use > (total_size - 2)): + logger.warning("Database connection pool near limits: %s", engine.pool.status()) return repo_type(db_engine=engine) diff --git a/services/director-v2/src/simcore_service_director_v2/core/settings.py b/services/director-v2/src/simcore_service_director_v2/core/settings.py index 7ec136b65d51..26a7e0d3a514 100644 --- a/services/director-v2/src/simcore_service_director_v2/core/settings.py +++ b/services/director-v2/src/simcore_service_director_v2/core/settings.py @@ -59,6 +59,12 @@ class ComputationalBackendSettings(BaseCustomSettings): ), ] = 50 COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED: bool = True + COMPUTATIONAL_BACKEND_PER_CLUSTER_MAX_DISTRIBUTED_CONCURRENT_CONNECTIONS: Annotated[ + PositiveInt, + Field( + description="defines how many concurrent connections to each dask scheduler are allowed accross all director-v2 replicas" + ), + ] = 20 COMPUTATIONAL_BACKEND_DEFAULT_CLUSTER_URL: Annotated[ AnyUrl, Field( diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index 193b44eb871d..ea6366ca21db 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -24,8 +24,13 @@ from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE from servicelib.logging_errors import create_troubleshootting_log_kwargs from servicelib.logging_utils import log_catch, log_context +from servicelib.redis._client import RedisClientSDK +from servicelib.redis._semaphore_decorator import ( + with_limited_concurrency_cm, +) from servicelib.utils import limited_as_completed, limited_gather +from ..._meta import APP_NAME from ...core.errors import ( ComputationalBackendNotConnectedError, ComputationalBackendOnDemandNotReadyError, @@ -67,6 +72,40 @@ _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: Final[str] = "check_time" +def _get_redis_client_from_scheduler( + _user_id: UserID, + scheduler: "DaskScheduler", + **kwargs, # pylint: disable=unused-argument # noqa: ARG001 +) -> RedisClientSDK: + return scheduler.redis_client + + +def _get_semaphore_cluster_redis_key( + user_id: UserID, + *args, # pylint: disable=unused-argument # noqa: ARG001 + run_metadata: RunMetadataDict, + **kwargs, # pylint: disable=unused-argument # noqa: ARG001 +) -> str: + return f"{APP_NAME}-cluster-user_id_{user_id}-wallet_id_{run_metadata.get('wallet_id')}" + + +def _get_semaphore_capacity_from_scheduler( + _user_id: UserID, + scheduler: "DaskScheduler", + **kwargs, # pylint: disable=unused-argument # noqa: ARG001 +) -> int: + return ( + scheduler.settings.COMPUTATIONAL_BACKEND_PER_CLUSTER_MAX_DISTRIBUTED_CONCURRENT_CONNECTIONS + ) + + +@with_limited_concurrency_cm( + _get_redis_client_from_scheduler, + key=_get_semaphore_cluster_redis_key, + capacity=_get_semaphore_capacity_from_scheduler, + blocking=True, + blocking_timeout=None, +) @asynccontextmanager async def _cluster_dask_client( user_id: UserID,