55from asyncio import Task
66from collections .abc import AsyncIterator
77from dataclasses import dataclass , field
8- from typing import Final
98from uuid import uuid4
109
1110import redis .asyncio as aioredis
1211import redis .exceptions
13- from common_library .errors_classes import OsparcErrorMixin
14- from pydantic import NonNegativeFloat , NonNegativeInt
12+ from pydantic import NonNegativeFloat
1513from redis .asyncio .lock import Lock
1614from redis .asyncio .retry import Retry
1715from redis .backoff import ExponentialBackoff
18- from settings_library .redis import RedisDatabase , RedisSettings
1916from tenacity import retry
2017from yarl import URL
2118
22- from .background_task import periodic_task
23- from .logging_utils import log_catch , log_context
24- from .retry_policies import RedisRetryPolicyUponInitialization
25-
26- _DEFAULT_LOCK_TTL : Final [datetime .timedelta ] = datetime .timedelta (seconds = 10 )
27- _DEFAULT_SOCKET_TIMEOUT : Final [datetime .timedelta ] = datetime .timedelta (seconds = 30 )
28-
29-
30- _DEFAULT_DECODE_RESPONSES : Final [bool ] = True
31- _DEFAULT_HEALTH_CHECK_INTERVAL : Final [datetime .timedelta ] = datetime .timedelta (
32- seconds = 5
19+ from ..background_task import periodic_task
20+ from ..logging_utils import log_catch
21+ from ..retry_policies import RedisRetryPolicyUponInitialization
22+ from ._constants import (
23+ DEFAULT_DECODE_RESPONSES ,
24+ DEFAULT_HEALTH_CHECK_INTERVAL ,
25+ DEFAULT_LOCK_TTL ,
26+ DEFAULT_SOCKET_TIMEOUT ,
3327)
34- _SHUTDOWN_TIMEOUT_S : Final [ NonNegativeInt ] = 5
35-
28+ from . _errors import CouldNotAcquireLockError , CouldNotConnectToRedisError
29+ from . _utils import auto_extend_lock , cancel_or_warn
3630
3731_logger = logging .getLogger (__name__ )
3832
3933
40- class BaseRedisError (OsparcErrorMixin , RuntimeError ):
41- ...
42-
43-
44- class CouldNotAcquireLockError (BaseRedisError ):
45- msg_template : str = "Lock {lock.name} could not be acquired!"
46-
47-
48- class CouldNotConnectToRedisError (BaseRedisError ):
49- msg_template : str = "Connection to '{dsn}' failed"
50-
51-
52- class LockLostError (BaseRedisError ):
53- msg_template : str = "Lock {lock.name} has been lost"
54-
55-
56- async def _cancel_or_warn (task : Task ) -> None :
57- if not task .cancelled ():
58- task .cancel ()
59- _ , pending = await asyncio .wait ((task ,), timeout = _SHUTDOWN_TIMEOUT_S )
60- if pending :
61- task_name = task .get_name ()
62- _logger .warning ("Could not cancel task_name=%s pending=%s" , task_name , pending )
63-
64-
65- async def _auto_extend_lock (lock : Lock ) -> None :
66- try :
67- with log_context (_logger , logging .DEBUG , f"Autoextend lock { lock .name } " ):
68- # TODO: if we cannot re-acquire that means the lock is lost, and we are not anymore safe and should raise all the way to the caller
69- await lock .reacquire ()
70- except redis .exceptions .LockNotOwnedError as exc :
71- raise LockLostError (lock = lock ) from exc
72-
73-
7434@dataclass
7535class RedisClientSDK :
7636 redis_dsn : str
7737 client_name : str
78- decode_responses : bool = _DEFAULT_DECODE_RESPONSES
79- health_check_interval : datetime .timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL
38+ decode_responses : bool = DEFAULT_DECODE_RESPONSES
39+ health_check_interval : datetime .timedelta = DEFAULT_HEALTH_CHECK_INTERVAL
8040
8141 _client : aioredis .Redis = field (init = False )
8242 _health_check_task : Task | None = None
@@ -97,8 +57,8 @@ def __post_init__(self):
9757 redis .exceptions .ConnectionError ,
9858 redis .exceptions .TimeoutError ,
9959 ],
100- socket_timeout = _DEFAULT_SOCKET_TIMEOUT .total_seconds (),
101- socket_connect_timeout = _DEFAULT_SOCKET_TIMEOUT .total_seconds (),
60+ socket_timeout = DEFAULT_SOCKET_TIMEOUT .total_seconds (),
61+ socket_connect_timeout = DEFAULT_SOCKET_TIMEOUT .total_seconds (),
10262 encoding = "utf-8" ,
10363 decode_responses = self .decode_responses ,
10464 client_name = self .client_name ,
@@ -126,7 +86,7 @@ async def setup(self) -> None:
12686 async def shutdown (self ) -> None :
12787 if self ._health_check_task :
12888 self ._continue_health_checking = False
129- await _cancel_or_warn (self ._health_check_task )
89+ await cancel_or_warn (self ._health_check_task )
13090 self ._health_check_task = None
13191
13292 await self ._client .aclose (close_connection_pool = True )
@@ -178,7 +138,7 @@ async def lock_context(
178138 2. `blocking==True` timeouts out while waiting for lock to be free (another entity holds the lock)
179139 """
180140
181- total_lock_duration : datetime .timedelta = _DEFAULT_LOCK_TTL
141+ total_lock_duration : datetime .timedelta = DEFAULT_LOCK_TTL
182142 lock_unique_id = f"lock_extender_{ lock_key } _{ uuid4 ()} "
183143
184144 ttl_lock : Lock = self ._client .lock (
@@ -193,7 +153,7 @@ async def lock_context(
193153
194154 try :
195155 async with periodic_task (
196- _auto_extend_lock ,
156+ auto_extend_lock ,
197157 interval = total_lock_duration / 2 ,
198158 task_name = lock_unique_id ,
199159 lock = ttl_lock ,
@@ -220,8 +180,6 @@ async def lock_context(
220180
221181 # Above implies that only one "task" `owns` and `extends` the lock at a time.
222182 # The issue appears to be related some timings (being too low).
223-
224- # TODO: Why are we silencing this because of a TEST?????
225183 try :
226184 await ttl_lock .release ()
227185 except redis .exceptions .LockNotOwnedError :
@@ -233,51 +191,3 @@ async def lock_context(
233191 async def lock_value (self , lock_name : str ) -> str | None :
234192 output : str | None = await self ._client .get (lock_name )
235193 return output
236-
237-
238- @dataclass (frozen = True )
239- class RedisManagerDBConfig :
240- database : RedisDatabase
241- decode_responses : bool = _DEFAULT_DECODE_RESPONSES
242- health_check_interval : datetime .timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL
243-
244-
245- @dataclass
246- class RedisClientsManager :
247- """
248- Manages the lifetime of redis client sdk connections
249- """
250-
251- databases_configs : set [RedisManagerDBConfig ]
252- settings : RedisSettings
253- client_name : str
254-
255- _client_sdks : dict [RedisDatabase , RedisClientSDK ] = field (default_factory = dict )
256-
257- async def setup (self ) -> None :
258- for config in self .databases_configs :
259- self ._client_sdks [config .database ] = RedisClientSDK (
260- redis_dsn = self .settings .build_redis_dsn (config .database ),
261- decode_responses = config .decode_responses ,
262- health_check_interval = config .health_check_interval ,
263- client_name = f"{ self .client_name } " ,
264- )
265-
266- for client in self ._client_sdks .values ():
267- await client .setup ()
268-
269- async def shutdown (self ) -> None :
270- # NOTE: somehow using logged_gather is not an option
271- # doing so will make the shutdown procedure hang
272- for client in self ._client_sdks .values ():
273- await client .shutdown ()
274-
275- def client (self , database : RedisDatabase ) -> RedisClientSDK :
276- return self ._client_sdks [database ]
277-
278- async def __aenter__ (self ) -> "RedisClientsManager" :
279- await self .setup ()
280- return self
281-
282- async def __aexit__ (self , * args ):
283- await self .shutdown ()
0 commit comments