Skip to content

Commit 3982a20

Browse files
committed
refactoring
1 parent 45269d2 commit 3982a20

File tree

8 files changed

+208
-166
lines changed

8 files changed

+208
-166
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from ._client import RabbitMQClient
2+
from ._errors import (
3+
CouldNotAcquireLockError,
4+
CouldNotConnectToRedisError,
5+
LockLostError,
6+
)
7+
8+
__all__: tuple[str, ...] = (
9+
"CouldNotAcquireLockError",
10+
"CouldNotConnectToRedisError",
11+
"LockLostError",
12+
)
13+
14+
# nopycln: file

packages/service-library/src/servicelib/redis.py renamed to packages/service-library/src/servicelib/redis/_client.py

Lines changed: 18 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -5,78 +5,38 @@
55
from asyncio import Task
66
from collections.abc import AsyncIterator
77
from dataclasses import dataclass, field
8-
from typing import Final
98
from uuid import uuid4
109

1110
import redis.asyncio as aioredis
1211
import redis.exceptions
13-
from common_library.errors_classes import OsparcErrorMixin
14-
from pydantic import NonNegativeFloat, NonNegativeInt
12+
from pydantic import NonNegativeFloat
1513
from redis.asyncio.lock import Lock
1614
from redis.asyncio.retry import Retry
1715
from redis.backoff import ExponentialBackoff
18-
from settings_library.redis import RedisDatabase, RedisSettings
1916
from tenacity import retry
2017
from yarl import URL
2118

22-
from .background_task import periodic_task
23-
from .logging_utils import log_catch, log_context
24-
from .retry_policies import RedisRetryPolicyUponInitialization
25-
26-
_DEFAULT_LOCK_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10)
27-
_DEFAULT_SOCKET_TIMEOUT: Final[datetime.timedelta] = datetime.timedelta(seconds=30)
28-
29-
30-
_DEFAULT_DECODE_RESPONSES: Final[bool] = True
31-
_DEFAULT_HEALTH_CHECK_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(
32-
seconds=5
19+
from ..background_task import periodic_task
20+
from ..logging_utils import log_catch
21+
from ..retry_policies import RedisRetryPolicyUponInitialization
22+
from ._constants import (
23+
DEFAULT_DECODE_RESPONSES,
24+
DEFAULT_HEALTH_CHECK_INTERVAL,
25+
DEFAULT_LOCK_TTL,
26+
DEFAULT_SOCKET_TIMEOUT,
3327
)
34-
_SHUTDOWN_TIMEOUT_S: Final[NonNegativeInt] = 5
35-
28+
from ._errors import CouldNotAcquireLockError, CouldNotConnectToRedisError
29+
from ._utils import auto_extend_lock, cancel_or_warn
3630

3731
_logger = logging.getLogger(__name__)
3832

3933

40-
class BaseRedisError(OsparcErrorMixin, RuntimeError):
41-
...
42-
43-
44-
class CouldNotAcquireLockError(BaseRedisError):
45-
msg_template: str = "Lock {lock.name} could not be acquired!"
46-
47-
48-
class CouldNotConnectToRedisError(BaseRedisError):
49-
msg_template: str = "Connection to '{dsn}' failed"
50-
51-
52-
class LockLostError(BaseRedisError):
53-
msg_template: str = "Lock {lock.name} has been lost"
54-
55-
56-
async def _cancel_or_warn(task: Task) -> None:
57-
if not task.cancelled():
58-
task.cancel()
59-
_, pending = await asyncio.wait((task,), timeout=_SHUTDOWN_TIMEOUT_S)
60-
if pending:
61-
task_name = task.get_name()
62-
_logger.warning("Could not cancel task_name=%s pending=%s", task_name, pending)
63-
64-
65-
async def _auto_extend_lock(lock: Lock) -> None:
66-
try:
67-
with log_context(_logger, logging.DEBUG, f"Autoextend lock {lock.name}"):
68-
# TODO: if we cannot re-acquire that means the lock is lost, and we are not anymore safe and should raise all the way to the caller
69-
await lock.reacquire()
70-
except redis.exceptions.LockNotOwnedError as exc:
71-
raise LockLostError(lock=lock) from exc
72-
73-
7434
@dataclass
7535
class RedisClientSDK:
7636
redis_dsn: str
7737
client_name: str
78-
decode_responses: bool = _DEFAULT_DECODE_RESPONSES
79-
health_check_interval: datetime.timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL
38+
decode_responses: bool = DEFAULT_DECODE_RESPONSES
39+
health_check_interval: datetime.timedelta = DEFAULT_HEALTH_CHECK_INTERVAL
8040

8141
_client: aioredis.Redis = field(init=False)
8242
_health_check_task: Task | None = None
@@ -97,8 +57,8 @@ def __post_init__(self):
9757
redis.exceptions.ConnectionError,
9858
redis.exceptions.TimeoutError,
9959
],
100-
socket_timeout=_DEFAULT_SOCKET_TIMEOUT.total_seconds(),
101-
socket_connect_timeout=_DEFAULT_SOCKET_TIMEOUT.total_seconds(),
60+
socket_timeout=DEFAULT_SOCKET_TIMEOUT.total_seconds(),
61+
socket_connect_timeout=DEFAULT_SOCKET_TIMEOUT.total_seconds(),
10262
encoding="utf-8",
10363
decode_responses=self.decode_responses,
10464
client_name=self.client_name,
@@ -126,7 +86,7 @@ async def setup(self) -> None:
12686
async def shutdown(self) -> None:
12787
if self._health_check_task:
12888
self._continue_health_checking = False
129-
await _cancel_or_warn(self._health_check_task)
89+
await cancel_or_warn(self._health_check_task)
13090
self._health_check_task = None
13191

