Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
488fa2b
Wrapp the Redis client with retry mechanism to improve disaster recovery
alex-zaikman Mar 27, 2025
9585ad3
add-redis-connection-proxy
JKL98ISR Mar 27, 2025
9d35a43
add-redis-connection-proxy
JKL98ISR Mar 27, 2025
300c114
add-redis-connection-proxy
JKL98ISR Mar 27, 2025
9fd9b2b
add-redis-connection-proxy
JKL98ISR Mar 27, 2025
d9e9597
add-redis-connection-proxy
JKL98ISR Mar 27, 2025
762e0f0
add-redis-connection-proxy
JKL98ISR Mar 27, 2025
8b3a55d
add-redis-connection-proxy
JKL98ISR Mar 27, 2025
37b8c6f
fix-test
JKL98ISR Apr 2, 2025
439d952
upgrade-redis
JKL98ISR Apr 2, 2025
6c67a73
upgrade-redis
JKL98ISR Apr 2, 2025
162739b
upgrade-redis
JKL98ISR Apr 2, 2025
4f67441
upgrade-redis
JKL98ISR Apr 2, 2025
bb7938b
Merge branch 'main' of https://github.com/deepchecks/monitoring into …
JKL98ISR Apr 2, 2025
ff78fe0
upgrade-redis
JKL98ISR Apr 2, 2025
fc74f55
upgrade-redis
JKL98ISR Apr 2, 2025
f9c3d43
upgrade-redis
JKL98ISR Apr 2, 2025
1f3d02f
isort
JKL98ISR Apr 2, 2025
3e69188
upgrade-redis
JKL98ISR Apr 3, 2025
302c5d6
upgrade-redis
JKL98ISR Apr 3, 2025
a61dfbe
upgrade-redis
JKL98ISR Apr 3, 2025
5061536
upgrade-redis
JKL98ISR Apr 3, 2025
66d80ab
upgrade-redis
JKL98ISR Apr 3, 2025
4b20d9f
upgrade-redis
JKL98ISR Apr 3, 2025
24fb0aa
Update backend/tests/unittests/test_monitor_alert_rules_executor.py
JKL98ISR Apr 3, 2025
902ee18
upgrade-redis
JKL98ISR Apr 3, 2025
ef7c23c
Merge branch 'alex/mon-2671-add-redis-connection-proxy' of https://gi…
JKL98ISR Apr 3, 2025
918fa58
upgrade-redis
JKL98ISR Apr 3, 2025
03b0a3e
upgrade-redis
JKL98ISR Apr 3, 2025
46c1b49
upgrade-redis
JKL98ISR Apr 3, 2025
abf834e
upgrade-redis
JKL98ISR Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/deepchecks_monitoring/api/v1/data_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def log_data_batch(
minute_rate = resources_provider.get_features_control(user).rows_per_minute

# Atomically getting the count and increasing in order to avoid race conditions
curr_count = resources_provider.cache_functions.get_and_incr_user_rate_count(user, time, len(data))
curr_count = await resources_provider.cache_functions.get_and_incr_user_rate_count(user, time, len(data))
remains = minute_rate - curr_count

# Remains can be negative because we don't check the limit before incrementing
Expand Down Expand Up @@ -140,7 +140,7 @@ async def log_labels(
minute_rate = resources_provider.get_features_control(user).rows_per_minute

# Atomically getting the count and increasing in order to avoid race conditions
curr_count = resources_provider.cache_functions.get_and_incr_user_rate_count(user, time, len(data), is_label=True)
curr_count = await resources_provider.cache_functions.get_and_incr_user_rate_count(user, time, len(data), is_label=True)
remains = minute_rate - curr_count

# Remains can be negative because we don't check the limit before incrementing
Expand Down
2 changes: 1 addition & 1 deletion backend/deepchecks_monitoring/bgtasks/alert_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def execute_monitor(
result_per_version = reduce_check_window(result_per_version, options)
# Save to cache
for version, result in result_per_version.items():
resources_provider.cache_functions.set_monitor_cache(
await resources_provider.cache_functions.set_monitor_cache(
organization_id, version.id, monitor_id, start_time, end_time, result)

logger.debug('Check execution result: %s', result_per_version)
Expand Down
16 changes: 3 additions & 13 deletions backend/deepchecks_monitoring/bgtasks/tasks_queuer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from deepchecks_monitoring.logic.keys import GLOBAL_TASK_QUEUE
from deepchecks_monitoring.monitoring_utils import configure_logger
from deepchecks_monitoring.public_models.task import BackgroundWorker, Task
from deepchecks_monitoring.utils.redis_proxy import RedisProxy

try:
from deepchecks_monitoring import ee
Expand All @@ -50,7 +51,7 @@ class TasksQueuer:
def __init__(
self,
resource_provider: ResourcesProvider,
redis_client: RedisCluster | Redis,
redis_client: RedisCluster | Redis | RedisProxy,
workers: t.List[BackgroundWorker],
logger: logging.Logger,
run_interval: int,
Expand Down Expand Up @@ -151,17 +152,6 @@ class Config:
env_file = '.env'
env_file_encoding = 'utf-8'


async def init_async_redis(redis_uri):
"""Initialize redis connection."""
try:
redis = RedisCluster.from_url(redis_uri)
await redis.ping()
return redis
except redis_exceptions.RedisClusterException:
return Redis.from_url(redis_uri)


def execute_worker():
"""Execute worker."""

Expand Down Expand Up @@ -195,7 +185,7 @@ async def main():

async with ResourcesProvider(settings) as rp:
async with anyio.create_task_group() as g:
async_redis = await init_async_redis(rp.redis_settings.redis_uri)
async_redis = RedisProxy(rp.redis_settings)
worker = tasks_queuer.TasksQueuer(rp, async_redis, workers, logger, settings.queuer_run_interval)
g.start_soon(worker.run)

Expand Down
14 changes: 2 additions & 12 deletions backend/deepchecks_monitoring/bgtasks/tasks_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from deepchecks_monitoring.logic.keys import GLOBAL_TASK_QUEUE, TASK_RUNNER_LOCK
from deepchecks_monitoring.monitoring_utils import configure_logger
from deepchecks_monitoring.public_models.task import BackgroundWorker, Task
from deepchecks_monitoring.utils.redis_proxy import RedisProxy

try:
from deepchecks_monitoring import ee
Expand Down Expand Up @@ -159,17 +160,6 @@ class WorkerSettings(BaseWorkerSettings, Settings):
"""Set of worker settings."""
pass


async def init_async_redis(redis_uri):
"""Initialize redis connection."""
try:
redis = RedisCluster.from_url(redis_uri)
await redis.ping()
return redis
except RedisClusterException:
return Redis.from_url(redis_uri)


def execute_worker():
"""Execute worker."""

Expand All @@ -189,7 +179,7 @@ async def main():
from deepchecks_monitoring.bgtasks import tasks_runner # pylint: disable=import-outside-toplevel

async with ResourcesProvider(settings) as rp:
async_redis = await init_async_redis(rp.redis_settings.redis_uri)
async_redis = RedisProxy(rp.redis_settings)

workers = [
ModelVersionCacheInvalidation(),
Expand Down
2 changes: 2 additions & 0 deletions backend/deepchecks_monitoring/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class RedisSettings(BaseDeepchecksSettings):
"""Redis settings."""

redis_uri: t.Optional[RedisDsn] = None
stop_after_retries: int = 3 # Number of retries before giving up
wait_between_retries: int = 3 # Time to wait between retries


class Settings(
Expand Down
16 changes: 8 additions & 8 deletions backend/deepchecks_monitoring/logic/cache_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pendulum as pdl
import redis.exceptions
from redis.client import Redis
from redis.asyncio.client import Redis

from deepchecks_monitoring.logic.keys import build_monitor_cache_key, get_invalidation_set_key

Expand All @@ -36,10 +36,10 @@ class CacheFunctions:

def __init__(self, redis_client=None):
self.use_cache = redis_client is not None
self.redis: Redis = redis_client
self.redis: Redis = redis_client
self.logger = logging.Logger("cache-functions")

def get_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time):
async def get_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time):
"""Get result from cache if exists. We can cache values which are "None" therefore to distinguish between the \
situations we return CacheResult with 'found' property."""
if self.use_cache:
Expand All @@ -48,7 +48,7 @@ def get_monitor_cache(self, organization_id, model_version_id, monitor_id, start
p = self.redis.pipeline()
p.get(key)
p.expire(key, MONITOR_CACHE_EXPIRY_TIME)
cache_value = p.execute()[0]
cache_value = (await p.execute())[0]
# If cache value is none it means the key was not found
if cache_value is not None:
return CacheResult(found=True, value=json.loads(cache_value))
Expand All @@ -58,7 +58,7 @@ def get_monitor_cache(self, organization_id, model_version_id, monitor_id, start
# Return no cache result
return CacheResult(found=False, value=None)

def set_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time, value):
async def set_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time, value):
"""Set cache value for the properties given."""
if not self.use_cache:
return
Expand All @@ -68,7 +68,7 @@ def set_monitor_cache(self, organization_id, model_version_id, monitor_id, start
p = self.redis.pipeline()
p.set(key, cache_val)
p.expire(key, MONITOR_CACHE_EXPIRY_TIME)
p.execute()
await p.execute()
except redis.exceptions.RedisError as e:
self.logger.exception(e)

Expand Down Expand Up @@ -97,15 +97,15 @@ def delete_key(self, key):
"""Remove a given key from the cache."""
self.redis.delete(key)

def get_and_incr_user_rate_count(self, user, time, count_added, is_label=True):
async def get_and_incr_user_rate_count(self, user, time, count_added, is_label=True):
"""Get the user's organization samples count for the given minute, and increase by the given amount."""
key = f"rate-limit:{user.organization.id}:{time.minute}"
if is_label:
key += ":label"
p = self.redis.pipeline()
p.incr(key, count_added)
p.expire(key, 60)
count_after_increase = p.execute()[0]
count_after_increase = (await p.execute())[0]
# Return the count before incrementing
return count_after_increase - count_added

Expand Down
2 changes: 1 addition & 1 deletion backend/deepchecks_monitoring/logic/check_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ async def run_check_per_window_in_range(
result_value = reduce_check_result(result_value, monitor_options.additional_kwargs)
# If cache available and there is monitor id, save result to cache
if cache_funcs and monitor_id:
cache_funcs.set_monitor_cache(organization_id, model_version.id, monitor_id, result_dict["start"],
await cache_funcs.set_monitor_cache(organization_id, model_version.id, monitor_id, result_dict["start"],
result_dict["end"], result_value)
reduce_results[model_version.name].append(result_value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ async def execute_check_per_window(

# TODO: consider caching results not only when a 'monitor_id' is provided
if cache_funcs and monitor_id:
cache_funcs.set_monitor_cache(
await cache_funcs.set_monitor_cache(
organization_id,
result['model_version_id'],
monitor_id,
Expand Down
7 changes: 3 additions & 4 deletions backend/deepchecks_monitoring/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

__all__ = ["ResourcesProvider"]

from deepchecks_monitoring.utils.redis_proxy import RedisProxy

logger: logging.Logger = configure_logger("server")


Expand Down Expand Up @@ -291,10 +293,7 @@ def get_kafka_admin(self) -> t.Generator[KafkaAdminClient, None, None]:
def redis_client(self) -> t.Optional[Redis]:
"""Return redis client if redis defined, else None."""
if self._redis_client is None and self.redis_settings.redis_uri:
try:
self._redis_client = RedisCluster.from_url(self.redis_settings.redis_uri)
except RedisClusterException:
self._redis_client = Redis.from_url(self.redis_settings.redis_uri)
self._redis_client = RedisProxy(self.redis_settings)
return self._redis_client

@property
Expand Down
71 changes: 71 additions & 0 deletions backend/deepchecks_monitoring/utils/redis_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
"""A proxy for Redis client that handles connection errors."""

import asyncio

import redis.exceptions as redis_exceptions
from redis.asyncio.client import Redis
from redis.asyncio.cluster import RedisCluster
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisClusterException
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed

from deepchecks_monitoring.config import RedisSettings

redis_exceptions_tuple = tuple( # Get all exception classes from redis.exceptions
cls for _, cls in vars(redis_exceptions).items()
if isinstance(cls, type) and issubclass(cls, Exception)
)


class RedisProxy:
def __init__(self, settings: RedisSettings):
self.settings = settings
self.client = self._connect(settings)

@classmethod
def _connect(cls, settings):
"""Connect to Redis."""
try:
client = RedisCluster.from_url(settings.redis_uri)
except redis_exceptions_tuple: # pylint: disable=catching-non-exception
client = Redis.from_url(settings.redis_uri)

return client

def __getattr__(self, name):
"""Wrapp the Redis client with retry mechanism."""
attr = getattr(self.client, name)
_decorator = retry(stop=stop_after_attempt(self.settings.stop_after_retries),
wait=wait_fixed(self.settings.wait_between_retries),
retry=retry_if_exception_type(redis_exceptions_tuple),
reraise=True)
if callable(attr):
if asyncio.iscoroutinefunction(attr):
@_decorator
async def wrapped(*args, **kwargs):
try:
return await attr(*args, **kwargs)
except (RedisClusterException, RedisConnectionError):
self.client = self._connect(self.settings)
raise
else:
@_decorator
def wrapped(*args, **kwargs):
try:
return attr(*args, **kwargs)
except (RedisClusterException, RedisConnectionError):
self.client = self._connect(self.settings)
raise

return wrapped
else:
return attr
12 changes: 2 additions & 10 deletions backend/dev_utils/run_task_directly.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,13 @@
from deepchecks_monitoring.ee.resources import ResourcesProvider
from deepchecks_monitoring.logic.keys import TASK_RUNNER_LOCK
from deepchecks_monitoring.public_models import Task
from deepchecks_monitoring.utils.redis_proxy import RedisProxy

# Task class you want to run
TASK_CLASS = ObjectStorageIngestor
# The task name you want to run (need to be exists in DB, we take the last one ordered by id desc)
BG_WORKER_TASK = 'object_storage_ingestion'

async def init_async_redis(redis_uri):
"""Initialize redis connection."""
try:
redis = RedisCluster.from_url(redis_uri)
await redis.ping()
return redis
except RedisClusterException:
return Redis.from_url(redis_uri)

async def run_it():
if path := dotenv.find_dotenv(usecwd=True):
dotenv.load_dotenv(dotenv_path=path)
Expand All @@ -49,7 +41,7 @@ async def run_it():
async with rp.create_async_database_session() as session:

try:
async_redis = await init_async_redis(rp.redis_settings.redis_uri)
async_redis = RedisProxy(rp.redis_settings)

lock_name = TASK_RUNNER_LOCK.format(1)
# By default, allow task 5 minutes before removes lock to allow another run. Inside the task itself we can
Expand Down
18 changes: 9 additions & 9 deletions backend/tests/logic/test_cache_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@ async def test_clear_monitor_cache(resources_provider):
for _ in range(0, 10_000, 100):
end_time = start_time.add(seconds=100)
# Should be deleted later
cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1,
await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1,
start_time=start_time, end_time=end_time, value='some value')
# Should be deleted later
cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1,
await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1,
start_time=start_time, end_time=end_time, value='some value')
# Should NOT be deleted later
cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7,
await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7,
start_time=start_time, end_time=end_time, value='some value')
# Should NOT be deleted later
cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1,
await cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1,
start_time=start_time, end_time=end_time, value='some value')
start_time = end_time

# Act
cache_funcs.clear_monitor_cache(organization_id=1, monitor_id=1)
# Assert
assert len(cache_funcs.redis.keys()) == 200
assert len(await cache_funcs.redis.keys()) == 200


@pytest.mark.asyncio
Expand All @@ -45,13 +45,13 @@ async def test_delete_monitor_cache_by_timestamp(resources_provider, async_sessi
start_time = now
for _ in range(0, 10_000, 100):
end_time = start_time.add(seconds=100)
cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1,
await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1,
start_time=start_time, end_time=end_time, value='some value')
cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1,
await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1,
start_time=start_time, end_time=end_time, value='some value')
cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7,
await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7,
start_time=start_time, end_time=end_time, value='some value')
cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1,
await cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1,
start_time=start_time, end_time=end_time, value='some value')
start_time = end_time

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ async def test_monitor_executor_is_using_cache(
window_start = window_end - (monitor_frequency.to_pendulum_duration() * monitor["aggregation_window"])
cache_value = {"my special key": 1}

resources_provider.cache_functions.set_monitor_cache(
await resources_provider.cache_functions.set_monitor_cache(
organization_id,
model_version["id"],
monitor["id"],
Expand Down
Loading