diff --git a/packages/service-library/src/servicelib/redis/_semaphore.py b/packages/service-library/src/servicelib/redis/_semaphore.py index 3fab1b459e60..a5ec957275f9 100644 --- a/packages/service-library/src/servicelib/redis/_semaphore.py +++ b/packages/service-library/src/servicelib/redis/_semaphore.py @@ -194,11 +194,15 @@ def validate_timeout( async def _ensure_semaphore_initialized(self) -> None: """Initializes the semaphore in Redis if not already done.""" assert self.register_semaphore is not None # nosec - await self.register_semaphore( # pylint: disable=not-callable + result = await self.register_semaphore( # pylint: disable=not-callable keys=[self.tokens_key, self.holders_set], args=[self.capacity, self.holders_set_ttl.total_seconds()], client=self.redis_client.redis, ) + assert isinstance(result, list) # nosec + exit_code, status = result + assert exit_code == SCRIPT_OK_EXIT_CODE # nosec + _logger.debug("Semaphore '%s' init status: %s", self.key, status) async def _blocking_acquire(self) -> str | None: @retry( diff --git a/packages/service-library/src/servicelib/redis/lua/register_semaphore_tokens.lua b/packages/service-library/src/servicelib/redis/lua/register_semaphore_tokens.lua index c41c53ab9fc7..36b2c769ef25 100644 --- a/packages/service-library/src/servicelib/redis/lua/register_semaphore_tokens.lua +++ b/packages/service-library/src/servicelib/redis/lua/register_semaphore_tokens.lua @@ -14,16 +14,25 @@ local holders_key = KEYS[2] local capacity = tonumber(ARGV[1]) local ttl_seconds = tonumber(ARGV[2]) --- Step 1: Initialize token pool if needed (first time setup) -local tokens_exist = redis.call('EXISTS', tokens_key) -local holders_exist = redis.call('EXISTS', holders_key) -if tokens_exist == 0 and holders_exist == 0 then +-- Use a persistent marker to track if semaphore was ever initialized +local init_marker_key = tokens_key .. ':initialized' + +-- Check if we've ever initialized this semaphore +local was_initialized = redis.call('EXISTS', init_marker_key) + +if was_initialized == 0 then + -- First time initialization - set the permanent marker + redis.call('SET', init_marker_key, '1') + redis.call('EXPIRE', init_marker_key, ttl_seconds) + -- Initialize with capacity number of tokens for i = 1, capacity do redis.call('LPUSH', tokens_key, 'token_' .. i) end - -- Set expiry on tokens list to prevent infinite growth + -- Set expiry on tokens list redis.call('EXPIRE', tokens_key, ttl_seconds) + return {0, 'initialized'} end -return 0 + +return {0, 'already_initialized'} diff --git a/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua b/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua index ef229ef91b0b..35b290b29d92 100644 --- a/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua +++ b/packages/service-library/src/servicelib/redis/lua/renew_semaphore.lua @@ -41,5 +41,7 @@ redis.call('SETEX', holder_key, ttl_seconds, token) -- Step 4: Renew the holders set and tokens list TTLs to prevent infinite growth redis.call('EXPIRE', holders_key, holders_ttl_seconds) redis.call('EXPIRE', tokens_key, tokens_ttl_seconds) +local init_marker_tokens_key = tokens_key .. ':initialized' +redis.call('EXPIRE', init_marker_tokens_key, tokens_ttl_seconds) return {0, 'renewed', redis.call('SCARD', holders_key)} diff --git a/packages/service-library/tests/redis/test_semaphore_decorator.py b/packages/service-library/tests/redis/test_semaphore_decorator.py index 7a8164bb36da..fa004da96fe3 100644 --- a/packages/service-library/tests/redis/test_semaphore_decorator.py +++ b/packages/service-library/tests/redis/test_semaphore_decorator.py @@ -153,7 +153,7 @@ async def test_decorator_with_callable_parameters( ): executed_keys = [] - def get_redis_client(*args, **kwargs): + def get_redis_client(*args, **kwargs) -> RedisClientSDK: return redis_client_sdk def get_key(user_id: str, resource: str) -> str: @@ -196,7 +196,7 @@ async def test_decorator_capacity_enforcement( key=semaphore_name, capacity=2, ) - async def limited_function(): + async def limited_function() -> None: nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count)