diff --git a/.github/workflows/license-check.yml b/.github/workflows/license-check.yml index ff5f912bd..f56e0a1ea 100644 --- a/.github/workflows/license-check.yml +++ b/.github/workflows/license-check.yml @@ -50,7 +50,7 @@ jobs: with: requirements: "backend/requirements-all.txt" fail: "Copyleft,Other,Error" - exclude: '(category_encoders.*2\.7\..*|attrs.*25\.3\..*|referencing.*0\.36\..*|envier.*0\.5\.0|psycopg2.*2\.9\.3|fqdn.*1\.5\.1|pyzmq.*25\.1\.2|debugpy.*1\.6\.7|certifi.*2025\.1\.31|tqdm.*4\.67\..*|webencodings.*0\.5\.1|torch.*1\.10\.2.*|torch.*1\.11\.0.*|pytorch-ignite.*0\.4\.10.*|torchaudio.*0\.11\.0.*|torchvision.*0\.12\.0.*|terminado.*0\.15\.0|qudida.*0\.0\.4|expiringdict.*1\.2\.2|botocore.*1\.29\.80|orderedmultidict.*1\.0\.1|deepchecks.*)' + exclude: '(category_encoders.*2\.7\..*|attrs.*25\.3\..*|referencing.*0\.36\..*|envier.*0\.5\.0|psycopg2.*2\.9\.3|fqdn.*1\.5\.1|pyzmq.*25\.1\.2|debugpy.*1\.6\.7|certifi.*2025\.1\.31|tqdm.*4\.67\..*|webencodings.*0\.5\.1|torch.*1\.10\.2.*|torch.*1\.11\.0.*|pytorch-ignite.*0\.4\.10.*|torchaudio.*0\.11\.0.*|torchvision.*0\.12\.0.*|terminado.*0\.15\.0|qudida.*0\.0\.4|expiringdict.*1\.2\.2|botocore.*1\.29\.80|orderedmultidict.*1\.0\.1|termcolor.*3\.0\.1|deepchecks.*)' # psycopg2 is LGPL 2 # category_encoders is BSD https://github.com/scikit-learn-contrib/category_encoders/tree/master?tab=BSD-3-Clause-1-ov-file # attrs is MIT https://github.com/python-attrs/attrs/blob/main/LICENSE @@ -64,6 +64,7 @@ jobs: # torchvision is BSD https://github.com/pytorch/vision/blob/main/LICENSE # torchaudio is BSD https://github.com/pytorch/audio/blob/main/LICENSE # terminado is BSD https://github.com/jupyter/terminado/blob/main/LICENSE + # termcolor is MIT https://github.com/termcolor/termcolor/blob/main/COPYING.txt # orderedmultidict is freeley distributed https://github.com/gruns/orderedmultidict/blob/master/LICENSE.md - name: Print report if: ${{ always() }} diff --git a/backend/deepchecks_monitoring/api/v1/data_input.py b/backend/deepchecks_monitoring/api/v1/data_input.py index 1953b6d4b..1c4336efd 100644 --- a/backend/deepchecks_monitoring/api/v1/data_input.py +++ b/backend/deepchecks_monitoring/api/v1/data_input.py @@ -73,7 +73,8 @@ async def log_data_batch( minute_rate = resources_provider.get_features_control(user).rows_per_minute # Atomically getting the count and increasing in order to avoid race conditions - curr_count = resources_provider.cache_functions.get_and_incr_user_rate_count(user, time, len(data)) + async with resources_provider.cache_functions() as cache_functions: + curr_count = await cache_functions.get_and_incr_user_rate_count(user, time, len(data)) remains = minute_rate - curr_count # Remains can be negative because we don't check the limit before incrementing @@ -140,7 +141,10 @@ async def log_labels( minute_rate = resources_provider.get_features_control(user).rows_per_minute # Atomically getting the count and increasing in order to avoid race conditions - curr_count = resources_provider.cache_functions.get_and_incr_user_rate_count(user, time, len(data), is_label=True) + async with resources_provider.cache_functions() as cache_functions: + curr_count = await cache_functions.get_and_incr_user_rate_count( + user, time, len(data), is_label=True + ) remains = minute_rate - curr_count # Remains can be negative because we don't check the limit before incrementing diff --git a/backend/deepchecks_monitoring/api/v1/monitor.py b/backend/deepchecks_monitoring/api/v1/monitor.py index 2ba5e8a72..a99a87ebb 100644 --- a/backend/deepchecks_monitoring/api/v1/monitor.py +++ b/backend/deepchecks_monitoring/api/v1/monitor.py @@ -21,8 +21,7 @@ from deepchecks_monitoring.api.v1.alert_rule import AlertRuleSchema from deepchecks_monitoring.api.v1.check import CheckResultSchema, CheckSchema from deepchecks_monitoring.config import Settings, Tags -from deepchecks_monitoring.dependencies import AsyncSessionDep, CacheFunctionsDep, ResourcesProviderDep, SettingsDep -from deepchecks_monitoring.logic.cache_functions import CacheFunctions +from deepchecks_monitoring.dependencies import AsyncSessionDep, ResourcesProviderDep, SettingsDep from deepchecks_monitoring.logic.check_logic import CheckNotebookSchema, MonitorOptions, run_check_per_window_in_range from deepchecks_monitoring.monitoring_utils import (DataFilterList, ExtendedAsyncSession, IdResponse, MonitorCheckConfSchema, fetch_or_404, field_length) @@ -155,7 +154,6 @@ async def update_monitor( monitor_id: int, body: MonitorUpdateSchema, session: AsyncSession = AsyncSessionDep, - cache_funcs: CacheFunctions = CacheFunctionsDep, user: User = Depends(CurrentActiveUser()), resources_provider: ResourcesProvider = ResourcesProviderDep, ): @@ -220,7 +218,8 @@ async def update_monitor( ) # Delete cache - cache_funcs.clear_monitor_cache(user.organization_id, monitor_id) + async with resources_provider.cache_functions() as cache_funcs: + await cache_funcs.clear_monitor_cache(user.organization_id, monitor_id) update_dict["updated_by"] = user.id await Monitor.update(session, monitor_id, update_dict) return Response(status_code=status.HTTP_200_OK) @@ -231,12 +230,13 @@ async def delete_monitor( monitor_id: int, monitor: Monitor = Depends(Monitor.get_object_from_http_request), session: AsyncSession = AsyncSessionDep, - cache_funcs: CacheFunctions = CacheFunctionsDep, + resources_provider: ResourcesProvider = ResourcesProviderDep, user: User = Depends(CurrentActiveUser()) ): """Delete monitor by id.""" await session.delete(monitor) - cache_funcs.clear_monitor_cache(user.organization_id, monitor_id) + async with resources_provider.cache_functions() as cache_funcs: + await cache_funcs.clear_monitor_cache(user.organization_id, monitor_id) return Response(status_code=status.HTTP_200_OK) @@ -280,7 +280,6 @@ async def run_monitor_lookback( body: MonitorRunSchema, monitor: Monitor = Depends(Monitor.get_object_from_http_request), session: AsyncSession = AsyncSessionDep, - cache_funcs: CacheFunctions = CacheFunctionsDep, user: User = Depends(CurrentActiveUser()), resources_provider: ResourcesProvider = ResourcesProviderDep, ): @@ -326,11 +325,12 @@ async def run_monitor_lookback( organization_id=t.cast(int, user.organization_id) ) - return await run_check_per_window_in_range( - monitor.check_id, - session, - options, - monitor_id=monitor_id, - cache_funcs=cache_funcs, - organization_id=user.organization_id, - ) + async with resources_provider.cache_functions() as cache_funcs: + return await run_check_per_window_in_range( + monitor.check_id, + session, + options, + monitor_id=monitor_id, + cache_funcs=cache_funcs, + organization_id=user.organization_id, + ) diff --git a/backend/deepchecks_monitoring/bgtasks/alert_task.py b/backend/deepchecks_monitoring/bgtasks/alert_task.py index 396e75950..a27de813a 100644 --- a/backend/deepchecks_monitoring/bgtasks/alert_task.py +++ b/backend/deepchecks_monitoring/bgtasks/alert_task.py @@ -154,13 +154,14 @@ async def execute_monitor( # First looking for results in cache if already calculated cache_results = {} model_versions_without_cache = [] - for model_version in model_versions: - cache_result = resources_provider.cache_functions.get_monitor_cache( - organization_id, model_version.id, monitor_id, start_time, end_time) - if cache_result.found: - cache_results[model_version] = cache_result.value - else: - model_versions_without_cache.append(model_version) + async with resources_provider.cache_functions() as cache_functions: + for model_version in model_versions: + cache_result = await cache_functions.get_monitor_cache( + organization_id, model_version.id, monitor_id, start_time, end_time) + if cache_result.found: + cache_results[model_version] = cache_result.value + else: + model_versions_without_cache.append(model_version) logger.debug('Cache result: %s', cache_results) # For model versions without result in cache running calculation @@ -181,9 +182,10 @@ async def execute_monitor( result_per_version = reduce_check_window(result_per_version, options) # Save to cache - for version, result in result_per_version.items(): - resources_provider.cache_functions.set_monitor_cache( - organization_id, version.id, monitor_id, start_time, end_time, result) + async with resources_provider.cache_functions() as cache_functions: + for version, result in result_per_version.items(): + await cache_functions.set_monitor_cache( + organization_id, version.id, monitor_id, start_time, end_time, result) logger.debug('Check execution result: %s', result_per_version) else: diff --git a/backend/deepchecks_monitoring/bgtasks/model_version_cache_invalidation.py b/backend/deepchecks_monitoring/bgtasks/model_version_cache_invalidation.py index 4b6d6a32d..2781bcf05 100644 --- a/backend/deepchecks_monitoring/bgtasks/model_version_cache_invalidation.py +++ b/backend/deepchecks_monitoring/bgtasks/model_version_cache_invalidation.py @@ -48,38 +48,38 @@ async def run(self, task: 'Task', session: AsyncSession, resources_provider, loc self.logger.info({'message': 'starting job', 'worker name': str(type(self)), 'task': task.id, 'model version': model_version_id, 'org_id': org_id}) - redis = resources_provider.redis_client - invalidation_set_key = get_invalidation_set_key(org_id, model_version_id) - - # Query all timestamps - entries = redis.zrange(invalidation_set_key, start=0, end=-1, withscores=True) - if entries: - # Sort timestamps for faster search - invalidation_ts = sorted([int(x[0]) for x in entries]) - max_score = max((x[1] for x in entries)) - - # Iterate all monitors cache keys and check timestamps overlap - monitor_pattern = build_monitor_cache_key(org_id, model_version_id, None, None, None) - keys_to_delete = [] - for monitor_cache_key in redis.scan_iter(match=monitor_pattern): - splitted = monitor_cache_key.split(b':') - start_ts, end_ts = int(splitted[4]), int(splitted[5]) - # Get first timestamp equal or larger than start_ts - index = bisect.bisect_left(invalidation_ts, start_ts) - # If index is equal to list length, then all timestamps are smaller than start_ts - if index == len(invalidation_ts): - continue - if start_ts <= invalidation_ts[index] < end_ts: - keys_to_delete.append(monitor_cache_key) - - pipe = redis.pipeline() - for key in keys_to_delete: + async with resources_provider.get_redis_client() as redis: + invalidation_set_key = get_invalidation_set_key(org_id, model_version_id) + + # Query all timestamps + entries = await redis.zrange(invalidation_set_key, start=0, end=-1, withscores=True) + if entries: + # Sort timestamps for faster search + invalidation_ts = sorted([int(x[0]) for x in entries]) + max_score = max((x[1] for x in entries)) + + # Iterate all monitors cache keys and check timestamps overlap + monitor_pattern = build_monitor_cache_key(org_id, model_version_id, None, None, None) + keys_to_delete = [] + async for monitor_cache_key in redis.scan_iter(match=monitor_pattern): + splitted = monitor_cache_key.split(b':') + start_ts, end_ts = int(splitted[4]), int(splitted[5]) + # Get first timestamp equal or larger than start_ts + index = bisect.bisect_left(invalidation_ts, start_ts) + # If index is equal to list length, then all timestamps are smaller than start_ts + if index == len(invalidation_ts): + continue + if start_ts <= invalidation_ts[index] < end_ts: + keys_to_delete.append(monitor_cache_key) + + pipe = redis.pipeline() # Delete all cache keys - must do in separate deletes since RedisCluster does not support multi-delete - pipe.delete(key) - # Delete all invalidation timestamps by range. if timestamps were updated while running, - # then their score should be larger than max_score, and they won't be deleted - pipe.zremrangebyscore(invalidation_set_key, min=0, max=max_score) - pipe.execute() + for key in keys_to_delete: + await pipe.delete(key) + # Delete all invalidation timestamps by range. if timestamps were updated while running, + # then their score should be larger than max_score, and they won't be deleted + await pipe.zremrangebyscore(invalidation_set_key, min=0, max=max_score) + await pipe.execute() self.logger.info({'message': 'finished job', 'worker name': str(type(self)), 'task': task.id, 'model version': model_version_id, 'org_id': org_id}) diff --git a/backend/deepchecks_monitoring/bgtasks/tasks_queuer.py b/backend/deepchecks_monitoring/bgtasks/tasks_queuer.py index 0f1a6f225..b18c43d91 100644 --- a/backend/deepchecks_monitoring/bgtasks/tasks_queuer.py +++ b/backend/deepchecks_monitoring/bgtasks/tasks_queuer.py @@ -33,6 +33,7 @@ from deepchecks_monitoring.logic.keys import GLOBAL_TASK_QUEUE from deepchecks_monitoring.monitoring_utils import configure_logger from deepchecks_monitoring.public_models.task import BackgroundWorker, Task +from deepchecks_monitoring.utils.redis_proxy import RedisProxy try: from deepchecks_monitoring import ee @@ -50,7 +51,7 @@ class TasksQueuer: def __init__( self, resource_provider: ResourcesProvider, - redis_client: RedisCluster | Redis, + redis_client: RedisCluster | Redis | RedisProxy, workers: t.List[BackgroundWorker], logger: logging.Logger, run_interval: int, @@ -152,16 +153,6 @@ class Config: env_file_encoding = 'utf-8' -async def init_async_redis(redis_uri): - """Initialize redis connection.""" - try: - redis = RedisCluster.from_url(redis_uri) - await redis.ping() - return redis - except redis_exceptions.RedisClusterException: - return Redis.from_url(redis_uri) - - def execute_worker(): """Execute worker.""" @@ -195,7 +186,8 @@ async def main(): async with ResourcesProvider(settings) as rp: async with anyio.create_task_group() as g: - async_redis = await init_async_redis(rp.redis_settings.redis_uri) + async_redis = RedisProxy(rp.redis_settings) + await async_redis.init_conn_async() worker = tasks_queuer.TasksQueuer(rp, async_redis, workers, logger, settings.queuer_run_interval) g.start_soon(worker.run) diff --git a/backend/deepchecks_monitoring/bgtasks/tasks_runner.py b/backend/deepchecks_monitoring/bgtasks/tasks_runner.py index 2cf7bff91..11b879795 100644 --- a/backend/deepchecks_monitoring/bgtasks/tasks_runner.py +++ b/backend/deepchecks_monitoring/bgtasks/tasks_runner.py @@ -16,7 +16,7 @@ import pendulum as pdl import uvloop from redis.asyncio import Redis, RedisCluster -from redis.exceptions import LockNotOwnedError, RedisClusterException +from redis.exceptions import LockNotOwnedError from sqlalchemy import select from deepchecks_monitoring.bgtasks.alert_task import AlertsTask @@ -29,6 +29,7 @@ from deepchecks_monitoring.logic.keys import GLOBAL_TASK_QUEUE, TASK_RUNNER_LOCK from deepchecks_monitoring.monitoring_utils import configure_logger from deepchecks_monitoring.public_models.task import BackgroundWorker, Task +from deepchecks_monitoring.utils.redis_proxy import RedisProxy try: from deepchecks_monitoring import ee @@ -160,16 +161,6 @@ class WorkerSettings(BaseWorkerSettings, Settings): pass -async def init_async_redis(redis_uri): - """Initialize redis connection.""" - try: - redis = RedisCluster.from_url(redis_uri) - await redis.ping() - return redis - except RedisClusterException: - return Redis.from_url(redis_uri) - - def execute_worker(): """Execute worker.""" @@ -189,7 +180,8 @@ async def main(): from deepchecks_monitoring.bgtasks import tasks_runner # pylint: disable=import-outside-toplevel async with ResourcesProvider(settings) as rp: - async_redis = await init_async_redis(rp.redis_settings.redis_uri) + async_redis = RedisProxy(rp.redis_settings) + await async_redis.init_conn_async() workers = [ ModelVersionCacheInvalidation(), diff --git a/backend/deepchecks_monitoring/config.py b/backend/deepchecks_monitoring/config.py index d7c3a99d3..eb4f14b51 100644 --- a/backend/deepchecks_monitoring/config.py +++ b/backend/deepchecks_monitoring/config.py @@ -137,6 +137,8 @@ class RedisSettings(BaseDeepchecksSettings): """Redis settings.""" redis_uri: t.Optional[RedisDsn] = None + stop_after_retries: int = 3 # Number of retries before giving up + wait_between_retries: int = 3 # Time to wait between retries class Settings( diff --git a/backend/deepchecks_monitoring/dependencies.py b/backend/deepchecks_monitoring/dependencies.py index 1ac06a5f5..27574a710 100644 --- a/backend/deepchecks_monitoring/dependencies.py +++ b/backend/deepchecks_monitoring/dependencies.py @@ -28,7 +28,6 @@ "limit_request_size", "SettingsDep", "DataIngestionDep", - "CacheFunctionsDep", "ResourcesProviderDep" ] @@ -61,11 +60,6 @@ def get_data_ingestion_backend(request: fastapi.Request): return state.data_ingestion_backend -def get_cache_functions(request: fastapi.Request): - state = request.app.state - return state.resources_provider.cache_functions - - def get_host(request: fastapi.Request) -> str: settings = request.app.state.settings return settings.host @@ -78,7 +72,6 @@ def get_resources_provider(request: fastapi.Request) -> "ResourcesProvider": AsyncSessionDep = fastapi.Depends(get_async_session) SettingsDep = fastapi.Depends(get_settings) DataIngestionDep = fastapi.Depends(get_data_ingestion_backend) -CacheFunctionsDep = fastapi.Depends(get_cache_functions) ResourcesProviderDep = fastapi.Depends(get_resources_provider) diff --git a/backend/deepchecks_monitoring/logic/cache_functions.py b/backend/deepchecks_monitoring/logic/cache_functions.py index 20dc8d966..6311132ca 100644 --- a/backend/deepchecks_monitoring/logic/cache_functions.py +++ b/backend/deepchecks_monitoring/logic/cache_functions.py @@ -16,7 +16,7 @@ import pendulum as pdl import redis.exceptions -from redis.client import Redis +from redis.asyncio.client import Redis from deepchecks_monitoring.logic.keys import build_monitor_cache_key, get_invalidation_set_key @@ -39,16 +39,16 @@ def __init__(self, redis_client=None): self.redis: Redis = redis_client self.logger = logging.Logger("cache-functions") - def get_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time): + async def get_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time): """Get result from cache if exists. We can cache values which are "None" therefore to distinguish between the \ situations we return CacheResult with 'found' property.""" if self.use_cache: key = build_monitor_cache_key(organization_id, model_version_id, monitor_id, start_time, end_time) try: p = self.redis.pipeline() - p.get(key) - p.expire(key, MONITOR_CACHE_EXPIRY_TIME) - cache_value = p.execute()[0] + await p.get(key) + await p.expire(key, MONITOR_CACHE_EXPIRY_TIME) + cache_value = (await p.execute())[0] # If cache value is none it means the key was not found if cache_value is not None: return CacheResult(found=True, value=json.loads(cache_value)) @@ -58,7 +58,7 @@ def get_monitor_cache(self, organization_id, model_version_id, monitor_id, start # Return no cache result return CacheResult(found=False, value=None) - def set_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time, value): + async def set_monitor_cache(self, organization_id, model_version_id, monitor_id, start_time, end_time, value): """Set cache value for the properties given.""" if not self.use_cache: return @@ -66,13 +66,13 @@ def set_monitor_cache(self, organization_id, model_version_id, monitor_id, start key = build_monitor_cache_key(organization_id, model_version_id, monitor_id, start_time, end_time) cache_val = json.dumps(value) p = self.redis.pipeline() - p.set(key, cache_val) - p.expire(key, MONITOR_CACHE_EXPIRY_TIME) - p.execute() + await p.set(key, cache_val) + await p.expire(key, MONITOR_CACHE_EXPIRY_TIME) + await p.execute() except redis.exceptions.RedisError as e: self.logger.exception(e) - def clear_monitor_cache(self, organization_id: int, monitor_id: int): + async def clear_monitor_cache(self, organization_id: int, monitor_id: int): """Clear entries from the cache. Parameters @@ -86,30 +86,26 @@ def clear_monitor_cache(self, organization_id: int, monitor_id: int): try: pattern = build_monitor_cache_key(organization_id, None, monitor_id, None, None) keys_to_delete = [] - for key in self.redis.scan_iter(match=pattern): + async for key in self.redis.scan_iter(match=pattern): keys_to_delete.append(key) if keys_to_delete: - self.redis.delete(*keys_to_delete) + await self.redis.delete(*keys_to_delete) except redis.exceptions.RedisError as e: self.logger.exception(e) - def delete_key(self, key): - """Remove a given key from the cache.""" - self.redis.delete(key) - - def get_and_incr_user_rate_count(self, user, time, count_added, is_label=True): + async def get_and_incr_user_rate_count(self, user, time, count_added, is_label=True): """Get the user's organization samples count for the given minute, and increase by the given amount.""" key = f"rate-limit:{user.organization.id}:{time.minute}" if is_label: key += ":label" p = self.redis.pipeline() - p.incr(key, count_added) - p.expire(key, 60) - count_after_increase = p.execute()[0] + await p.incr(key, count_added) + await p.expire(key, 60) + count_after_increase = (await p.execute())[0] # Return the count before incrementing return count_after_increase - count_added - def add_invalidation_timestamps(self, organization_id: int, model_version_id: int, timestamps: t.Set[int]): + async def add_invalidation_timestamps(self, organization_id: int, model_version_id: int, timestamps: t.Set[int]): key = get_invalidation_set_key(organization_id, model_version_id) now = pdl.now().timestamp() - self.redis.zadd(key, {ts: now for ts in timestamps}) + await self.redis.zadd(key, {ts: now for ts in timestamps}) diff --git a/backend/deepchecks_monitoring/logic/check_logic.py b/backend/deepchecks_monitoring/logic/check_logic.py index 5ff0d4934..82e974a98 100644 --- a/backend/deepchecks_monitoring/logic/check_logic.py +++ b/backend/deepchecks_monitoring/logic/check_logic.py @@ -323,7 +323,7 @@ async def run_check_per_window_in_range( curr_test_info = {"start": window_start, "end": window_end} test_info.append(curr_test_info) if monitor_id and cache_funcs: - cache_result = cache_funcs.get_monitor_cache( + cache_result = await cache_funcs.get_monitor_cache( organization_id, model_version.id, monitor_id, window_start, window_end) # If found the result in cache, skip querying if cache_result.found: @@ -372,8 +372,10 @@ async def run_check_per_window_in_range( result_value = reduce_check_result(result_value, monitor_options.additional_kwargs) # If cache available and there is monitor id, save result to cache if cache_funcs and monitor_id: - cache_funcs.set_monitor_cache(organization_id, model_version.id, monitor_id, result_dict["start"], - result_dict["end"], result_value) + await cache_funcs.set_monitor_cache( + organization_id, model_version.id, monitor_id, result_dict["start"], + result_dict["end"], result_value + ) reduce_results[model_version.name].append(result_value) return { diff --git a/backend/deepchecks_monitoring/logic/data_ingestion.py b/backend/deepchecks_monitoring/logic/data_ingestion.py index e28fc71ff..5d7857e72 100644 --- a/backend/deepchecks_monitoring/logic/data_ingestion.py +++ b/backend/deepchecks_monitoring/logic/data_ingestion.py @@ -425,8 +425,9 @@ async def log_samples( await self._send_with_retry(producer, topic_name, messages, len(messages)) else: - await log_data(model_version, data, session, [log_time] * len(data), self.logger, - organization_id, self.resources_provider.cache_functions) + async with self.resources_provider.cache_functions() as cache_functions: + await log_data(model_version, data, session, + [log_time] * len(data), self.logger, organization_id, cache_functions) async def log_labels( self, @@ -458,8 +459,8 @@ async def log_labels( await self._send_with_retry(producer, topic_name, messages, len(messages)) else: - await log_labels(model, data, session, organization_id, - self.resources_provider.cache_functions, self.logger) + async with self.resources_provider.cache_functions() as cache_functions: + await log_labels(model, data, session, organization_id, cache_functions, self.logger) async def run_data_consumer(self): """Create an endless-loop of consuming messages from kafka.""" @@ -494,8 +495,9 @@ async def _handle_data_messages(self, messages_data = [json.loads(m.value) for m in messages if m.offset > model_version.ingestion_offset] samples = [m["data"] for m in messages_data] log_times = [pdl.parse(m["log_time"]) for m in messages_data] - await log_data(model_version, samples, session, log_times, self.logger, organization_id, - self.resources_provider.cache_functions) + async with self.resources_provider.cache_functions() as cache_functions: + await log_data(model_version, samples, session, log_times, self.logger, organization_id, + cache_functions) model_version.ingestion_offset = messages[-1].offset if entity == "model": model: Model = (await session.execute(select(Model).where(Model.id == entity_id))).scalar() @@ -507,8 +509,8 @@ async def _handle_data_messages(self, # already ingested messages messages_data = [json.loads(m.value) for m in messages if m.offset > model.ingestion_offset] samples = [m["data"] for m in messages_data] - await log_labels(model, samples, session, organization_id, - self.resources_provider.cache_functions, self.logger) + async with self.resources_provider.cache_functions() as cache_functions: + await log_labels(model, samples, session, organization_id, cache_functions, self.logger) model.ingestion_offset = messages[-1].offset return True diff --git a/backend/deepchecks_monitoring/logic/parallel_check_executor.py b/backend/deepchecks_monitoring/logic/parallel_check_executor.py index 63efd2a0b..c1123c1f7 100644 --- a/backend/deepchecks_monitoring/logic/parallel_check_executor.py +++ b/backend/deepchecks_monitoring/logic/parallel_check_executor.py @@ -220,7 +220,7 @@ async def execute_check_per_window( # TODO: consider caching results not only when a 'monitor_id' is provided if cache_funcs and monitor_id: - cache_funcs.set_monitor_cache( + await cache_funcs.set_monitor_cache( organization_id, result['model_version_id'], monitor_id, diff --git a/backend/deepchecks_monitoring/resources.py b/backend/deepchecks_monitoring/resources.py index 88919e50f..b2b273479 100644 --- a/backend/deepchecks_monitoring/resources.py +++ b/backend/deepchecks_monitoring/resources.py @@ -20,9 +20,6 @@ from kafka import KafkaAdminClient from kafka.admin import NewTopic from kafka.errors import KafkaError, TopicAlreadyExistsError -from redis.client import Redis -from redis.cluster import RedisCluster -from redis.exceptions import RedisClusterException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlalchemy.future.engine import Engine, create_engine @@ -39,9 +36,11 @@ from deepchecks_monitoring.utils import database from deepchecks_monitoring.utils.mixpanel import BaseEvent as BaseMixpanelEvent from deepchecks_monitoring.utils.mixpanel import MixpanelEventReporter +from deepchecks_monitoring.utils.redis_proxy import RedisProxy __all__ = ["ResourcesProvider"] + logger: logging.Logger = configure_logger("server") @@ -82,9 +81,6 @@ def __init__(self, settings: config.BaseSettings): self._settings = settings self._database_engine: t.Optional[Engine] = None self._async_database_engine: t.Optional[AsyncEngine] = None - self._kafka_admin: t.Optional[KafkaAdminClient] = None - self._redis_client: t.Optional[Redis] = None - self._cache_funcs: t.Optional[CacheFunctions] = None self._email_sender: t.Optional[EmailSender] = None self._oauth_client: t.Optional[OAuth] = None self._parallel_check_executors = None @@ -287,22 +283,24 @@ def get_kafka_admin(self) -> t.Generator[KafkaAdminClient, None, None]: finally: kafka_admin.close() - @property - def redis_client(self) -> t.Optional[Redis]: + @asynccontextmanager + async def get_redis_client(self) -> t.AsyncGenerator[t.Optional[RedisProxy], None]: """Return redis client if redis defined, else None.""" - if self._redis_client is None and self.redis_settings.redis_uri: + if self.redis_settings.redis_uri: + redis_proxy = RedisProxy(self.redis_settings) + await redis_proxy.init_conn_async() try: - self._redis_client = RedisCluster.from_url(self.redis_settings.redis_uri) - except RedisClusterException: - self._redis_client = Redis.from_url(self.redis_settings.redis_uri) - return self._redis_client + yield redis_proxy + finally: + await redis_proxy.aclose() + else: + yield None - @property - def cache_functions(self) -> t.Optional[CacheFunctions]: + @asynccontextmanager + async def cache_functions(self) -> t.AsyncGenerator[CacheFunctions, None]: """Return cache functions.""" - if self._cache_funcs is None: - self._cache_funcs = CacheFunctions(self.redis_client) - return self._cache_funcs + async with self.get_redis_client() as redis_client: + yield CacheFunctions(redis_client) @property def oauth_client(self): diff --git a/backend/deepchecks_monitoring/utils/redis_proxy.py b/backend/deepchecks_monitoring/utils/redis_proxy.py new file mode 100644 index 000000000..7e5ffb52c --- /dev/null +++ b/backend/deepchecks_monitoring/utils/redis_proxy.py @@ -0,0 +1,76 @@ +# ---------------------------------------------------------------------------- +# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com) +# +# This file is part of Deepchecks. +# Deepchecks is distributed under the terms of the GNU Affero General +# Public License (version 3 or later). +# You should have received a copy of the GNU Affero General Public License +# along with Deepchecks. If not, see . +# ---------------------------------------------------------------------------- +"""A proxy for Redis client that handles connection errors.""" + +import asyncio + +import redis.exceptions as redis_exceptions +from redis.asyncio.client import Redis +from redis.asyncio.cluster import RedisCluster +from redis.exceptions import ConnectionError as RedisConnectionError +from redis.exceptions import RedisClusterException +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from deepchecks_monitoring.config import RedisSettings + +redis_exceptions_tuple = tuple( # Get all exception classes from redis.exceptions + cls for _, cls in vars(redis_exceptions).items() + if isinstance(cls, type) and issubclass(cls, Exception) +) + + +class RedisProxy: + "A proxy for Redis client that handles connection errors." + + def __init__(self, settings: RedisSettings): + self.settings = settings + self.client = None + + @classmethod + async def _get_redis_client(cls, settings: RedisSettings): + try: + client = RedisCluster.from_url(settings.redis_uri) + await client.ping() + except redis_exceptions_tuple: # pylint: disable=catching-non-exception + client = Redis.from_url(settings.redis_uri) + return client + + async def init_conn_async(self): + """Connect to Redis.""" + @retry( + stop=stop_after_attempt(self.settings.stop_after_retries), + wait=wait_fixed(self.settings.wait_between_retries), + retry=retry_if_exception_type(redis_exceptions_tuple), + reraise=True + ) + async def connect_to_redis(): + self.client = await self._get_redis_client(self.settings) + await connect_to_redis() + + def __getattr__(self, name): + """Wrapp the Redis client with retry mechanism.""" + attr = getattr(self.client, name) + decorator = retry(stop=stop_after_attempt(self.settings.stop_after_retries), + wait=wait_fixed(self.settings.wait_between_retries), + retry=retry_if_exception_type(redis_exceptions_tuple), + reraise=True) + if callable(attr): + if asyncio.iscoroutinefunction(attr): + @decorator + async def wrapped(*args, **kwargs): + return await attr(*args, **kwargs) + else: + @decorator + def wrapped(*args, **kwargs): + return attr(*args, **kwargs) + + return wrapped + else: + return attr diff --git a/backend/dev-requirements.txt b/backend/dev-requirements.txt index 8fbb91b5c..06cfc8fc0 100644 --- a/backend/dev-requirements.txt +++ b/backend/dev-requirements.txt @@ -20,4 +20,4 @@ tox==3.25.1 faker pyOpenSSL aiosmtpd -fakeredis[lua]==2.9.2 \ No newline at end of file +fakeredis[lua]==2.28.0 \ No newline at end of file diff --git a/backend/dev_utils/run_task_directly.py b/backend/dev_utils/run_task_directly.py index 48021888b..2888e88d6 100644 --- a/backend/dev_utils/run_task_directly.py +++ b/backend/dev_utils/run_task_directly.py @@ -20,21 +20,13 @@ from deepchecks_monitoring.ee.resources import ResourcesProvider from deepchecks_monitoring.logic.keys import TASK_RUNNER_LOCK from deepchecks_monitoring.public_models import Task +from deepchecks_monitoring.utils.redis_proxy import RedisProxy # Task class you want to run TASK_CLASS = ObjectStorageIngestor # The task name you want to run (need to be exists in DB, we take the last one ordered by id desc) BG_WORKER_TASK = 'object_storage_ingestion' -async def init_async_redis(redis_uri): - """Initialize redis connection.""" - try: - redis = RedisCluster.from_url(redis_uri) - await redis.ping() - return redis - except RedisClusterException: - return Redis.from_url(redis_uri) - async def run_it(): if path := dotenv.find_dotenv(usecwd=True): dotenv.load_dotenv(dotenv_path=path) @@ -49,7 +41,8 @@ async def run_it(): async with rp.create_async_database_session() as session: try: - async_redis = await init_async_redis(rp.redis_settings.redis_uri) + async_redis = RedisProxy(rp.redis_settings) + await async_redis.init_conn_async() lock_name = TASK_RUNNER_LOCK.format(1) # By default, allow task 5 minutes before removes lock to allow another run. Inside the task itself we can diff --git a/backend/requirements.txt b/backend/requirements.txt index df06960f7..8b7309b4b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -19,7 +19,7 @@ kafka-python==2.0.2 uvloop==0.17.0 nbformat>=5.4.0,<6 deepchecks@git+https://github.com/deepchecks/deepchecks.git@8c15865aaebc7f73faeb0939d07fa982d646b2c4 -redis[hiredis]~=4.6.0 +redis[hiredis]~=5.2.1 pandas~=2.1.0 pyjwt[crypto]==2.4.0 Authlib~=1.0.1 diff --git a/backend/tests/api/test_monitor.py b/backend/tests/api/test_monitor.py index 478b619fb..47ad8f1ea 100644 --- a/backend/tests/api/test_monitor.py +++ b/backend/tests/api/test_monitor.py @@ -11,7 +11,7 @@ import pendulum as pdl import pytest -from fakeredis import FakeRedis +from fakeredis.aioredis import FakeRedis from sqlalchemy.ext.asyncio import AsyncSession from deepchecks_monitoring.logic.keys import build_monitor_cache_key @@ -301,7 +301,7 @@ async def test_monitor_update_with_data( monitor = await async_session.get(Monitor, monitor["id"]) expected_schedule = round_up_datetime(daterange[0], monitor_frequency, "utc") - \ - monitor_frequency.to_pendulum_duration() + monitor_frequency.to_pendulum_duration() assert pdl.instance(monitor.latest_schedule) == expected_schedule # Act - Update only monitor name, and rest of the fields should be the same @@ -408,12 +408,13 @@ async def test_update_monitor_freq( assert latest_schedule == expected -def test_monitor_execution( +@pytest.mark.asyncio +async def test_monitor_execution( test_api: TestAPI, classification_model_check: Payload, classification_model_version: Payload, classification_model: Payload, - redis: FakeRedis, + async_redis: FakeRedis, user ): # Arrange @@ -434,12 +435,23 @@ def test_monitor_execution( }]} } )) + upload_classification_data( api=test_api, model_version_id=classification_model_version["id"], model_id=classification_model["id"], daterange=daterange ) + # assert cache empty before running monitor + count_cache = 0 + monitor_key = build_monitor_cache_key(user.organization_id, classification_model_version["id"], monitor["id"], + None, None) + + async for _ in async_redis.scan_iter(monitor_key): + count_cache += 1 + + assert count_cache == 0 + # Act result_without_cache = test_api.execute_monitor( monitor_id=monitor["id"], @@ -449,14 +461,12 @@ def test_monitor_execution( result_without_cache = t.cast(Payload, result_without_cache) # Assert cache is populated + async_redis.connection_pool.reset() # we need to reset due to fake redis being async count_cache = 0 - monitor_key = build_monitor_cache_key(user.organization_id, classification_model_version["id"], monitor["id"], - None, None) - for _ in redis.scan_iter(monitor_key): + async for _ in async_redis.scan_iter(monitor_key): count_cache += 1 assert count_cache == 8 - # Assert result is same with cache result_with_cache = test_api.execute_monitor( monitor_id=monitor["id"], diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 39b5312b5..15593807f 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -19,6 +19,7 @@ import dotenv import faker import fakeredis +import fakeredis.aioredis import pytest import pytest_asyncio import testing.postgresql @@ -38,6 +39,8 @@ from deepchecks_monitoring.monitoring_utils import ExtendedAsyncSession from deepchecks_monitoring.public_models.base import Base as PublicModelsBase from deepchecks_monitoring.schema_models import TaskType +from deepchecks_monitoring.utils.redis_proxy import RedisProxy + from tests.common import Payload, TestAPI, generate_user from tests.utils import TestDatabaseGenerator, create_dummy_smtp_server @@ -143,9 +146,13 @@ def smtp_server(): yield server -@pytest.fixture(scope="function") -def redis(): - yield fakeredis.FakeStrictRedis() +@pytest_asyncio.fixture(scope="function") +async def async_redis(): + try: + redis = fakeredis.aioredis.FakeRedis() + yield redis + finally: + await redis.aclose() @pytest.fixture(scope="function") @@ -168,7 +175,8 @@ def settings(async_engine, smtp_server): kafka_host=None, is_cloud=True, mixpanel_id="xxxxxx", - enable_analytics=True + enable_analytics=True, + redis_uri="redis://localhost/0", ) @@ -177,6 +185,7 @@ def _mock_mixpanel_client(): mixpanel_mock = MagicMock() with patch("deepchecks_monitoring.utils.mixpanel.Mixpanel", new=mixpanel_mock) as MockClass: instance = MockClass.return_value + def track(distinct_id, event_name, properties=None, meta=None): nonlocal instance, mixpanel_mock if properties: @@ -206,23 +215,26 @@ def mock_get_features_control(self, user): # pylint: disable=unused-argument return mock_get_features_control -@pytest.fixture(scope="function") -def resources_provider(settings, features_control_mock, redis): +@pytest_asyncio.fixture(scope="function") +async def resources_provider(settings, features_control_mock, async_redis): + async def _get_redis_client(*_): + async_redis.connection_pool.reset() + return async_redis patch.object(ResourcesProvider, "get_features_control", features_control_mock).start() - patch.object(ResourcesProvider, "redis_client", redis).start() + patch.object(RedisProxy, "_get_redis_client", _get_redis_client).start() yield ResourcesProvider(settings) -@pytest_asyncio.fixture(scope="function") +@ pytest_asyncio.fixture(scope="function") async def application( resources_provider: ResourcesProvider, settings: Settings ) -> FastAPI: """Create application instance.""" return create_application( - resources_provider=resources_provider, - settings=settings, - log_level="ERROR" + resources_provider = resources_provider, + settings = settings, + log_level = "ERROR" ) @@ -372,6 +384,7 @@ async def classification_model_version( ) return t.cast(t.Dict[str, t.Any], result) + @pytest_asyncio.fixture() async def classification_model_version_with_bool( test_api: TestAPI, diff --git a/backend/tests/logic/test_cache_functions.py b/backend/tests/logic/test_cache_functions.py index cac7b7548..4ae7dcc8d 100644 --- a/backend/tests/logic/test_cache_functions.py +++ b/backend/tests/logic/test_cache_functions.py @@ -4,66 +4,65 @@ from deepchecks_monitoring.bgtasks.model_version_cache_invalidation import ( ModelVersionCacheInvalidation, insert_model_version_cache_invalidation_task) -from deepchecks_monitoring.logic.cache_functions import CacheFunctions from deepchecks_monitoring.public_models import Task @pytest.mark.asyncio async def test_clear_monitor_cache(resources_provider): - cache_funcs: CacheFunctions = resources_provider.cache_functions + async with resources_provider.cache_functions() as cache_funcs: - # Arrange - Organization with 2 monitors and 2 model versions, and another organization with same monitor id. - start_time = pdl.now() - for _ in range(0, 10_000, 100): - end_time = start_time.add(seconds=100) - # Should be deleted later - cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1, - start_time=start_time, end_time=end_time, value='some value') - # Should be deleted later - cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1, - start_time=start_time, end_time=end_time, value='some value') - # Should NOT be deleted later - cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7, - start_time=start_time, end_time=end_time, value='some value') - # Should NOT be deleted later - cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1, - start_time=start_time, end_time=end_time, value='some value') - start_time = end_time + # Arrange - Organization with 2 monitors and 2 model versions, and another organization with same monitor id. + start_time = pdl.now() + for _ in range(0, 10_000, 100): + end_time = start_time.add(seconds=100) + # Should be deleted later + await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1, + start_time=start_time, end_time=end_time, value='some value') + # Should be deleted later + await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1, + start_time=start_time, end_time=end_time, value='some value') + # Should NOT be deleted later + await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7, + start_time=start_time, end_time=end_time, value='some value') + # Should NOT be deleted later + await cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1, + start_time=start_time, end_time=end_time, value='some value') + start_time = end_time - # Act - cache_funcs.clear_monitor_cache(organization_id=1, monitor_id=1) - # Assert - assert len(cache_funcs.redis.keys()) == 200 + # Act + await cache_funcs.clear_monitor_cache(organization_id=1, monitor_id=1) + # Assert + assert len(await cache_funcs.redis.keys()) == 200 @pytest.mark.asyncio async def test_delete_monitor_cache_by_timestamp(resources_provider, async_session): - cache_funcs: CacheFunctions = resources_provider.cache_functions + async with resources_provider.cache_functions() as cache_funcs: - # Arrange - Organization with 2 monitors and 2 model versions, and another organization with same monitor id. - now = pdl.now() - start_time = now - for _ in range(0, 10_000, 100): - end_time = start_time.add(seconds=100) - cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1, - start_time=start_time, end_time=end_time, value='some value') - cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1, - start_time=start_time, end_time=end_time, value='some value') - cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7, - start_time=start_time, end_time=end_time, value='some value') - cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1, - start_time=start_time, end_time=end_time, value='some value') - start_time = end_time + # Arrange - Organization with 2 monitors and 2 model versions, and another organization with same monitor id. + now = pdl.now() + start_time = now + for _ in range(0, 10_000, 100): + end_time = start_time.add(seconds=100) + await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=1, + start_time=start_time, end_time=end_time, value='some value') + await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=2, monitor_id=1, + start_time=start_time, end_time=end_time, value='some value') + await cache_funcs.set_monitor_cache(organization_id=1, model_version_id=1, monitor_id=7, + start_time=start_time, end_time=end_time, value='some value') + await cache_funcs.set_monitor_cache(organization_id=9, model_version_id=1, monitor_id=1, + start_time=start_time, end_time=end_time, value='some value') + start_time = end_time - timestamps_to_invalidate = {now.add(seconds=140).int_timestamp, now.add(seconds=520).int_timestamp, - now.add(seconds=1000).int_timestamp} - cache_funcs.add_invalidation_timestamps(1, 1, timestamps_to_invalidate) + timestamps_to_invalidate = {now.add(seconds=140).int_timestamp, now.add(seconds=520).int_timestamp, + now.add(seconds=1000).int_timestamp} + await cache_funcs.add_invalidation_timestamps(1, 1, timestamps_to_invalidate) - # Act - run task - async with async_session as session: - task_id = await insert_model_version_cache_invalidation_task(1, 1, session=session) - task = await session.scalar(select(Task).where(Task.id == task_id)) - await ModelVersionCacheInvalidation().run(task, session, resources_provider, lock=None) + # Act - run task + async with async_session as session: + task_id = await insert_model_version_cache_invalidation_task(1, 1, session=session) + task = await session.scalar(select(Task).where(Task.id == task_id)) + await ModelVersionCacheInvalidation().run(task, session, resources_provider, lock=None) - # Assert - 2 monitors and 3 timestamps - assert len(cache_funcs.redis.keys()) == 400 - 2 * 3 + # Assert - 2 monitors and 3 timestamps + assert len(await cache_funcs.redis.keys()) == 400 - 2 * 3 diff --git a/backend/tests/unittests/test_monitor_alert_rules_executor.py b/backend/tests/unittests/test_monitor_alert_rules_executor.py index 10a69d801..a50269a1f 100644 --- a/backend/tests/unittests/test_monitor_alert_rules_executor.py +++ b/backend/tests/unittests/test_monitor_alert_rules_executor.py @@ -118,13 +118,14 @@ async def test_monitor_executor( window_end = now window_start = window_end - Frequency.DAY.to_pendulum_duration() - cache_value = resources_provider.cache_functions.get_monitor_cache( - user.organization.id, - versions[0]["id"], - monitor["id"], - window_start, - window_end - ) + async with resources_provider.cache_functions() as cache_functions: + cache_value = await cache_functions.get_monitor_cache( + user.organization.id, + versions[0]["id"], + monitor["id"], + window_start, + window_end + ) assert cache_value.found is True assert cache_value.value == {"accuracy": 0.2} @@ -340,14 +341,15 @@ async def test_monitor_executor_is_using_cache( window_start = window_end - (monitor_frequency.to_pendulum_duration() * monitor["aggregation_window"]) cache_value = {"my special key": 1} - resources_provider.cache_functions.set_monitor_cache( - organization_id, - model_version["id"], - monitor["id"], - window_start, - window_end, - cache_value - ) + async with resources_provider.cache_functions() as cache_functions: + await cache_functions.set_monitor_cache( + organization_id, + model_version["id"], + monitor["id"], + window_start, + window_end, + cache_value + ) result: t.List[Alert] = await execute_monitor( monitor_id=monitor["id"],