diff --git a/packages/celery-library/Makefile b/packages/celery-library/Makefile index 04596d3d1249..ccce149f8530 100644 --- a/packages/celery-library/Makefile +++ b/packages/celery-library/Makefile @@ -27,6 +27,7 @@ tests: ## runs unit tests --durations=10 \ --exitfirst \ --failed-first \ + --keep-docker-up \ --pdb \ -vv \ $(CURDIR)/tests @@ -41,6 +42,7 @@ tests-ci: ## runs unit tests --cov-report=term-missing \ --cov-report=xml \ --junitxml=junit.xml -o junit_family=legacy \ + --keep-docker-up \ --cov=celery_library \ --durations=10 \ --log-date-format="%Y-%m-%d %H:%M:%S" \ diff --git a/packages/pytest-simcore/src/pytest_simcore/redis_service.py b/packages/pytest-simcore/src/pytest_simcore/redis_service.py index 824d61a57fb9..04177c7f9e28 100644 --- a/packages/pytest-simcore/src/pytest_simcore/redis_service.py +++ b/packages/pytest-simcore/src/pytest_simcore/redis_service.py @@ -4,7 +4,6 @@ import logging from collections.abc import AsyncIterator -from datetime import timedelta import pytest import tenacity @@ -116,14 +115,6 @@ async def wait_till_redis_responsive(redis_url: URL | str) -> None: await client.aclose(close_connection_pool=True) -@pytest.fixture -def mock_redis_socket_timeout(mocker: MockerFixture) -> None: - # lowered to allow CI to properly shutdown RedisClientSDK instances - mocker.patch( - "servicelib.redis._client.DEFAULT_SOCKET_TIMEOUT", timedelta(seconds=0.25) - ) - - @pytest.fixture async def use_in_memory_redis(mocker: MockerFixture) -> RedisSettings: mocker.patch("redis.asyncio.from_url", FakeAsyncRedis) diff --git a/packages/service-library/src/servicelib/redis/_client.py b/packages/service-library/src/servicelib/redis/_client.py index de407a74fe8d..ee4e9a2040e0 100644 --- a/packages/service-library/src/servicelib/redis/_client.py +++ b/packages/service-library/src/servicelib/redis/_client.py @@ -20,7 +20,6 @@ DEFAULT_DECODE_RESPONSES, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_LOCK_TTL, - DEFAULT_SOCKET_TIMEOUT, ) _logger = logging.getLogger(__name__) @@ -49,6 +48,7 @@ class RedisClientSDK: _client: aioredis.Redis = field(init=False) _task_health_check: Task | None = None _started_event_task_health_check: asyncio.Event | None = None + _cancelled_event_task_health_check: asyncio.Event | None = None _is_healthy: bool = False @property @@ -65,20 +65,24 @@ def __post_init__(self) -> None: redis.exceptions.ConnectionError, ], retry_on_timeout=True, - socket_timeout=DEFAULT_SOCKET_TIMEOUT.total_seconds(), + socket_timeout=None, # NOTE: setting a timeout here can lead to issues with long running commands encoding="utf-8", decode_responses=self.decode_responses, client_name=self.client_name, ) self._is_healthy = False self._started_event_task_health_check = asyncio.Event() + self._cancelled_event_task_health_check = asyncio.Event() async def setup(self) -> None: @periodic(interval=self.health_check_interval) async def _periodic_check_health() -> None: assert self._started_event_task_health_check # nosec + assert self._cancelled_event_task_health_check # nosec self._started_event_task_health_check.set() self._is_healthy = await self.ping() + if self._cancelled_event_task_health_check.is_set(): + raise asyncio.CancelledError self._task_health_check = asyncio.create_task( _periodic_check_health(), @@ -100,10 +104,9 @@ async def shutdown(self) -> None: if self._task_health_check: assert self._started_event_task_health_check # nosec await self._started_event_task_health_check.wait() - - await cancel_wait_task( - self._task_health_check, max_delay=_HEALTHCHECK_TIMEOUT_S - ) + assert self._cancelled_event_task_health_check # nosec + self._cancelled_event_task_health_check.set() + await cancel_wait_task(self._task_health_check, max_delay=None) await self._client.aclose(close_connection_pool=True) diff --git a/packages/service-library/src/servicelib/redis/_constants.py b/packages/service-library/src/servicelib/redis/_constants.py index e34befa1536b..845e70d7fa8b 100644 --- a/packages/service-library/src/servicelib/redis/_constants.py +++ b/packages/service-library/src/servicelib/redis/_constants.py @@ -7,11 +7,12 @@ seconds=30 ) DEFAULT_LOCK_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10) -DEFAULT_SOCKET_TIMEOUT: Final[datetime.timedelta] = datetime.timedelta(seconds=30) +DEFAULT_SEMAPHORE_BLOCK_TIMEOUT: Final[datetime.timedelta] = datetime.timedelta( + seconds=30 +) DEFAULT_SEMAPHORE_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10) SEMAPHORE_KEY_PREFIX: Final[str] = "semaphores:" -SEMAPHORE_HOLDER_KEY_PREFIX: Final[str] = "semaphores:holders:" DEFAULT_DECODE_RESPONSES: Final[bool] = True DEFAULT_HEALTH_CHECK_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) diff --git a/packages/service-library/src/servicelib/redis/_errors.py b/packages/service-library/src/servicelib/redis/_errors.py index e83b40e4ec62..2d09a3730869 100644 --- a/packages/service-library/src/servicelib/redis/_errors.py +++ b/packages/service-library/src/servicelib/redis/_errors.py @@ -26,13 +26,23 @@ class LockLostError(BaseRedisError): ProjectLockError: TypeAlias = redis.exceptions.LockError # NOTE: backwards compatible -class SemaphoreAcquisitionError(BaseRedisError): - msg_template: str = "Could not acquire semaphore '{name}' (capacity: {capacity})" +class SemaphoreError(BaseRedisError): + msg_template: str = ( + "Unexpected error with semaphore '{name}' by this instance `{instance_id}`" + ) + + +class SemaphoreAcquisitionError(SemaphoreError): + msg_template: str = ( + "Could not acquire semaphore '{name}' by this instance `{instance_id}`" + ) -class SemaphoreNotAcquiredError(BaseRedisError): - msg_template: str = "Semaphore '{name}' was not acquired by this instance" +class SemaphoreNotAcquiredError(SemaphoreError): + msg_template: str = ( + "Semaphore '{name}' was not acquired by this instance `{instance_id}`" + ) -class SemaphoreLostError(BaseRedisError): +class SemaphoreLostError(SemaphoreError): msg_template: str = "Semaphore '{name}' was lost by this instance `{instance_id}`" diff --git a/packages/service-library/src/servicelib/redis/_semaphore.py b/packages/service-library/src/servicelib/redis/_semaphore.py index b62fbc7d238d..3fab1b459e60 100644 --- a/packages/service-library/src/servicelib/redis/_semaphore.py +++ b/packages/service-library/src/servicelib/redis/_semaphore.py @@ -1,10 +1,17 @@ +import asyncio +import contextlib import datetime import logging +import socket import uuid -from types import TracebackType +from collections.abc import AsyncIterator from typing import Annotated, ClassVar +import arrow +import redis.exceptions +from common_library.async_tools import cancel_wait_task from common_library.basic_types import DEFAULT_FACTORY +from common_library.logging.logging_errors import create_troubleshooting_log_kwargs from pydantic import ( BaseModel, Field, @@ -14,35 +21,33 @@ ) from redis.commands.core import AsyncScript from tenacity import ( - RetryError, - before_sleep_log, retry, - retry_if_not_result, - stop_after_delay, - stop_never, + retry_if_exception_type, wait_random_exponential, ) +from ..background_task import periodic from ._client import RedisClientSDK from ._constants import ( + DEFAULT_EXPECTED_LOCK_OVERALL_TIME, DEFAULT_SEMAPHORE_TTL, - DEFAULT_SOCKET_TIMEOUT, - SEMAPHORE_HOLDER_KEY_PREFIX, SEMAPHORE_KEY_PREFIX, ) from ._errors import ( SemaphoreAcquisitionError, + SemaphoreError, SemaphoreLostError, SemaphoreNotAcquiredError, ) from ._semaphore_lua import ( ACQUIRE_SEMAPHORE_SCRIPT, - COUNT_SEMAPHORE_SCRIPT, + REGISTER_SEMAPHORE_TOKEN_SCRIPT, RELEASE_SEMAPHORE_SCRIPT, RENEW_SEMAPHORE_SCRIPT, SCRIPT_BAD_EXIT_CODE, SCRIPT_OK_EXIT_CODE, ) +from ._utils import handle_redis_returns_union_types _logger = logging.getLogger(__name__) @@ -91,7 +96,7 @@ class DistributedSemaphore(BaseModel): blocking_timeout: Annotated[ datetime.timedelta | None, Field(description="Maximum time to wait when blocking"), - ] = DEFAULT_SOCKET_TIMEOUT + ] = None instance_id: Annotated[ str, Field( @@ -101,11 +106,13 @@ class DistributedSemaphore(BaseModel): ] = DEFAULT_FACTORY # Class and/or Private state attributes (not part of the model) + register_semaphore: ClassVar[AsyncScript | None] = None acquire_script: ClassVar[AsyncScript | None] = None - count_script: ClassVar[AsyncScript | None] = None release_script: ClassVar[AsyncScript | None] = None renew_script: ClassVar[AsyncScript | None] = None + _token: str | None = None # currently held token, if any + @classmethod def _register_scripts(cls, redis_client: RedisClientSDK) -> None: """Register Lua scripts with Redis if not already done. @@ -113,12 +120,12 @@ def _register_scripts(cls, redis_client: RedisClientSDK) -> None: caches the script SHA, so this is efficient. Even if called multiple times, the script is only registered once.""" if cls.acquire_script is None: + cls.register_semaphore = redis_client.redis.register_script( + REGISTER_SEMAPHORE_TOKEN_SCRIPT + ) cls.acquire_script = redis_client.redis.register_script( ACQUIRE_SEMAPHORE_SCRIPT ) - cls.count_script = redis_client.redis.register_script( - COUNT_SEMAPHORE_SCRIPT - ) cls.release_script = redis_client.redis.register_script( RELEASE_SEMAPHORE_SCRIPT ) @@ -134,19 +141,42 @@ def __init__(self, **data) -> None: @property def semaphore_key(self) -> str: """Redis key for the semaphore sorted set.""" - return f"{SEMAPHORE_KEY_PREFIX}{self.key}" + return f"{SEMAPHORE_KEY_PREFIX}{self.key}_cap{self.capacity}" + + @computed_field # type: ignore[prop-decorator] + @property + def tokens_key(self) -> str: + """Redis key for the token pool LIST.""" + return f"{self.semaphore_key}:tokens" + + @computed_field # type: ignore[prop-decorator] + @property + def holders_set(self) -> str: + """Redis key for the holders SET.""" + return f"{self.semaphore_key}:holders_set" @computed_field # type: ignore[prop-decorator] @property def holder_key(self) -> str: """Redis key for this instance's holder entry.""" - return f"{SEMAPHORE_HOLDER_KEY_PREFIX}{self.key}:{self.instance_id}" + return f"{self.semaphore_key}:holders:{self.instance_id}" + + @computed_field # type: ignore[prop-decorator] + @property + def holders_set_ttl(self) -> datetime.timedelta: + """TTL for the holders SET""" + return self.ttl * 5 + + @computed_field # type: ignore[prop-decorator] + @property + def tokens_set_ttl(self) -> datetime.timedelta: + """TTL for the tokens SET""" + return self.ttl * 5 - # Additional validation @field_validator("ttl") @classmethod def validate_ttl(cls, v: datetime.timedelta) -> datetime.timedelta: - if v.total_seconds() <= 0: + if v.total_seconds() < 1: msg = "TTL must be positive" raise ValueError(msg) return v @@ -161,6 +191,64 @@ def validate_timeout( raise ValueError(msg) return v + async def _ensure_semaphore_initialized(self) -> None: + """Initializes the semaphore in Redis if not already done.""" + assert self.register_semaphore is not None # nosec + await self.register_semaphore( # pylint: disable=not-callable + keys=[self.tokens_key, self.holders_set], + args=[self.capacity, self.holders_set_ttl.total_seconds()], + client=self.redis_client.redis, + ) + + async def _blocking_acquire(self) -> str | None: + @retry( + wait=wait_random_exponential(min=0.1, max=0.5), + retry=retry_if_exception_type(redis.exceptions.TimeoutError), + ) + async def _acquire_forever_on_socket_timeout() -> list[str] | None: + # NOTE: brpop returns None on timeout + + tokens_key_token: list[str] | None = await handle_redis_returns_union_types( + self.redis_client.redis.brpop( + [self.tokens_key], + timeout=None, # NOTE: we always block forever since tenacity takes care of timing out + ) + ) + return tokens_key_token + + try: + # NOTE: redis-py library timeouts when the defined socket timeout triggers + # The BRPOP command itself could timeout but the redis-py socket timeout defeats the purpose + # so we always block forever on BRPOP, tenacity takes care of retrying when a socket timeout happens + # and we use asyncio.timeout to enforce the blocking_timeout if defined + async with asyncio.timeout( + self.blocking_timeout.total_seconds() if self.blocking_timeout else None + ): + tokens_key_token = await _acquire_forever_on_socket_timeout() + assert tokens_key_token is not None # nosec + assert len(tokens_key_token) == 2 # nosec # noqa: PLR2004 + assert tokens_key_token[0] == self.tokens_key # nosec + return tokens_key_token[1] + except TimeoutError as e: + raise SemaphoreAcquisitionError( + name=self.key, instance_id=self.instance_id + ) from e + + async def _non_blocking_acquire(self) -> str | None: + token: str | list[str] | None = await handle_redis_returns_union_types( + self.redis_client.redis.rpop(self.tokens_key) + ) + if token is None: + _logger.debug( + "Semaphore '%s' not acquired (no tokens available) (instance: %s)", + self.key, + self.instance_id, + ) + return None + + assert isinstance(token, str) # nosec + return token + async def acquire(self) -> bool: """ Acquire the semaphore. @@ -171,115 +259,102 @@ async def acquire(self) -> bool: Raises: SemaphoreAcquisitionError: If acquisition fails and blocking=True """ + await self._ensure_semaphore_initialized() - if not self.blocking: - # Non-blocking: try once - return await self._try_acquire() + if await self.is_acquired(): + _logger.debug( + "Semaphore '%s' already acquired by this instance (instance: %s)", + self.key, + self.instance_id, + ) + return True - # Blocking - @retry( - wait=wait_random_exponential(min=0.1, max=2), - reraise=True, - stop=( - stop_after_delay(self.blocking_timeout.total_seconds()) - if self.blocking_timeout - else stop_never - ), - retry=retry_if_not_result(lambda acquired: acquired), - before_sleep=before_sleep_log(_logger, logging.DEBUG), + if self.blocking is False: + self._token = await self._non_blocking_acquire() + if not self._token: + return False + else: + self._token = await self._blocking_acquire() + + assert self._token is not None # nosec + # set up the semaphore holder with a TTL + assert self.acquire_script is not None # nosec + result = await self.acquire_script( # pylint: disable=not-callable + keys=[self.holders_set, self.holder_key], + args=[ + self._token, + self.instance_id, + self.ttl.total_seconds(), + self.holders_set_ttl.total_seconds(), + ], + client=self.redis_client.redis, ) - async def _blocking_acquire() -> bool: - return await self._try_acquire() - try: - return await _blocking_acquire() - except RetryError as exc: - raise SemaphoreAcquisitionError( - name=self.key, capacity=self.capacity - ) from exc + # Lua script returns: [exit_code, status, current_count, expired_count] + assert isinstance(result, list) # nosec + exit_code, status, token, current_count = result + + assert exit_code == SCRIPT_OK_EXIT_CODE # nosec + assert status == "acquired" # nosec + + _logger.debug( + "Acquired semaphore '%s' with token %s (instance: %s, count: %s)", + self.key, + token, + self.instance_id, + current_count, + ) + return True async def release(self) -> None: """ - Release the semaphore atomically using Lua script. + Release the semaphore Raises: SemaphoreNotAcquiredError: If semaphore was not acquired by this instance """ - ttl_seconds = int(self.ttl.total_seconds()) # Execute the release Lua script atomically - cls = type(self) - assert cls.release_script is not None # nosec - result = await cls.release_script( # pylint: disable=not-callable - keys=( - self.semaphore_key, - self.holder_key, - ), - args=( - self.instance_id, - str(ttl_seconds), - ), - client=self.redis_client.redis, - ) - - assert isinstance(result, list) # nosec - exit_code, status, current_count, expired_count = result - result = status - - if result == "released": - assert exit_code == SCRIPT_OK_EXIT_CODE # nosec - _logger.debug( - "Released semaphore '%s' (instance: %s, count: %s, expired: %s)", - self.key, - self.instance_id, - current_count, - expired_count, - ) - else: - # Instance wasn't in the semaphore set - this shouldn't happen - # but let's handle it gracefully - assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec - raise SemaphoreNotAcquiredError(name=self.key) - - async def _try_acquire(self) -> bool: - ttl_seconds = int(self.ttl.total_seconds()) - - # Execute the Lua script atomically - cls = type(self) - assert cls.acquire_script is not None # nosec - result = await cls.acquire_script( # pylint: disable=not-callable - keys=(self.semaphore_key, self.holder_key), - args=(self.instance_id, str(self.capacity), str(ttl_seconds)), + assert self.release_script is not None # nosec + release_args = [self.instance_id] + if self._token is not None: + release_args.append(self._token) + result = await self.release_script( # pylint: disable=not-callable + keys=[self.tokens_key, self.holders_set, self.holder_key], + args=release_args, client=self.redis_client.redis, ) + self._token = None - # Lua script returns: [exit_code, status, current_count, expired_count] assert isinstance(result, list) # nosec - exit_code, status, current_count, expired_count = result - + exit_code, status, current_count = result if exit_code == SCRIPT_OK_EXIT_CODE: + assert status == "released" # nosec _logger.debug( - "Acquired semaphore '%s' (instance: %s, count: %s, expired: %s)", + "Released semaphore '%s' (instance: %s, count: %s)", self.key, self.instance_id, current_count, - expired_count, ) - return True + return - _logger.debug( - "Failed to acquire semaphore '%s' - %s (count: %s, expired: %s)", + # Instance was already expired or not acquired + assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec + _logger.error( + "Failed to release semaphore '%s' - %s (instance: %s, count: %s)", self.key, status, + self.instance_id, current_count, - expired_count, ) - return False + if status == "not_held": + raise SemaphoreNotAcquiredError(name=self.key, instance_id=self.instance_id) + assert status == "expired" # nosec + raise SemaphoreLostError(name=self.key, instance_id=self.instance_id) async def reacquire(self) -> None: """ - Atomically renew a semaphore entry using Lua script. - + Re-acquire a semaphore This function is intended to be called by decorators or external renewal mechanisms. @@ -287,93 +362,184 @@ async def reacquire(self) -> None: SemaphoreLostError: If the semaphore was lost or expired """ - ttl_seconds = int(self.ttl.total_seconds()) + ttl_seconds = self.ttl.total_seconds() # Execute the renewal Lua script atomically - cls = type(self) - assert cls.renew_script is not None # nosec - result = await cls.renew_script( # pylint: disable=not-callable - keys=(self.semaphore_key, self.holder_key), - args=( + assert self.renew_script is not None # nosec + result = await self.renew_script( # pylint: disable=not-callable + keys=[self.holders_set, self.holder_key, self.tokens_key], + args=[ self.instance_id, - str(ttl_seconds), - ), + ttl_seconds, + self.holders_set_ttl.total_seconds(), + self.tokens_set_ttl.total_seconds(), + ], client=self.redis_client.redis, ) assert isinstance(result, list) # nosec - exit_code, status, current_count, expired_count = result + exit_code, status, current_count = result - # Lua script returns: 'renewed' or status message - if status == "renewed": - assert exit_code == SCRIPT_OK_EXIT_CODE # nosec + if exit_code == SCRIPT_OK_EXIT_CODE: + assert status == "renewed" # nosec _logger.debug( - "Renewed semaphore '%s' (instance: %s, count: %s, expired: %s)", + "Renewed semaphore '%s' (instance: %s, count: %s)", self.key, self.instance_id, current_count, - expired_count, ) - else: - assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec - if status == "expired": - _logger.warning( - "Semaphore '%s' holder key expired (instance: %s, count: %s, expired: %s)", - self.key, - self.instance_id, - current_count, - expired_count, - ) - elif status == "not_held": - _logger.warning( - "Semaphore '%s' not held (instance: %s, count: %s, expired: %s)", - self.key, - self.instance_id, - current_count, - expired_count, - ) + return + assert exit_code == SCRIPT_BAD_EXIT_CODE # nosec - raise SemaphoreLostError(name=self.key, instance_id=self.instance_id) + _logger.warning( + "Semaphore '%s' holder key was lost (instance: %s, status: %s, count: %s)", + self.key, + self.instance_id, + status, + current_count, + ) + if status == "not_held": + raise SemaphoreNotAcquiredError(name=self.key, instance_id=self.instance_id) + assert status == "expired" # nosec + raise SemaphoreLostError(name=self.key, instance_id=self.instance_id) + + async def is_acquired(self) -> bool: + """Check if the semaphore is currently acquired by this instance.""" + return bool( + await handle_redis_returns_union_types( + self.redis_client.redis.exists(self.holder_key) + ) + == 1 + ) - async def get_current_count(self) -> int: + async def current_count(self) -> int: """Get the current number of semaphore holders""" - ttl_seconds = int(self.ttl.total_seconds()) - - # Execute the count Lua script atomically - cls = type(self) - assert cls.count_script is not None # nosec - result = await cls.count_script( # pylint: disable=not-callable - keys=(self.semaphore_key,), - args=(str(ttl_seconds),), - client=self.redis_client.redis, + return await handle_redis_returns_union_types( + self.redis_client.redis.scard(self.holders_set) ) - assert isinstance(result, list) # nosec - current_count, expired_count = result - - if int(expired_count) > 0: - _logger.debug( - "Cleaned up %s expired entries from semaphore '%s'", - expired_count, - self.key, - ) - - return int(current_count) + async def available_tokens(self) -> int: + """Get the size of the semaphore (number of available tokens)""" + await self._ensure_semaphore_initialized() + return await handle_redis_returns_union_types( + self.redis_client.redis.llen(self.tokens_key) + ) - async def get_available_count(self) -> int: - """Get the number of available semaphore slots""" - current_count = await self.get_current_count() - return max(0, self.capacity - current_count) - # Context manager support - async def __aenter__(self) -> "DistributedSemaphore": - await self.acquire() - return self +@contextlib.asynccontextmanager +async def distributed_semaphore( # noqa: C901 + redis_client: RedisClientSDK, + *, + key: str, + capacity: PositiveInt, + ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL, + blocking: bool = True, + blocking_timeout: datetime.timedelta | None = None, + expected_lock_overall_time: datetime.timedelta = DEFAULT_EXPECTED_LOCK_OVERALL_TIME, +) -> AsyncIterator[DistributedSemaphore]: + """ + Async context manager for DistributedSemaphore. - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, + Example: + async with distributed_semaphore(redis_client, "my_resource", capacity=3) as sem: + # Only 3 instances can execute this block concurrently + await do_limited_work() + """ + semaphore = DistributedSemaphore( + redis_client=redis_client, + key=key, + capacity=capacity, + ttl=ttl, + blocking=blocking, + blocking_timeout=blocking_timeout, + ) + + @periodic(interval=semaphore.ttl / 3, raise_on_error=True) + async def _periodic_reacquisition( + semaphore: DistributedSemaphore, + started: asyncio.Event, + cancellation_event: asyncio.Event, ) -> None: - await self.release() + if cancellation_event.is_set(): + raise asyncio.CancelledError + if not started.is_set(): + started.set() + await semaphore.reacquire() + + lock_acquisition_time = None + try: + if not await semaphore.acquire(): + raise SemaphoreAcquisitionError(name=key, instance_id=semaphore.instance_id) + + lock_acquisition_time = arrow.utcnow() + + async with ( + asyncio.TaskGroup() as tg + ): # NOTE: using task group ensures proper cancellation propagation of parent task + auto_reacquisition_started = asyncio.Event() + cancellation_event = asyncio.Event() + auto_reacquisition_task = tg.create_task( + _periodic_reacquisition( + semaphore, auto_reacquisition_started, cancellation_event + ), + name=f"semaphore/auto_reacquisition_task_{semaphore.key}_{semaphore.instance_id}", + ) + await auto_reacquisition_started.wait() + try: + # NOTE: this try/finally ensures that cancellation_event is set when we exit the context + # even in case of exceptions + yield semaphore + finally: + cancellation_event.set() # NOTE: this ensure cancellation is effective + await cancel_wait_task(auto_reacquisition_task) + except BaseExceptionGroup as eg: + semaphore_errors, other_errors = eg.split(SemaphoreError) + if other_errors: + assert len(other_errors.exceptions) == 1 # nosec + raise other_errors.exceptions[0] from eg + assert semaphore_errors is not None # nosec + assert len(semaphore_errors.exceptions) == 1 # nosec + raise semaphore_errors.exceptions[0] from eg + finally: + try: + await semaphore.release() + except SemaphoreNotAcquiredError as exc: + _logger.exception( + **create_troubleshooting_log_kwargs( + f"Unexpected error while releasing semaphore '{semaphore.key}'", + error=exc, + error_context={ + "semaphore_key": semaphore.key, + "semaphore_instance_id": semaphore.instance_id, + "hostname": socket.gethostname(), + }, + tip="This indicates a logic error in the code using the semaphore", + ) + ) + except SemaphoreLostError as exc: + _logger.exception( + **create_troubleshooting_log_kwargs( + f"Unexpected error while releasing semaphore '{semaphore.key}'", + error=exc, + error_context={ + "semaphore_key": semaphore.key, + "semaphore_instance_id": semaphore.instance_id, + "hostname": socket.gethostname(), + }, + tip="This indicates that the semaphore was lost or expired before release. " + "Look for synchronouse code or the loop is very busy and cannot schedule the reacquisition task.", + ) + ) + if lock_acquisition_time is not None: + lock_release_time = arrow.utcnow() + locking_time = lock_release_time - lock_acquisition_time + if locking_time > expected_lock_overall_time: + _logger.warning( + "Semaphore '%s' was held for %s by %s which is longer than expected (%s). " + "TIP: consider reducing the locking time by optimizing the code inside " + "the critical section or increasing the default locking time", + semaphore.key, + locking_time, + semaphore.instance_id, + expected_lock_overall_time, + ) diff --git a/packages/service-library/src/servicelib/redis/_semaphore_decorator.py b/packages/service-library/src/servicelib/redis/_semaphore_decorator.py index 529eb33fb22a..72e7fd9d309e 100644 --- a/packages/service-library/src/servicelib/redis/_semaphore_decorator.py +++ b/packages/service-library/src/servicelib/redis/_semaphore_decorator.py @@ -1,29 +1,16 @@ -import asyncio import datetime import functools import logging -import socket from collections.abc import AsyncIterator, Callable, Coroutine from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, ParamSpec, TypeVar -import arrow -from common_library.async_tools import cancel_wait_task -from common_library.logging.logging_errors import create_troubleshooting_log_kwargs - -from ..background_task import periodic from ._client import RedisClientSDK from ._constants import ( DEFAULT_EXPECTED_LOCK_OVERALL_TIME, DEFAULT_SEMAPHORE_TTL, - DEFAULT_SOCKET_TIMEOUT, -) -from ._errors import ( - SemaphoreAcquisitionError, - SemaphoreLostError, - SemaphoreNotAcquiredError, ) -from ._semaphore import DistributedSemaphore +from ._semaphore import distributed_semaphore _logger = logging.getLogger(__name__) @@ -32,124 +19,6 @@ R = TypeVar("R") -@asynccontextmanager -async def _managed_semaphore_execution( - semaphore: DistributedSemaphore, - semaphore_key: str, - ttl: datetime.timedelta, - execution_context: str, - expected_lock_overall_time: datetime.timedelta, -) -> AsyncIterator: - """Common semaphore management logic with auto-renewal.""" - # Acquire the semaphore first - if not await semaphore.acquire(): - raise SemaphoreAcquisitionError(name=semaphore_key, capacity=semaphore.capacity) - - lock_acquisition_time = arrow.utcnow() - try: - # NOTE: Use TaskGroup for proper exception propagation, this ensures that in case of error the context manager will be properly exited - # and the semaphore released. - # If we use create_task() directly, exceptions in the task are not propagated to the parent task - # and the context manager may never exit, leading to semaphore leaks. - async with asyncio.TaskGroup() as tg: - started_event = asyncio.Event() - - # Create auto-renewal task - @periodic(interval=ttl / 3, raise_on_error=True) - async def _periodic_renewer() -> None: - await semaphore.reacquire() - if not started_event.is_set(): - started_event.set() - - # Start the renewal task - renewal_task = tg.create_task( - _periodic_renewer(), - name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}", - ) - await started_event.wait() - - yield - - # NOTE: if we do not explicitely await the task inside the context manager - # it sometimes hangs forever (Python issue?) - await cancel_wait_task(renewal_task, max_delay=None) - - except BaseExceptionGroup as eg: - semaphore_lost_errors, other_errors = eg.split(SemaphoreLostError) - # If there are any other errors, re-raise them - if other_errors: - assert len(other_errors.exceptions) == 1 # nosec - raise other_errors.exceptions[0] from eg - - assert semaphore_lost_errors is not None # nosec - assert len(semaphore_lost_errors.exceptions) == 1 # nosec - raise semaphore_lost_errors.exceptions[0] from eg - - finally: - # Always attempt to release the semaphore - try: - await semaphore.release() - except SemaphoreNotAcquiredError as exc: - _logger.exception( - **create_troubleshooting_log_kwargs( - f"Unexpected error while releasing semaphore '{semaphore_key}'", - error=exc, - error_context={ - "semaphore_key": semaphore_key, - "client_name": semaphore.redis_client.client_name, - "hostname": socket.gethostname(), - "execution_context": execution_context, - }, - tip="This might happen if the semaphore was lost before releasing it. " - "Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.", - ) - ) - finally: - lock_release_time = arrow.utcnow() - locking_time = lock_release_time - lock_acquisition_time - if locking_time > expected_lock_overall_time: - _logger.warning( - "Semaphore '%s' was held for %s which is longer than expected (%s). " - "TIP: consider reducing the locking time by optimizing the code inside " - "the critical section or increasing the default locking time", - semaphore_key, - locking_time, - expected_lock_overall_time, - ) - - -def _create_semaphore( - redis_client: RedisClientSDK | Callable[..., RedisClientSDK], - args: tuple[Any, ...], - *, - key: str | Callable[..., str], - capacity: int | Callable[..., int], - ttl: datetime.timedelta, - blocking: bool, - blocking_timeout: datetime.timedelta | None, - kwargs: dict[str, Any], -) -> tuple[DistributedSemaphore, str]: - """Create and configure a distributed semaphore from callable or static parameters.""" - semaphore_key = key(*args, **kwargs) if callable(key) else key - semaphore_capacity = capacity(*args, **kwargs) if callable(capacity) else capacity - client = redis_client(*args, **kwargs) if callable(redis_client) else redis_client - - assert isinstance(semaphore_key, str) # nosec - assert isinstance(semaphore_capacity, int) # nosec - assert isinstance(client, RedisClientSDK) # nosec - - semaphore = DistributedSemaphore( - redis_client=client, - key=semaphore_key, - capacity=semaphore_capacity, - ttl=ttl, - blocking=blocking, - blocking_timeout=blocking_timeout, - ) - - return semaphore, semaphore_key - - def with_limited_concurrency( redis_client: RedisClientSDK | Callable[..., RedisClientSDK], *, @@ -157,7 +26,7 @@ def with_limited_concurrency( capacity: int | Callable[..., int], ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL, blocking: bool = True, - blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT, + blocking_timeout: datetime.timedelta | None = None, expected_lock_overall_time: datetime.timedelta = DEFAULT_EXPECTED_LOCK_OVERALL_TIME, ) -> Callable[ [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] @@ -200,23 +69,28 @@ def _decorator( ) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(coro) async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - semaphore, semaphore_key = _create_semaphore( - redis_client, - args, - key=key, - capacity=capacity, + semaphore_key = key(*args, **kwargs) if callable(key) else key + semaphore_capacity = ( + capacity(*args, **kwargs) if callable(capacity) else capacity + ) + client = ( + redis_client(*args, **kwargs) + if callable(redis_client) + else redis_client + ) + + assert isinstance(semaphore_key, str) # nosec + assert isinstance(semaphore_capacity, int) # nosec + assert isinstance(client, RedisClientSDK) # nosec + + async with distributed_semaphore( + redis_client=client, + key=semaphore_key, + capacity=semaphore_capacity, ttl=ttl, blocking=blocking, blocking_timeout=blocking_timeout, - kwargs=kwargs, - ) - - async with _managed_semaphore_execution( - semaphore, - semaphore_key, - ttl, - f"coroutine_{coro.__name__}", - expected_lock_overall_time, + expected_lock_overall_time=expected_lock_overall_time, ): return await coro(*args, **kwargs) @@ -232,7 +106,7 @@ def with_limited_concurrency_cm( capacity: int | Callable[..., int], ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL, blocking: bool = True, - blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT, + blocking_timeout: datetime.timedelta | None = None, expected_lock_overall_time: datetime.timedelta = DEFAULT_EXPECTED_LOCK_OVERALL_TIME, ) -> Callable[ [Callable[P, AbstractAsyncContextManager[R]]], @@ -277,24 +151,29 @@ def _decorator( @functools.wraps(cm_func) @asynccontextmanager async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]: - semaphore, semaphore_key = _create_semaphore( - redis_client, - args, - key=key, - capacity=capacity, - ttl=ttl, - blocking=blocking, - blocking_timeout=blocking_timeout, - kwargs=kwargs, + semaphore_key = key(*args, **kwargs) if callable(key) else key + semaphore_capacity = ( + capacity(*args, **kwargs) if callable(capacity) else capacity ) + client = ( + redis_client(*args, **kwargs) + if callable(redis_client) + else redis_client + ) + + assert isinstance(semaphore_key, str) # nosec + assert isinstance(semaphore_capacity, int) # nosec + assert isinstance(client, RedisClientSDK) # nosec async with ( - _managed_semaphore_execution( - semaphore, - semaphore_key, - ttl, - f"context_manager_{cm_func.__name__}", - expected_lock_overall_time, + distributed_semaphore( + redis_client=client, + key=semaphore_key, + capacity=semaphore_capacity, + ttl=ttl, + blocking=blocking, + blocking_timeout=blocking_timeout, + expected_lock_overall_time=expected_lock_overall_time, ), cm_func(*args, **kwargs) as value, ): diff --git a/packages/service-library/src/servicelib/redis/_semaphore_lua.py b/packages/service-library/src/servicelib/redis/_semaphore_lua.py index 8bf685b30a86..71f29fa88817 100644 --- a/packages/service-library/src/servicelib/redis/_semaphore_lua.py +++ b/packages/service-library/src/servicelib/redis/_semaphore_lua.py @@ -26,10 +26,13 @@ def _load_script(script_name: str) -> str: return script_file.read_text(encoding="utf-8").strip() +# fair semaphore scripts (token pool based) +REGISTER_SEMAPHORE_TOKEN_SCRIPT: Final[str] = _load_script("register_semaphore_tokens") ACQUIRE_SEMAPHORE_SCRIPT: Final[str] = _load_script("acquire_semaphore") RELEASE_SEMAPHORE_SCRIPT: Final[str] = _load_script("release_semaphore") +CLEANUP_SEMAPHORE_SCRIPT: Final[str] = _load_script("cleanup_semaphore") RENEW_SEMAPHORE_SCRIPT: Final[str] = _load_script("renew_semaphore") -COUNT_SEMAPHORE_SCRIPT: Final[str] = _load_script("count_semaphore") -SCRIPT_BAD_EXIT_CODE: Final[int] = 255 + SCRIPT_OK_EXIT_CODE: Final[int] = 0 +SCRIPT_BAD_EXIT_CODE: Final[int] = 255 diff --git a/packages/service-library/src/servicelib/redis/lua/acquire_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/acquire_semaphore.lua index b73608677909..396a2ec34df4 100644 --- a/packages/service-library/src/servicelib/redis/lua/acquire_semaphore.lua +++ b/packages/service-library/src/servicelib/redis/lua/acquire_semaphore.lua @@ -1,40 +1,33 @@ --- Atomically acquire a distributed semaphore --- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) --- KEYS[2]: holder_key (individual holder TTL key) --- ARGV[1]: instance_id --- ARGV[2]: capacity (max concurrent holders) --- ARGV[3]: ttl_seconds +-- Fair distributed semaphore using token pool (BRPOP-based) +-- KEYS[1]: holders_key (SET of current holder instance IDs) +-- KEYS[2]: holder_key (individual holder TTL key for this instance) + +-- ARGV[1]: token (the token received from BRPOP) +-- ARGV[2]: instance_id (the instance trying to acquire the semaphore) +-- ARGV[3]: ttl_seconds (for the holder_key) +-- ARGV[4]: holders_set_ttl_seconds (to set expiry on holders set) -- --- Returns: {exit_code, status, current_count, expired_count} --- exit_code: 0 if acquired, 255 if failed --- status: 'acquired' or 'capacity_full' --- current_count: number of holders after operation --- expired_count: number of expired entries cleaned up +-- Returns: {exit_code, status, token, current_count} +-- exit_code: 0 if acquired +-- status: 'acquired' -local semaphore_key = KEYS[1] +local holders_key = KEYS[1] local holder_key = KEYS[2] -local instance_id = ARGV[1] -local capacity = tonumber(ARGV[2]) + +local token = ARGV[1] +local instance_id = ARGV[2] local ttl_seconds = tonumber(ARGV[3]) +local holders_set_ttl_seconds = tonumber(ARGV[4]) + --- Get current Redis server time -local time_result = redis.call('TIME') -local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) --- Step 1: Clean up expired entries -local expiry_threshold = current_time - ttl_seconds -local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) +-- Step 1: Register as holder +redis.call('SADD', holders_key, instance_id) +redis.call('SETEX', holder_key, ttl_seconds, token) --- Step 2: Check current capacity after cleanup -local current_count = redis.call('ZCARD', semaphore_key) +-- Step 2: Set expiry on holders set to prevent infinite growth +redis.call('EXPIRE', holders_key, holders_set_ttl_seconds) --- Step 3: Try to acquire if under capacity -if current_count < capacity then - -- Atomically add to semaphore and set holder key - redis.call('ZADD', semaphore_key, current_time, instance_id) - redis.call('SETEX', holder_key, ttl_seconds, '1') +local current_count = redis.call('SCARD', holders_key) - return {0, 'acquired', current_count + 1, expired_count} -else - return {255, 'capacity_full', current_count, expired_count} -end +return {0, 'acquired', token, current_count} diff --git a/packages/service-library/src/servicelib/redis/lua/cleanup_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/cleanup_semaphore.lua new file mode 100644 index 000000000000..34a3b87dd28d --- /dev/null +++ b/packages/service-library/src/servicelib/redis/lua/cleanup_semaphore.lua @@ -0,0 +1,56 @@ +-- Cleanup orphaned tokens from crashed clients +-- KEYS[1]: tokens_key (LIST of available tokens) +-- KEYS[2]: holders_key (SET of current holders) +-- KEYS[3]: holder_prefix (prefix for holder keys, e.g. "semaphores:holders:key:") +-- ARGV[1]: capacity (total semaphore capacity) +-- +-- Returns: {recovered_tokens, missing_tokens, excess_tokens} +-- This script should be run periodically to recover tokens from crashed clients + +local tokens_key = KEYS[1] +local holders_key = KEYS[2] +local holder_prefix = KEYS[3] + +local capacity = tonumber(ARGV[1]) + +-- Step 1: Get all current holders +local current_holders = redis.call('SMEMBERS', holders_key) +local recovered_tokens = 0 +local cleaned_holders = {} + +-- Step 2: Check each holder to see if their TTL key still exists +for i = 1, #current_holders do + local holder_id = current_holders[i] + local holder_key = holder_prefix .. holder_id + local exists = redis.call('EXISTS', holder_key) + + if exists == 0 then + -- Holder key doesn't exist but holder is in SET + -- This indicates a crashed client - clean up and recover token + redis.call('SREM', holders_key, holder_id) + redis.call('LPUSH', tokens_key, 'token_recovered_' .. holder_id) + recovered_tokens = recovered_tokens + 1 + table.insert(cleaned_holders, holder_id) + end +end + +-- Step 3: Ensure we have the correct total number of tokens +local remaining_holders = redis.call('SCARD', holders_key) +local available_tokens_count = redis.call('LLEN', tokens_key) +local total_tokens = remaining_holders + available_tokens_count + +-- If we're missing tokens (due to crashes or Redis issues), add them back +local missing_tokens = capacity - total_tokens +for i = 1, missing_tokens do + redis.call('LPUSH', tokens_key, 'token_missing_' .. i) + recovered_tokens = recovered_tokens + 1 +end + +-- If we somehow have too many tokens (shouldn't happen), remove extras +local excess_tokens = total_tokens - capacity +for i = 1, excess_tokens do + redis.call('RPOP', tokens_key) +end + + +return {recovered_tokens, missing_tokens, excess_tokens} diff --git a/packages/service-library/src/servicelib/redis/lua/count_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/count_semaphore.lua deleted file mode 100644 index a7c023bae89a..000000000000 --- a/packages/service-library/src/servicelib/redis/lua/count_semaphore.lua +++ /dev/null @@ -1,23 +0,0 @@ --- Atomically count current semaphore holders (with cleanup) --- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) --- ARGV[1]: ttl_seconds --- --- Returns: {current_count, expired_count} --- current_count: number of active holders after cleanup --- expired_count: number of expired entries cleaned up - -local semaphore_key = KEYS[1] -local ttl_seconds = tonumber(ARGV[1]) - --- Get current Redis server time -local time_result = redis.call('TIME') -local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) - --- Step 1: Clean up expired entries -local expiry_threshold = current_time - ttl_seconds -local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) - --- Step 2: Count remaining entries -local current_count = redis.call('ZCARD', semaphore_key) - -return {current_count, expired_count} diff --git a/packages/service-library/src/servicelib/redis/lua/register_semaphore_tokens.lua b/packages/service-library/src/servicelib/redis/lua/register_semaphore_tokens.lua new file mode 100644 index 000000000000..c41c53ab9fc7 --- /dev/null +++ b/packages/service-library/src/servicelib/redis/lua/register_semaphore_tokens.lua @@ -0,0 +1,29 @@ +-- Simple token initialization and management for Python BRPOP +-- KEYS[1]: tokens_key (LIST of available tokens) +-- KEYS[2]: holders_key (SET of current holder instance IDs) + +-- ARGV[1]: capacity (max concurrent holders) +-- ARGV[2]: ttl_seconds +-- +-- Returns: {exit_code} +-- exit_code: 0 if registered successfully + +local tokens_key = KEYS[1] +local holders_key = KEYS[2] + +local capacity = tonumber(ARGV[1]) +local ttl_seconds = tonumber(ARGV[2]) + +-- Step 1: Initialize token pool if needed (first time setup) +local tokens_exist = redis.call('EXISTS', tokens_key) +local holders_exist = redis.call('EXISTS', holders_key) +if tokens_exist == 0 and holders_exist == 0 then + -- Initialize with capacity number of tokens + for i = 1, capacity do + redis.call('LPUSH', tokens_key, 'token_' .. i) + end + -- Set expiry on tokens list to prevent infinite growth + redis.call('EXPIRE', tokens_key, ttl_seconds) +end + +return 0 diff --git a/packages/service-library/src/servicelib/redis/lua/release_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/release_semaphore.lua index a1411060a99a..088662c83842 100644 --- a/packages/service-library/src/servicelib/redis/lua/release_semaphore.lua +++ b/packages/service-library/src/servicelib/redis/lua/release_semaphore.lua @@ -1,46 +1,51 @@ --- Atomically release a distributed semaphore --- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) --- KEYS[2]: holder_key (individual holder TTL key) +-- Release fair semaphore and return token to pool +-- KEYS[1]: tokens_key (LIST of available tokens) +-- KEYS[2]: holders_key (SET of current holders) +-- KEYS[3]: holder_key (individual holder TTL key for this instance) + -- ARGV[1]: instance_id --- ARGV[2]: ttl_seconds +-- ARGV[2]: passed_token (the token held by this instance or nil if unknown) -- --- Returns: {success, status, current_count, expired_count} --- exit_code: 0 if released, 255 if failed --- status: 'released', 'not_held', or 'already_expired' --- current_count: number of holders after operation --- expired_count: number of expired entries cleaned up - -local semaphore_key = KEYS[1] -local holder_key = KEYS[2] -local instance_id = ARGV[1] -local ttl_seconds = tonumber(ARGV[2]) +-- Returns: {exit_code, status, current_count} +-- exit_code: 0 if released, 255 if failed +-- status: 'released', 'not_held', or 'expired' --- Get current Redis server time -local time_result = redis.call('TIME') -local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) +local tokens_key = KEYS[1] +local holders_key = KEYS[2] +local holder_key = KEYS[3] --- Step 1: Clean up expired entries -local expiry_threshold = current_time - ttl_seconds -local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) +local instance_id = ARGV[1] +local passed_token = ARGV[2] --- Step 2: Check if this instance currently holds the semaphore -local score = redis.call('ZSCORE', semaphore_key, instance_id) +-- Step 1: Check if this instance is currently a holder +local is_holder = redis.call('SISMEMBER', holders_key, instance_id) +if is_holder == 0 then + -- Not in holders set - check if holder key exists + return {255, 'not_held', redis.call('SCARD', holders_key)} +end -if score == false then - -- Instance doesn't hold the semaphore - local current_count = redis.call('ZCARD', semaphore_key) - return {255, 'not_held', current_count, expired_count} +-- Step 2: Get the token from holder key +local token = redis.call('GET', holder_key) +if not token then + -- the token expired but we are still in the holders set + -- this indicates a lost semaphore (e.g. due to TTL expiry) + -- remove from holders set and return error + redis.call('SREM', holders_key, instance_id) + -- if the token was passed return it to the pool + if passed_token then + redis.call('LPUSH', tokens_key, passed_token) + end + -- Note: we do NOT push a recovered token since we don't know its state + return {255, 'expired', redis.call('SCARD', holders_key)} end --- Step 3: Remove the semaphore entry and holder key -local removed_from_zset = redis.call('ZREM', semaphore_key, instance_id) -local removed_holder = redis.call('DEL', holder_key) +-- Step 3: Release the semaphore +redis.call('SREM', holders_key, instance_id) +redis.call('DEL', holder_key) -local current_count = redis.call('ZCARD', semaphore_key) +-- Step 4: Return token to available pool +-- This automatically unblocks any waiting BRPOP calls +redis.call('LPUSH', tokens_key, token) -if removed_from_zset == 1 then - return {0, 'released', current_count, expired_count} -else - -- This shouldn't happen since we checked ZSCORE above, but handle it - return {255, 'already_expired', current_count, expired_count} -end + +return {0, 'released', redis.call('SCARD', holders_key)} diff --git a/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua index 231d9d421ea0..ef229ef91b0b 100644 --- a/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua +++ b/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua @@ -1,49 +1,45 @@ --- Atomically renew a distributed semaphore holder's TTL --- KEYS[1]: semaphore_key (ZSET storing holders with timestamps) --- KEYS[2]: holder_key (individual holder TTL key) +-- Renew semaphore holder TTL (simplified for token pool design) +-- KEYS[1]: holders_key (SET of current holders) +-- KEYS[2]: holder_key (individual holder TTL key for this instance) +-- KEYS[3]: tokens_key (LIST of available tokens) -- ARGV[1]: instance_id -- ARGV[2]: ttl_seconds +-- ARGV[3]: holders_ttl_seconds (to renew holders set) +-- ARGV[4]: tokens_ttl_seconds (to renew tokens list) -- --- Returns: {success, status, current_count, expired_count} --- exit_code: 0 if renewed, 255 if failed --- status: 'renewed', 'not_held', or 'expired' --- current_count: number of holders after operation --- expired_count: number of expired entries cleaned up +-- Returns: {exit_code, status, current_count} +-- exit_code: 0 if renewed, 255 if failed +-- status: 'renewed', 'not_held', or 'expired' -local semaphore_key = KEYS[1] +local holders_key = KEYS[1] local holder_key = KEYS[2] +local tokens_key = KEYS[3] + local instance_id = ARGV[1] local ttl_seconds = tonumber(ARGV[2]) - --- Get current Redis server time -local time_result = redis.call('TIME') -local current_time = tonumber(time_result[1]) + (tonumber(time_result[2]) / 1000000) - --- Step 1: Clean up expired entries -local expiry_threshold = current_time - ttl_seconds -local expired_count = redis.call('ZREMRANGEBYSCORE', semaphore_key, '-inf', expiry_threshold) - --- Step 2: Check if this instance currently holds the semaphore -local score = redis.call('ZSCORE', semaphore_key, instance_id) - -if score == false then - -- Instance doesn't hold the semaphore - local current_count = redis.call('ZCARD', semaphore_key) - return {255, 'not_held', current_count, expired_count} +local holders_ttl_seconds = tonumber(ARGV[3]) +local tokens_ttl_seconds = tonumber(ARGV[4]) + +-- Step 1: Check if this instance is currently a holder +local is_holder = redis.call('SISMEMBER', holders_key, instance_id) +if is_holder == 0 then + -- Not in holders set + return {255, 'not_held', redis.call('SCARD', holders_key)} end --- Step 3: Check if the holder key still exists (not expired) +-- Step 2: Check if holder key exists (to detect if it expired) local exists = redis.call('EXISTS', holder_key) if exists == 0 then - -- Holder key expired, remove from semaphore and fail renewal - redis.call('ZREM', semaphore_key, instance_id) - local current_count = redis.call('ZCARD', semaphore_key) - return {255, 'expired', current_count, expired_count + 1} + -- Holder key expired + return {255, 'expired', redis.call('SCARD', holders_key)} end --- Step 4: Renew both the semaphore entry and holder key -redis.call('ZADD', semaphore_key, current_time, instance_id) -redis.call('SETEX', holder_key, ttl_seconds, '1') +-- Step 3: Renew the holder key TTL +local token = redis.call('GET', holder_key) +redis.call('SETEX', holder_key, ttl_seconds, token) + +-- Step 4: Renew the holders set and tokens list TTLs to prevent infinite growth +redis.call('EXPIRE', holders_key, holders_ttl_seconds) +redis.call('EXPIRE', tokens_key, tokens_ttl_seconds) -local current_count = redis.call('ZCARD', semaphore_key) -return {0, 'renewed', current_count, expired_count} +return {0, 'renewed', redis.call('SCARD', holders_key)} diff --git a/packages/service-library/tests/conftest.py b/packages/service-library/tests/conftest.py index c4f63a18a1ba..f06739423b20 100644 --- a/packages/service-library/tests/conftest.py +++ b/packages/service-library/tests/conftest.py @@ -112,7 +112,7 @@ async def _cleanup_redis_data(clients_manager: RedisClientsManager) -> None: @pytest.fixture async def get_redis_client_sdk( - mock_redis_socket_timeout: None, use_in_memory_redis: RedisSettings + use_in_memory_redis: RedisSettings, ) -> AsyncIterable[ Callable[[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]] ]: @@ -122,7 +122,7 @@ async def get_redis_client_sdk( @pytest.fixture async def get_in_process_redis_client_sdk( - mock_redis_socket_timeout: None, redis_service: RedisSettings + redis_service: RedisSettings, ) -> AsyncIterable[ Callable[[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]] ]: diff --git a/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py b/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py index 7e60c71cb30e..0bb6254542e9 100644 --- a/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py +++ b/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py @@ -3,7 +3,6 @@ import asyncio import contextlib -import datetime import itertools import json import random @@ -19,7 +18,6 @@ from common_library.json_serialization import json_dumps from common_library.serialization import model_dump_with_secrets from pydantic import NonNegativeFloat, NonNegativeInt -from pytest_mock import MockerFixture from servicelib.rabbitmq import RabbitMQClient from servicelib.redis import RedisClientSDK from servicelib.sequences_utils import partition_gen @@ -385,19 +383,10 @@ async def pause_redis(self) -> AsyncIterator[None]: yield -@pytest.fixture -def mock_default_socket_timeout(mocker: MockerFixture) -> None: - mocker.patch( - "servicelib.redis._client.DEFAULT_SOCKET_TIMEOUT", - datetime.timedelta(seconds=0.25), - ) - - @pytest.mark.parametrize("max_workers", [10]) @pytest.mark.parametrize("deferred_tasks_to_start", [100]) @pytest.mark.parametrize("service", ["rabbit", "redis"]) async def test_workflow_with_third_party_services_outages( - mock_default_socket_timeout: None, paused_container: Callable[[str], AbstractAsyncContextManager[None]], redis_client_sdk_deferred_tasks: RedisClientSDK, rabbit_client: RabbitMQClient, diff --git a/packages/service-library/tests/redis/test_client.py b/packages/service-library/tests/redis/test_client.py index 91ff29e5f38e..580c47d0facb 100644 --- a/packages/service-library/tests/redis/test_client.py +++ b/packages/service-library/tests/redis/test_client.py @@ -104,9 +104,7 @@ async def test_redis_lock_with_ttl( assert not await ttl_lock.locked() -async def test_redis_client_sdk_setup_shutdown( - mock_redis_socket_timeout: None, redis_service: RedisSettings -): +async def test_redis_client_sdk_setup_shutdown(redis_service: RedisSettings): # setup redis_resources_dns = redis_service.build_redis_dsn(RedisDatabase.RESOURCES) client = RedisClientSDK(redis_resources_dns, client_name="pytest") @@ -130,7 +128,6 @@ async def test_redis_client_sdk_setup_shutdown( async def test_regression_fails_on_redis_service_outage( - mock_redis_socket_timeout: None, paused_container: Callable[[str], AbstractAsyncContextManager[None]], redis_client_sdk: RedisClientSDK, ): diff --git a/packages/service-library/tests/redis/test_clients_manager.py b/packages/service-library/tests/redis/test_clients_manager.py index eeb110557e33..4bf5bc454f46 100644 --- a/packages/service-library/tests/redis/test_clients_manager.py +++ b/packages/service-library/tests/redis/test_clients_manager.py @@ -16,7 +16,6 @@ async def test_redis_client_sdks_manager( - mock_redis_socket_timeout: None, redis_service: RedisSettings, ): all_redis_configs: set[RedisManagerDBConfig] = { diff --git a/packages/service-library/tests/redis/test_semaphore.py b/packages/service-library/tests/redis/test_semaphore.py index 14042589b5a0..755ce716bfe2 100644 --- a/packages/service-library/tests/redis/test_semaphore.py +++ b/packages/service-library/tests/redis/test_semaphore.py @@ -7,6 +7,7 @@ import asyncio import datetime +import logging import pytest from faker import Faker @@ -14,7 +15,6 @@ from servicelib.redis import RedisClientSDK from servicelib.redis._constants import ( DEFAULT_SEMAPHORE_TTL, - SEMAPHORE_HOLDER_KEY_PREFIX, SEMAPHORE_KEY_PREFIX, ) from servicelib.redis._errors import SemaphoreLostError @@ -22,7 +22,9 @@ DistributedSemaphore, SemaphoreAcquisitionError, SemaphoreNotAcquiredError, + distributed_semaphore, ) +from servicelib.redis._utils import handle_redis_returns_union_types pytest_simcore_core_services_selection = [ "redis", @@ -36,9 +38,9 @@ def with_short_default_semaphore_ttl( mocker: MockerFixture, ) -> datetime.timedelta: - short_ttl = datetime.timedelta(seconds=0.5) + short_ttl = datetime.timedelta(seconds=5) mocker.patch( - "servicelib.redis._semaphore._DEFAULT_SEMAPHORE_TTL", + "servicelib.redis._semaphore.DEFAULT_SEMAPHORE_TTL", short_ttl, ) return short_ttl @@ -58,10 +60,13 @@ async def test_semaphore_initialization( assert semaphore.ttl == DEFAULT_SEMAPHORE_TTL assert semaphore.blocking is True assert semaphore.instance_id is not None - assert semaphore.semaphore_key == f"{SEMAPHORE_KEY_PREFIX}{semaphore_name}" - assert semaphore.holder_key.startswith( - f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:" + assert ( + semaphore.semaphore_key + == f"{SEMAPHORE_KEY_PREFIX}{semaphore_name}_cap{semaphore_capacity}" ) + assert semaphore.tokens_key.startswith(f"{semaphore.semaphore_key}:") + assert semaphore.holders_set.startswith(f"{semaphore.semaphore_key}:") + assert semaphore.holder_key.startswith(f"{semaphore.semaphore_key}:") async def test_invalid_semaphore_initialization( @@ -85,6 +90,13 @@ async def test_invalid_semaphore_initialization( capacity=1, ttl=datetime.timedelta(seconds=0), ) + with pytest.raises(ValueError, match="TTL must be positive"): + DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + ttl=datetime.timedelta(seconds=0.5), + ) with pytest.raises(ValueError, match="Timeout must be positive"): DistributedSemaphore( redis_client=redis_client_sdk, @@ -96,77 +108,188 @@ async def test_invalid_semaphore_initialization( ) -async def test_semaphore_acquire_release_single( +async def _assert_semaphore_redis_state( + redis_client_sdk: RedisClientSDK, + semaphore: DistributedSemaphore, + *, + expected_count: int, + expected_free_tokens: int, + expected_expired: bool = False, +): + """Helper to assert the internal Redis state of the semaphore""" + holders = await handle_redis_returns_union_types( + redis_client_sdk.redis.smembers(semaphore.holders_set) + ) + assert len(holders) == expected_count + if expected_count > 0: + assert semaphore.instance_id in holders + holder_key_exists = await redis_client_sdk.redis.exists(semaphore.holder_key) + if expected_expired: + assert holder_key_exists == 0 + else: + assert holder_key_exists == 1 + tokens = await handle_redis_returns_union_types( + redis_client_sdk.redis.lrange(semaphore.tokens_key, 0, -1) + ) + assert len(tokens) == expected_free_tokens + + +async def test_semaphore_acquire_release_basic( redis_client_sdk: RedisClientSDK, semaphore_name: str, semaphore_capacity: int, + with_short_default_semaphore_ttl: datetime.timedelta, ): semaphore = DistributedSemaphore( - redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=with_short_default_semaphore_ttl, ) # Initially not acquired + assert await semaphore.current_count() == 0 + assert await semaphore.available_tokens() == semaphore_capacity + assert await semaphore.is_acquired() is False + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=0, + expected_free_tokens=semaphore_capacity, + ) - # Acquire successfully + # Acquire result = await semaphore.acquire() assert result is True - - # Check Redis state - assert await semaphore.get_current_count() == 1 - assert await semaphore.get_available_count() == semaphore_capacity - 1 + assert await semaphore.current_count() == 1 + assert await semaphore.available_tokens() == semaphore_capacity - 1 + assert await semaphore.is_acquired() is True + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=semaphore_capacity - 1, + ) # Acquire again on same instance should return True immediately and keep the same count (reentrant) result = await semaphore.acquire() assert result is True - assert await semaphore.get_current_count() == 1 - assert await semaphore.get_available_count() == semaphore_capacity - 1 + assert await semaphore.current_count() == 1 + assert await semaphore.available_tokens() == semaphore_capacity - 1 + assert await semaphore.is_acquired() is True + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=semaphore_capacity - 1, + ) # reacquire should just work await semaphore.reacquire() - assert await semaphore.get_current_count() == 1 - assert await semaphore.get_available_count() == semaphore_capacity - 1 + assert await semaphore.current_count() == 1 + assert await semaphore.available_tokens() == semaphore_capacity - 1 + assert await semaphore.is_acquired() is True + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=semaphore_capacity - 1, + ) # Release await semaphore.release() - assert await semaphore.get_current_count() == 0 - assert await semaphore.get_available_count() == semaphore_capacity + assert await semaphore.current_count() == 0 + assert await semaphore.available_tokens() == semaphore_capacity + assert await semaphore.is_acquired() is False + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=0, + expected_free_tokens=semaphore_capacity, + ) # reacquire after release should fail - with pytest.raises( - SemaphoreLostError, - match=f"Semaphore '{semaphore_name}' was lost by this instance", - ): + with pytest.raises(SemaphoreNotAcquiredError): await semaphore.reacquire() + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=0, + expected_free_tokens=semaphore_capacity, + ) - -async def test_semaphore_context_manager( - redis_client_sdk: RedisClientSDK, - semaphore_name: str, - semaphore_capacity: int, -): - async with DistributedSemaphore( - redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity - ) as semaphore: - assert await semaphore.get_current_count() == 1 - - # Should be released after context - assert await semaphore.get_current_count() == 0 + # so does release again + with pytest.raises(SemaphoreNotAcquiredError): + await semaphore.release() + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=0, + expected_free_tokens=semaphore_capacity, + ) -async def test_semaphore_release_without_acquire_raises( +async def test_semaphore_acquire_release_with_ttl_expiry( redis_client_sdk: RedisClientSDK, semaphore_name: str, semaphore_capacity: int, + with_short_default_semaphore_ttl: datetime.timedelta, ): semaphore = DistributedSemaphore( - redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=semaphore_capacity, + ttl=with_short_default_semaphore_ttl, + ) + await semaphore.acquire() + assert await semaphore.current_count() == 1 + assert await semaphore.available_tokens() == semaphore_capacity - 1 + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=semaphore_capacity - 1, ) - with pytest.raises( - SemaphoreNotAcquiredError, - match=f"Semaphore '{semaphore_name}' was not acquired by this instance", - ): + # wait for TTL to expire + await asyncio.sleep(with_short_default_semaphore_ttl.total_seconds() + 0.1) + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=semaphore_capacity - 1, + expected_expired=True, + ) + + # TTL expired, reacquire should fail + with pytest.raises(SemaphoreLostError): + await semaphore.reacquire() + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=semaphore_capacity - 1, + expected_expired=True, + ) + # and release should also fail + with pytest.raises(SemaphoreLostError): + await semaphore.release() + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=0, + expected_free_tokens=semaphore_capacity, + ) + + # and release again should also fail with different error + with pytest.raises(SemaphoreNotAcquiredError): await semaphore.release() + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=0, + expected_free_tokens=semaphore_capacity, + ) async def test_semaphore_multiple_instances_capacity_limit( @@ -183,148 +306,295 @@ async def test_semaphore_multiple_instances_capacity_limit( # Acquire first two should succeed assert await semaphores[0].acquire() is True + assert await semaphores[0].is_acquired() is True + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphores[0], + expected_count=1, + expected_free_tokens=capacity - 1, + ) + assert await semaphores[1].is_acquired() is False + for sem in semaphores[:4]: + assert await sem.current_count() == 1 + assert await sem.available_tokens() == capacity - 1 + + # acquire second assert await semaphores[1].acquire() is True + for sem in semaphores[:2]: + assert await sem.is_acquired() is True + assert await sem.current_count() == 2 + assert await sem.available_tokens() == capacity - 2 + await _assert_semaphore_redis_state( + redis_client_sdk, + sem, + expected_count=2, + expected_free_tokens=capacity - 2, + ) # Third and fourth should fail in non-blocking mode - for semaphore in semaphores[2:]: - semaphore.blocking = False - assert await semaphore.acquire() is False - - # Check counts - assert await semaphores[0].get_current_count() == 2 - assert await semaphores[0].get_available_count() == 0 + for sem in semaphores[2:]: + sem.blocking = False + assert await sem.acquire() is False + assert await sem.is_acquired() is False + assert await sem.current_count() == 2 + assert await sem.available_tokens() == capacity - 2 # Release one await semaphores[0].release() - assert await semaphores[0].get_current_count() == 1 - assert await semaphores[0].get_available_count() == 1 + assert await semaphores[0].is_acquired() is False + for sem in semaphores[:4]: + assert await sem.current_count() == 1 + assert await sem.available_tokens() == capacity - 1 # Now third can acquire assert await semaphores[2].acquire() is True + for sem in semaphores[:4]: + assert await sem.current_count() == 2 + assert await sem.available_tokens() == capacity - 2 # Clean up await semaphores[1].release() await semaphores[2].release() -async def test_semaphore_blocking_timeout( +async def test_semaphore_with_timeout( redis_client_sdk: RedisClientSDK, semaphore_name: str, ): - capacity = 1 - timeout = datetime.timedelta(seconds=0.1) - - # First semaphore acquires - async with DistributedSemaphore( - redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity - ): - # Second semaphore should timeout - semaphore2 = DistributedSemaphore( - redis_client=redis_client_sdk, - key=semaphore_name, - capacity=capacity, - blocking_timeout=timeout, - ) - - with pytest.raises( - SemaphoreAcquisitionError, - match=f"Could not acquire semaphore '{semaphore_name}' \\(capacity: {capacity}\\)", - ): - await semaphore2.acquire() + timeout = datetime.timedelta(seconds=1) + semaphore1 = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + blocking_timeout=timeout, + ) + assert await semaphore1.acquire() is True + assert await semaphore1.is_acquired() is True + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore1, + expected_count=1, + expected_free_tokens=0, + ) + semaphore2 = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + blocking_timeout=timeout, + ) + # Second should timeout + with pytest.raises(SemaphoreAcquisitionError): + await semaphore2.acquire() + assert await semaphore2.is_acquired() is False + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore1, + expected_count=1, + expected_free_tokens=0, + ) -async def test_semaphore_blocking_acquire_waits( +async def test_semaphore_context_manager( redis_client_sdk: RedisClientSDK, semaphore_name: str, ): - capacity = 1 - semaphore1 = DistributedSemaphore( - redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity - ) - semaphore2 = DistributedSemaphore( - redis_client=redis_client_sdk, key=semaphore_name, capacity=capacity + async with distributed_semaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + ) as semaphore1: + assert await semaphore1.is_acquired() is True + assert await semaphore1.current_count() == 1 + assert await semaphore1.available_tokens() == 0 + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore1, + expected_count=1, + expected_free_tokens=0, + ) + assert await semaphore1.is_acquired() is False + assert await semaphore1.current_count() == 0 + assert await semaphore1.available_tokens() == 1 + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore1, + expected_count=0, + expected_free_tokens=1, ) - # First acquires immediately - await semaphore1.acquire() - # Second will wait - async def delayed_release() -> None: - await asyncio.sleep(0.1) - await semaphore1.release() - - acquire_task = asyncio.create_task(semaphore2.acquire()) - release_task = asyncio.create_task(delayed_release()) +async def test_semaphore_context_manager_with_timeout( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, +): + capacity = 1 + timeout = datetime.timedelta(seconds=0.1) - # Both should complete successfully - results = await asyncio.gather(acquire_task, release_task) - assert results[0] is True # acquire succeeded + # First semaphore acquires + async with distributed_semaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=capacity, + ) as semaphore1: + assert await semaphore1.is_acquired() is True + assert await semaphore1.current_count() == 1 + assert await semaphore1.available_tokens() == 0 + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore1, + expected_count=1, + expected_free_tokens=0, + ) + # Second semaphore should raise on timeout + with pytest.raises(SemaphoreAcquisitionError): + async with distributed_semaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=capacity, + blocking=True, + blocking_timeout=timeout, + ): + ... + + # non-blocking should also raise when used with context manager + with pytest.raises(SemaphoreAcquisitionError): + async with distributed_semaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=capacity, + blocking=False, + ): + ... + # using the semaphore directly should in non-blocking mode should return False + semaphore2 = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=capacity, + blocking=False, + ) + assert await semaphore2.acquire() is False - await semaphore2.release() + # now try infinite timeout + semaphore3 = DistributedSemaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=capacity, + blocking_timeout=None, # wait forever + ) + acquire_task = asyncio.create_task(semaphore3.acquire()) + await asyncio.sleep(5) # give some time to start acquiring + assert not acquire_task.done() +@pytest.mark.parametrize( + "exception", + [RuntimeError, asyncio.CancelledError], + ids=str, +) async def test_semaphore_context_manager_with_exception( redis_client_sdk: RedisClientSDK, semaphore_name: str, semaphore_capacity: int, + exception: type[Exception | asyncio.CancelledError], ): - captured_semaphore: DistributedSemaphore | None = None - async def _raising_context(): - async with DistributedSemaphore( + async with distributed_semaphore( redis_client=redis_client_sdk, key=semaphore_name, capacity=semaphore_capacity, - ) as sem: - nonlocal captured_semaphore - captured_semaphore = sem - msg = "Test exception" - raise RuntimeError(msg) + ): + raise exception("Test") - with pytest.raises(RuntimeError, match="Test exception"): + with pytest.raises(exception, match="Test"): await _raising_context() - # Should be released even after exception - assert captured_semaphore is not None - # captured_semaphore is guaranteed to be not None by the assert above - assert await captured_semaphore.get_current_count() == 0 - -async def test_semaphore_ttl_cleanup( +async def test_semaphore_context_manager_lost_renewal( redis_client_sdk: RedisClientSDK, semaphore_name: str, - semaphore_capacity: int, - short_ttl: datetime.timedelta, + with_short_default_semaphore_ttl: datetime.timedelta, ): - # Create semaphore with explicit short TTL - semaphore = DistributedSemaphore( + with pytest.raises(SemaphoreLostError): # noqa: PT012 + async with distributed_semaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + ttl=with_short_default_semaphore_ttl, + ) as semaphore: + assert await semaphore.is_acquired() is True + assert await semaphore.current_count() == 1 + assert await semaphore.available_tokens() == 0 + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=0, + ) + + # now simulate lost renewal by deleting the holder key + await redis_client_sdk.redis.delete(semaphore.holder_key) + # wait a bit to let the auto-renewal task detect the lost lock + # the sleep will be interrupted by the exception and the context manager will exit + with pytest.raises(asyncio.CancelledError): + await asyncio.sleep( + with_short_default_semaphore_ttl.total_seconds() + 0.5 + ) + raise asyncio.CancelledError + + +async def test_semaphore_context_manager_auto_renewal( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + with_short_default_semaphore_ttl: datetime.timedelta, +): + async with distributed_semaphore( redis_client=redis_client_sdk, key=semaphore_name, - capacity=semaphore_capacity, - ttl=short_ttl, - ) - - # Manually add an expired entry - expired_instance_id = "expired-instance" - current_time = asyncio.get_event_loop().time() - # Make sure it's definitely expired by using the short TTL - expired_time = current_time - short_ttl.total_seconds() - 1 - - await redis_client_sdk.redis.zadd( - semaphore.semaphore_key, {expired_instance_id: expired_time} - ) + capacity=1, + ttl=with_short_default_semaphore_ttl, + ) as semaphore: + assert await semaphore.is_acquired() is True + assert await semaphore.current_count() == 1 + assert await semaphore.available_tokens() == 0 + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=0, + ) - # Verify the entry was added - initial_count = await redis_client_sdk.redis.zcard(semaphore.semaphore_key) - assert initial_count == 1 + # wait for a few TTLs to ensure auto-renewal is working + total_wait = with_short_default_semaphore_ttl.total_seconds() * 3 + await asyncio.sleep(total_wait) + + # should still be acquired + assert await semaphore.is_acquired() is True + assert await semaphore.current_count() == 1 + assert await semaphore.available_tokens() == 0 + await _assert_semaphore_redis_state( + redis_client_sdk, + semaphore, + expected_count=1, + expected_free_tokens=0, + ) - # Current count should clean up expired entries - count = await semaphore.get_current_count() - assert count == 0 - # Verify expired entry was removed - remaining = await redis_client_sdk.redis.zcard(semaphore.semaphore_key) - assert remaining == 0 +async def test_semaphore_context_manager_logs_warning_when_hold_too_long( + redis_client_sdk: RedisClientSDK, + semaphore_name: str, + caplog: pytest.LogCaptureFixture, +): + """Test that a warning is logged when holding the semaphore for too long""" + with caplog.at_level(logging.WARNING): + async with distributed_semaphore( + redis_client=redis_client_sdk, + key=semaphore_name, + capacity=1, + expected_lock_overall_time=datetime.timedelta(milliseconds=200), + ): + await asyncio.sleep(0.3) + assert caplog.records + assert "longer than expected" in caplog.messages[-1] async def test_multiple_semaphores_different_keys( @@ -337,10 +607,10 @@ async def test_multiple_semaphores_different_keys( capacity = 1 async with ( - DistributedSemaphore( + distributed_semaphore( redis_client=redis_client_sdk, key=key1, capacity=capacity ), - DistributedSemaphore( + distributed_semaphore( redis_client=redis_client_sdk, key=key2, capacity=capacity ), ): diff --git a/packages/service-library/tests/redis/test_semaphore_decorator.py b/packages/service-library/tests/redis/test_semaphore_decorator.py index 21d8425caf43..7a8164bb36da 100644 --- a/packages/service-library/tests/redis/test_semaphore_decorator.py +++ b/packages/service-library/tests/redis/test_semaphore_decorator.py @@ -15,9 +15,7 @@ from pytest_mock import MockerFixture from pytest_simcore.helpers.logging_tools import log_context from servicelib.redis import RedisClientSDK -from servicelib.redis._constants import ( - SEMAPHORE_HOLDER_KEY_PREFIX, -) +from servicelib.redis._constants import SEMAPHORE_KEY_PREFIX from servicelib.redis._errors import SemaphoreLostError from servicelib.redis._semaphore import ( DistributedSemaphore, @@ -94,8 +92,8 @@ async def long_running_work() -> Literal["success"]: capacity=semaphore_capacity, ttl=short_ttl, ) - assert await temp_semaphore.get_current_count() == 1 - assert await temp_semaphore.get_available_count() == semaphore_capacity - 1 + assert await temp_semaphore.current_count() == 1 + assert await temp_semaphore.available_tokens() == semaphore_capacity - 1 # Wait for work to complete result = await task @@ -103,8 +101,8 @@ async def long_running_work() -> Literal["success"]: assert work_completed.is_set() # After completion, semaphore should be released - assert await temp_semaphore.get_current_count() == 0 - assert await temp_semaphore.get_available_count() == semaphore_capacity + assert await temp_semaphore.current_count() == 0 + assert await temp_semaphore.available_tokens() == semaphore_capacity async def test_auto_renewal_lose_semaphore_raises( @@ -136,7 +134,7 @@ async def coro_that_should_fail() -> Literal["should not reach here"]: # Find and delete all holder keys for this semaphore holder_keys = await redis_client_sdk.redis.keys( - f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:*" + f"{SEMAPHORE_KEY_PREFIX}{semaphore_name}_cap{semaphore_capacity}:holders:*" ) assert holder_keys, "Holder keys should exist before deletion" await redis_client_sdk.redis.delete(*holder_keys) @@ -337,7 +335,7 @@ async def failing_function(): ttl=short_ttl, ) assert ( - await test_semaphore.get_current_count() == 0 + await test_semaphore.current_count() == 0 ), "Semaphore should be released after exception" @@ -366,7 +364,7 @@ async def test_with_large_capacity( redis_client_sdk: RedisClientSDK, semaphore_name: str, ): - large_capacity = 20 + large_capacity = 100 concurrent_count = 0 max_concurrent = 0 sleep_time_s = 10 @@ -379,16 +377,18 @@ async def test_with_large_capacity( blocking=True, blocking_timeout=None, ) - async def limited_function() -> None: + async def limited_function(task_id: int) -> None: nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) - with log_context(logging.INFO, f"task with {concurrent_count=}"): + with log_context(logging.INFO, f"{task_id=}") as ctx: + ctx.logger.info("started %s with %s", task_id, concurrent_count) await asyncio.sleep(sleep_time_s) + ctx.logger.info("done %s with %s", task_id, concurrent_count) concurrent_count -= 1 # Start tasks equal to the large capacity - tasks = [asyncio.create_task(limited_function()) for _ in range(num_tasks)] + tasks = [asyncio.create_task(limited_function(i)) for i in range(num_tasks)] done, pending = await asyncio.wait( tasks, timeout=float(num_tasks) / float(large_capacity) * 10.0 * float(sleep_time_s), @@ -404,7 +404,6 @@ async def test_long_locking_logs_warning( redis_client_sdk: RedisClientSDK, semaphore_name: str, caplog: pytest.LogCaptureFixture, - mocker: MockerFixture, ): @with_limited_concurrency( redis_client_sdk, @@ -424,7 +423,6 @@ async def limited_function() -> None: assert "longer than expected" in caplog.messages[-1] -@pytest.mark.skip async def test_semaphore_fair_queuing( redis_client_sdk: RedisClientSDK, semaphore_name: str, @@ -438,7 +436,7 @@ async def test_semaphore_fair_queuing( ) async def limited_function(call_id: int): entered_order.append(call_id) - await asyncio.sleep(0.1) + await asyncio.sleep(0.2) return call_id # Launch tasks in a specific order @@ -446,7 +444,7 @@ async def limited_function(call_id: int): tasks = [] for i in range(num_tasks): tasks.append(asyncio.create_task(limited_function(i))) - await asyncio.sleep(0.01) # Small delay to help preserve order + await asyncio.sleep(0.1) # Small delay to help preserve order results = await asyncio.gather(*tasks) # All should complete successfully and in order @@ -460,37 +458,6 @@ async def limited_function(call_id: int): async def test_context_manager_basic_functionality( redis_client_sdk: RedisClientSDK, semaphore_name: str, -): - call_count = 0 - - @with_limited_concurrency_cm( - redis_client_sdk, - key=semaphore_name, - capacity=1, - ) - @asynccontextmanager - async def limited_context_manager(): - nonlocal call_count - call_count += 1 - yield call_count - - # Multiple concurrent context managers - async def use_context_manager() -> int: - async with limited_context_manager() as value: - await asyncio.sleep(0.1) - return value - - tasks = [asyncio.create_task(use_context_manager()) for _ in range(3)] - results = await asyncio.gather(*tasks) - - # All should complete successfully - assert len(results) == 3 - assert all(isinstance(r, int) for r in results) - - -async def test_context_manager_capacity_enforcement( - redis_client_sdk: RedisClientSDK, - semaphore_name: str, ): concurrent_count = 0 max_concurrent = 0 @@ -512,13 +479,17 @@ async def limited_context_manager(): finally: concurrent_count -= 1 - async def use_context_manager() -> None: + async def use_context_manager() -> int: async with limited_context_manager(): await asyncio.sleep(0.1) + return 1 # Start concurrent context managers tasks = [asyncio.create_task(use_context_manager()) for _ in range(20)] - await asyncio.gather(*tasks) + results = await asyncio.gather(*tasks) + # All should complete successfully + assert len(results) == 20 + assert all(isinstance(r, int) for r in results) # Should never exceed capacity of 2 assert max_concurrent <= 2 @@ -596,16 +567,16 @@ async def use_long_running_cm(): capacity=semaphore_capacity, ttl=short_ttl, ) - assert await temp_semaphore.get_current_count() == 1 - assert await temp_semaphore.get_available_count() == semaphore_capacity - 1 + assert await temp_semaphore.current_count() == 1 + assert await temp_semaphore.available_tokens() == semaphore_capacity - 1 # Wait for work to complete await task assert work_completed.is_set() # After completion, semaphore should be released - assert await temp_semaphore.get_current_count() == 0 - assert await temp_semaphore.get_available_count() == semaphore_capacity + assert await temp_semaphore.current_count() == 0 + assert await temp_semaphore.available_tokens() == semaphore_capacity async def test_context_manager_with_callable_parameters( @@ -736,7 +707,7 @@ async def use_failing_cm() -> None: # Find and delete all holder keys for this semaphore holder_keys = await redis_client_sdk.redis.keys( - f"{SEMAPHORE_HOLDER_KEY_PREFIX}{semaphore_name}:*" + f"{SEMAPHORE_KEY_PREFIX}{semaphore_name}_cap{semaphore_capacity}:holders:*" ) assert holder_keys, "Holder keys should exist before deletion" await redis_client_sdk.redis.delete(*holder_keys) diff --git a/services/web/server/tests/integration/01/notifications/test_rabbitmq_consumers.py b/services/web/server/tests/integration/01/notifications/test_rabbitmq_consumers.py index 99ef45fdd1d1..d29c362f8a77 100644 --- a/services/web/server/tests/integration/01/notifications/test_rabbitmq_consumers.py +++ b/services/web/server/tests/integration/01/notifications/test_rabbitmq_consumers.py @@ -132,7 +132,6 @@ async def _assert_handler_called_with_json( @pytest.fixture async def client( docker_registry: str, - mock_redis_socket_timeout: None, aiohttp_client: Callable, app_config: dict[str, Any], rabbit_service: RabbitSettings,