55from asyncio import Task
66from collections .abc import AsyncIterator
77from dataclasses import dataclass , field
8- from typing import Final
98from uuid import uuid4
109
1110import redis .asyncio as aioredis
1211import redis .exceptions
13- from common_library .errors_classes import OsparcErrorMixin
14- from pydantic import NonNegativeFloat , NonNegativeInt
12+ from pydantic import NonNegativeFloat
1513from redis .asyncio .lock import Lock
1614from redis .asyncio .retry import Retry
1715from redis .backoff import ExponentialBackoff
18- from settings_library .redis import RedisDatabase , RedisSettings
1916from tenacity import retry
2017from yarl import URL
2118
22- from .background_task import periodic_task
23- from .logging_utils import log_catch , log_context
24- from .retry_policies import RedisRetryPolicyUponInitialization
25-
26- _DEFAULT_LOCK_TTL : Final [datetime .timedelta ] = datetime .timedelta (seconds = 10 )
27- _DEFAULT_SOCKET_TIMEOUT : Final [datetime .timedelta ] = datetime .timedelta (seconds = 30 )
28-
29-
30- _DEFAULT_DECODE_RESPONSES : Final [bool ] = True
31- _DEFAULT_HEALTH_CHECK_INTERVAL : Final [datetime .timedelta ] = datetime .timedelta (
32- seconds = 5
19+ from ..background_task import periodic_task
20+ from ..logging_utils import log_catch
21+ from ..retry_policies import RedisRetryPolicyUponInitialization
22+ from ._constants import (
23+ DEFAULT_DECODE_RESPONSES ,
24+ DEFAULT_HEALTH_CHECK_INTERVAL ,
25+ DEFAULT_LOCK_TTL ,
26+ DEFAULT_SOCKET_TIMEOUT ,
3327)
34- _SHUTDOWN_TIMEOUT_S : Final [ NonNegativeInt ] = 5
35-
28+ from . _errors import CouldNotAcquireLockError , CouldNotConnectToRedisError
29+ from . _utils import auto_extend_lock , cancel_or_warn
3630
3731_logger = logging .getLogger (__name__ )
3832
3933
40- class BaseRedisError (OsparcErrorMixin , RuntimeError ):
41- ...
42-
43-
44- class CouldNotAcquireLockError (BaseRedisError ):
45- msg_template : str = "Lock {lock.name} could not be acquired!"
46-
47-
48- class CouldNotConnectToRedisError (BaseRedisError ):
49- msg_template : str = "Connection to '{dsn}' failed"
50-
51-
52- async def _cancel_or_warn (task : Task ) -> None :
53- if not task .cancelled ():
54- task .cancel ()
55- _ , pending = await asyncio .wait ((task ,), timeout = _SHUTDOWN_TIMEOUT_S )
56- if pending :
57- task_name = task .get_name ()
58- _logger .warning ("Could not cancel task_name=%s pending=%s" , task_name , pending )
59-
60-
6134@dataclass
6235class RedisClientSDK :
6336 redis_dsn : str
6437 client_name : str
65- decode_responses : bool = _DEFAULT_DECODE_RESPONSES
66- health_check_interval : datetime .timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL
38+ decode_responses : bool = DEFAULT_DECODE_RESPONSES
39+ health_check_interval : datetime .timedelta = DEFAULT_HEALTH_CHECK_INTERVAL
6740
6841 _client : aioredis .Redis = field (init = False )
6942 _health_check_task : Task | None = None
@@ -74,7 +47,7 @@ class RedisClientSDK:
7447 def redis (self ) -> aioredis .Redis :
7548 return self ._client
7649
77- def __post_init__ (self ):
50+ def __post_init__ (self ) -> None :
7851 self ._client = aioredis .from_url (
7952 self .redis_dsn ,
8053 # Run 3 retries with exponential backoff strategy source: https://redis.readthedocs.io/en/stable/backoff.html
@@ -84,8 +57,8 @@ def __post_init__(self):
8457 redis .exceptions .ConnectionError ,
8558 redis .exceptions .TimeoutError ,
8659 ],
87- socket_timeout = _DEFAULT_SOCKET_TIMEOUT .total_seconds (),
88- socket_connect_timeout = _DEFAULT_SOCKET_TIMEOUT .total_seconds (),
60+ socket_timeout = DEFAULT_SOCKET_TIMEOUT .total_seconds (),
61+ socket_connect_timeout = DEFAULT_SOCKET_TIMEOUT .total_seconds (),
8962 encoding = "utf-8" ,
9063 decode_responses = self .decode_responses ,
9164 client_name = self .client_name ,
@@ -113,7 +86,7 @@ async def setup(self) -> None:
11386 async def shutdown (self ) -> None :
11487 if self ._health_check_task :
11588 self ._continue_health_checking = False
116- await _cancel_or_warn (self ._health_check_task )
89+ await cancel_or_warn (self ._health_check_task )
11790 self ._health_check_task = None
11891
11992 await self ._client .aclose (close_connection_pool = True )
@@ -165,7 +138,7 @@ async def lock_context(
165138 2. `blocking==True` timeouts out while waiting for lock to be free (another entity holds the lock)
166139 """
167140
168- total_lock_duration : datetime .timedelta = _DEFAULT_LOCK_TTL
141+ total_lock_duration : datetime .timedelta = DEFAULT_LOCK_TTL
169142 lock_unique_id = f"lock_extender_{ lock_key } _{ uuid4 ()} "
170143
171144 ttl_lock : Lock = self ._client .lock (
@@ -178,15 +151,9 @@ async def lock_context(
178151 if not await ttl_lock .acquire (token = lock_value ):
179152 raise CouldNotAcquireLockError (lock = ttl_lock )
180153
181- async def _extend_lock (lock : Lock ) -> None :
182- with log_context (
183- _logger , logging .DEBUG , f"Extending lock { lock_unique_id } "
184- ), log_catch (_logger , reraise = False ):
185- await lock .reacquire ()
186-
187154 try :
188155 async with periodic_task (
189- _extend_lock ,
156+ auto_extend_lock ,
190157 interval = total_lock_duration / 2 ,
191158 task_name = lock_unique_id ,
192159 lock = ttl_lock ,
@@ -224,51 +191,3 @@ async def _extend_lock(lock: Lock) -> None:
224191 async def lock_value (self , lock_name : str ) -> str | None :
225192 output : str | None = await self ._client .get (lock_name )
226193 return output
227-
228-
229- @dataclass (frozen = True )
230- class RedisManagerDBConfig :
231- database : RedisDatabase
232- decode_responses : bool = _DEFAULT_DECODE_RESPONSES
233- health_check_interval : datetime .timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL
234-
235-
236- @dataclass
237- class RedisClientsManager :
238- """
239- Manages the lifetime of redis client sdk connections
240- """
241-
242- databases_configs : set [RedisManagerDBConfig ]
243- settings : RedisSettings
244- client_name : str
245-
246- _client_sdks : dict [RedisDatabase , RedisClientSDK ] = field (default_factory = dict )
247-
248- async def setup (self ) -> None :
249- for config in self .databases_configs :
250- self ._client_sdks [config .database ] = RedisClientSDK (
251- redis_dsn = self .settings .build_redis_dsn (config .database ),
252- decode_responses = config .decode_responses ,
253- health_check_interval = config .health_check_interval ,
254- client_name = f"{ self .client_name } " ,
255- )
256-
257- for client in self ._client_sdks .values ():
258- await client .setup ()
259-
260- async def shutdown (self ) -> None :
261- # NOTE: somehow using logged_gather is not an option
262- # doing so will make the shutdown procedure hang
263- for client in self ._client_sdks .values ():
264- await client .shutdown ()
265-
266- def client (self , database : RedisDatabase ) -> RedisClientSDK :
267- return self ._client_sdks [database ]
268-
269- async def __aenter__ (self ) -> "RedisClientsManager" :
270- await self .setup ()
271- return self
272-
273- async def __aexit__ (self , * args ):
274- await self .shutdown ()
0 commit comments