13292
await self._client.aclose(close_connection_pool=True)
@@ -178,7 +138,7 @@ async def lock_context(
178138
2. `blocking==True` timeouts out while waiting for lock to be free (another entity holds the lock)
179139
"""
180140

181-
total_lock_duration: datetime.timedelta = _DEFAULT_LOCK_TTL
141+
total_lock_duration: datetime.timedelta = DEFAULT_LOCK_TTL
182142
lock_unique_id = f"lock_extender_{lock_key}_{uuid4()}"
183143

184144
ttl_lock: Lock = self._client.lock(
@@ -193,7 +153,7 @@ async def lock_context(
193153

194154
try:
195155
async with periodic_task(
196-
_auto_extend_lock,
156+
auto_extend_lock,
197157
interval=total_lock_duration / 2,
198158
task_name=lock_unique_id,
199159
lock=ttl_lock,
@@ -220,8 +180,6 @@ async def lock_context(
220180

221181
# Above implies that only one "task" `owns` and `extends` the lock at a time.
222182
# The issue appears to be related some timings (being too low).
223-
224-
# TODO: Why are we silencing this because of a TEST?????
225183
try:
226184
await ttl_lock.release()
227185
except redis.exceptions.LockNotOwnedError:
@@ -233,51 +191,3 @@ async def lock_context(
233191
async def lock_value(self, lock_name: str) -> str | None:
234192
output: str | None = await self._client.get(lock_name)
235193
return output
236-
237-
238-
@dataclass(frozen=True)
239-
class RedisManagerDBConfig:
240-
database: RedisDatabase
241-
decode_responses: bool = _DEFAULT_DECODE_RESPONSES
242-
health_check_interval: datetime.timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL
243-
244-
245-
@dataclass
246-
class RedisClientsManager:
247-
"""
248-
Manages the lifetime of redis client sdk connections
249-
"""
250-
251-
databases_configs: set[RedisManagerDBConfig]
252-
settings: RedisSettings
253-
client_name: str
254-
255-
_client_sdks: dict[RedisDatabase, RedisClientSDK] = field(default_factory=dict)
256-
257-
async def setup(self) -> None:
258-
for config in self.databases_configs:
259-
self._client_sdks[config.database] = RedisClientSDK(
260-
redis_dsn=self.settings.build_redis_dsn(config.database),
261-
decode_responses=config.decode_responses,
262-
health_check_interval=config.health_check_interval,
263-
client_name=f"{self.client_name}",
264-
)
265-
266-
for client in self._client_sdks.values():
267-
await client.setup()
268-
269-
async def shutdown(self) -> None:
270-
# NOTE: somehow using logged_gather is not an option
271-
# doing so will make the shutdown procedure hang
272-
for client in self._client_sdks.values():
273-
await client.shutdown()
274-
275-
def client(self, database: RedisDatabase) -> RedisClientSDK:
276-
return self._client_sdks[database]
277-
278-
async def __aenter__(self) -> "RedisClientsManager":
279-
await self.setup()
280-
return self
281-
282-
async def __aexit__(self, *args):
283-
await self.shutdown()
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from dataclasses import dataclass, field
2+
3+
from settings_library.redis import RedisDatabase, RedisSettings
4+
5+
from ._client import RedisClientSDK
6+
from ._models import RedisManagerDBConfig
7+
8+
9+
@dataclass
10+
class RedisClientsManager:
11+
"""
12+
Manages the lifetime of redis client sdk connections
13+
"""
14+
15+
databases_configs: set[RedisManagerDBConfig]
16+
settings: RedisSettings
17+
client_name: str
18+
19+
_client_sdks: dict[RedisDatabase, RedisClientSDK] = field(default_factory=dict)
20+
21+
async def setup(self) -> None:
22+
for config in self.databases_configs:
23+
self._client_sdks[config.database] = RedisClientSDK(
24+
redis_dsn=self.settings.build_redis_dsn(config.database),
25+
decode_responses=config.decode_responses,
26+
health_check_interval=config.health_check_interval,
27+
client_name=f"{self.client_name}",
28+
)
29+
30+
for client in self._client_sdks.values():
31+
await client.setup()
32+
33+
async def shutdown(self) -> None:
34+
# NOTE: somehow using logged_gather is not an option
35+
# doing so will make the shutdown procedure hang
36+
for client in self._client_sdks.values():
37+
await client.shutdown()
38+
39+
def client(self, database: RedisDatabase) -> RedisClientSDK:
40+
return self._client_sdks[database]
41+
42+
async def __aenter__(self) -> "RedisClientsManager":
43+
await self.setup()
44+
return self
45+
46+
async def __aexit__(self, *args) -> None:
47+
await self.shutdown()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import datetime
2+
from typing import Final
3+
4+
from pydantic import NonNegativeInt
5+
6+
DEFAULT_LOCK_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10)
7+
DEFAULT_SOCKET_TIMEOUT: Final[datetime.timedelta] = datetime.timedelta(seconds=30)
8+
9+
10+
DEFAULT_DECODE_RESPONSES: Final[bool] = True
11+
DEFAULT_HEALTH_CHECK_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5)
12+
SHUTDOWN_TIMEOUT_S: Final[NonNegativeInt] = 5
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import functools
2+
import logging
3+
from collections.abc import Awaitable, Callable
4+
from typing import ParamSpec, TypeVar
5+
6+
from ._client import RedisClientSDK
7+
8+
_logger = logging.getLogger(__file__)
9+
10+
P = ParamSpec("P")
11+
R = TypeVar("R")
12+
13+
14+
def exclusive(
15+
redis: RedisClientSDK | Callable[..., RedisClientSDK],
16+
*,
17+
lock_key: str | Callable[..., str],
18+
lock_value: bytes | str | None = None,
19+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
20+
"""
21+
Define a method to run exclusively across
22+
processes by leveraging a Redis Lock.
23+
24+
parameters:
25+
redis: the redis client SDK
26+
lock_key: a string as the name of the lock (good practice: app_name:lock_name)
27+
lock_value: some additional data that can be retrieved by another client
28+
29+
Raises:
30+
- ValueError if used incorrectly
31+
- CouldNotAcquireLockError if the lock could not be acquired
32+
"""
33+
34+
if not lock_key:
35+
msg = "lock_key cannot be empty string!"
36+
raise ValueError(msg)
37+
38+
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
39+
@functools.wraps(func)
40+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
41+
redis_lock_key = (
42+
lock_key(*args, **kwargs) if callable(lock_key) else lock_key
43+
)
44+
assert isinstance(redis_lock_key, str) # nosec
45+
46+
redis_client = redis(*args, **kwargs) if callable(redis) else redis
47+
assert isinstance(redis_client, RedisClientSDK) # nosec
48+
49+
async with redis_client.lock_context(
50+
lock_key=redis_lock_key, lock_value=lock_value
51+
):
52+
return await func(*args, **kwargs)
53+
54+
return wrapper
55+
56+
return decorator
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from common_library.errors_classes import OsparcErrorMixin
2+
3+
4+
class BaseRedisError(OsparcErrorMixin, RuntimeError):
5+
...
6+
7+
8+
class CouldNotAcquireLockError(BaseRedisError):
9+
msg_template: str = "Lock {lock.name} could not be acquired!"
10+
11+
12+
class CouldNotConnectToRedisError(BaseRedisError):
13+
msg_template: str = "Connection to '{dsn}' failed"
14+
15+
16+
class LockLostError(BaseRedisError):
17+
msg_template: str = "Lock {lock.name} has been lost"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import datetime
2+
from dataclasses import dataclass
3+
4+
from settings_library.redis import RedisDatabase
5+
6+
from ._constants import DEFAULT_DECODE_RESPONSES, DEFAULT_HEALTH_CHECK_INTERVAL
7+
8+
9+
@dataclass(frozen=True, kw_only=True)
10+
class RedisManagerDBConfig:
11+
database: RedisDatabase
12+
decode_responses: bool = DEFAULT_DECODE_RESPONSES
13+
health_check_interval: datetime.timedelta = DEFAULT_HEALTH_CHECK_INTERVAL

0 commit comments

Comments
 (0)