diff --git a/api/specs/web-server/_long_running_tasks_legacy.py b/api/specs/web-server/_long_running_tasks_legacy.py index d17b9cceeed..d5fc487301a 100644 --- a/api/specs/web-server/_long_running_tasks_legacy.py +++ b/api/specs/web-server/_long_running_tasks_legacy.py @@ -26,7 +26,7 @@ name="list_tasks", description="Lists all long running tasks", ) -def list_tasks(): ... +async def list_tasks(): ... @router.get( @@ -35,7 +35,7 @@ def list_tasks(): ... name="get_task_status", description="Retrieves the status of a task", ) -def get_task_status( +async def get_task_status( _path_params: Annotated[_PathParam, Depends()], ): ... @@ -46,7 +46,7 @@ def get_task_status( description="Cancels and deletes a task", status_code=status.HTTP_204_NO_CONTENT, ) -def cancel_and_delete_task( +async def cancel_and_delete_task( _path_params: Annotated[_PathParam, Depends()], ): ... @@ -57,6 +57,6 @@ def cancel_and_delete_task( response_model=Any, description="Retrieves the result of a task", ) -def get_task_result( +async def get_task_result( _path_params: Annotated[_PathParam, Depends()], ): ... diff --git a/packages/celery-library/requirements/_test.in b/packages/celery-library/requirements/_test.in index e85e3cb5177..e6d3bd92107 100644 --- a/packages/celery-library/requirements/_test.in +++ b/packages/celery-library/requirements/_test.in @@ -11,6 +11,7 @@ # testing coverage faker +fakeredis[lua] httpx pint pytest diff --git a/packages/celery-library/requirements/_test.txt b/packages/celery-library/requirements/_test.txt index e9eea29a558..b35e14b9401 100644 --- a/packages/celery-library/requirements/_test.txt +++ b/packages/celery-library/requirements/_test.txt @@ -56,6 +56,8 @@ docker==7.1.0 # pytest-docker-tools faker==37.3.0 # via -r requirements/_test.in +fakeredis==2.30.3 + # via -r requirements/_test.in flexcache==0.3 # via pint flexparser==0.4 @@ -83,6 +85,8 @@ kombu==5.5.3 # -c requirements/_base.txt # celery # pytest-celery +lupa==2.5 + # via fakeredis packaging==25.0 # via # -c requirements/_base.txt @@ -112,6 +116,10 @@ pygments==2.19.1 # via # -c requirements/_base.txt # pytest +pyjwt==2.9.0 + # via + # -c requirements/_base.txt + # redis pytest==8.4.1 # via # -r requirements/_test.in @@ -156,6 +164,11 @@ pyyaml==6.0.2 # -c requirements/../../../requirements/constraints.txt # -c requirements/_base.txt # -r requirements/_test.in +redis==5.3.0 + # via + # -c requirements/../../../requirements/constraints.txt + # -c requirements/_base.txt + # fakeredis requests==2.32.4 # via # -c requirements/_base.txt @@ -170,6 +183,8 @@ sniffio==1.3.1 # via # -c requirements/_base.txt # anyio +sortedcontainers==2.4.0 + # via fakeredis tenacity==9.1.2 # via # -c requirements/_base.txt diff --git a/packages/celery-library/src/celery_library/common.py b/packages/celery-library/src/celery_library/common.py index 3b7c9cd22ab..d50e75597c6 100644 --- a/packages/celery-library/src/celery_library/common.py +++ b/packages/celery-library/src/celery_library/common.py @@ -47,6 +47,8 @@ async def create_task_manager( ), client_name="celery_tasks", ) + await redis_client_sdk.setup() + # GCR please address https://github.com/ITISFoundation/osparc-simcore/issues/8159 return CeleryTaskManager( app, diff --git a/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py b/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py index 8007df5ef8a..d6e132c5361 100644 --- a/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py +++ b/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Awaitable, Callable from typing import Annotated, TypeAlias from pydantic import BaseModel, Field, field_validator, validate_call @@ -22,8 +23,16 @@ class TaskProgress(BaseModel): message: ProgressMessage = "" percent: ProgressPercent = 0.0 + # used to propagate progress updates internally + _update_callback: Callable[["TaskProgress"], Awaitable[None]] | None = None + + def set_update_callback( + self, callback: Callable[["TaskProgress"], Awaitable[None]] + ) -> None: + self._update_callback = callback + @validate_call - def update( + async def update( self, *, message: ProgressMessage | None = None, @@ -40,6 +49,16 @@ def update( _logger.debug("Progress update: %s", f"{self}") + if self._update_callback is not None: + try: + await self._update_callback(self) + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning( + "Error while calling progress update callback: %s", + exc, + stack_info=True, + ) + @classmethod def create(cls, task_id: TaskId | None = None) -> "TaskProgress": return cls(task_id=task_id) diff --git a/packages/pytest-simcore/src/pytest_simcore/redis_service.py b/packages/pytest-simcore/src/pytest_simcore/redis_service.py index 05aec86a234..824d61a57fb 100644 --- a/packages/pytest-simcore/src/pytest_simcore/redis_service.py +++ b/packages/pytest-simcore/src/pytest_simcore/redis_service.py @@ -8,6 +8,7 @@ import pytest import tenacity +from fakeredis import FakeAsyncRedis from pytest_mock import MockerFixture from redis.asyncio import Redis, from_url from settings_library.basic_types import PortInt @@ -121,3 +122,9 @@ def mock_redis_socket_timeout(mocker: MockerFixture) -> None: mocker.patch( "servicelib.redis._client.DEFAULT_SOCKET_TIMEOUT", timedelta(seconds=0.25) ) + + +@pytest.fixture +async def use_in_memory_redis(mocker: MockerFixture) -> RedisSettings: + mocker.patch("redis.asyncio.from_url", FakeAsyncRedis) + return RedisSettings() diff --git a/packages/service-library/requirements/_test.in b/packages/service-library/requirements/_test.in index 8df4e188617..5fd26efe1e9 100644 --- a/packages/service-library/requirements/_test.in +++ b/packages/service-library/requirements/_test.in @@ -17,6 +17,7 @@ botocore coverage docker faker +fakeredis[lua] flaky numpy openapi-spec-validator diff --git a/packages/service-library/requirements/_test.txt b/packages/service-library/requirements/_test.txt index b3ea35d716a..c8bd758d450 100644 --- a/packages/service-library/requirements/_test.txt +++ b/packages/service-library/requirements/_test.txt @@ -53,6 +53,8 @@ execnet==2.1.1 # via pytest-xdist faker==36.1.1 # via -r requirements/_test.in +fakeredis==2.30.3 + # via -r requirements/_test.in flaky==3.8.1 # via -r requirements/_test.in frozenlist==1.5.0 @@ -109,6 +111,8 @@ jsonschema-specifications==2024.10.1 # openapi-schema-validator lazy-object-proxy==1.10.0 # via openapi-spec-validator +lupa==2.5 + # via fakeredis multidict==6.1.0 # via # -c requirements/_aiohttp.txt @@ -211,6 +215,11 @@ pyyaml==6.0.2 # -c requirements/_base.txt # -c requirements/_fastapi.txt # jsonschema-path +redis==5.2.1 + # via + # -c requirements/../../../requirements/constraints.txt + # -c requirements/_base.txt + # fakeredis referencing==0.35.1 # via # -c requirements/../../../requirements/constraints.txt @@ -245,6 +254,8 @@ sniffio==1.3.1 # -c requirements/_fastapi.txt # anyio # asgi-lifespan +sortedcontainers==2.4.0 + # via fakeredis sqlalchemy==1.4.54 # via # -c requirements/../../../requirements/constraints.txt diff --git a/packages/service-library/src/servicelib/aiohttp/db_asyncpg_engine.py b/packages/service-library/src/servicelib/aiohttp/db_asyncpg_engine.py index 88b0338dadf..bfcf16b8437 100644 --- a/packages/service-library/src/servicelib/aiohttp/db_asyncpg_engine.py +++ b/packages/service-library/src/servicelib/aiohttp/db_asyncpg_engine.py @@ -8,7 +8,6 @@ from typing import Final from aiohttp import web -from servicelib.logging_utils import log_context from settings_library.postgres import PostgresSettings from simcore_postgres_database.utils_aiosqlalchemy import ( # type: ignore[import-not-found] # this on is unclear get_pg_engine_stateinfo, diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py index aa1fc4ea09e..f03945126bd 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py @@ -1,9 +1,11 @@ import datetime from aiohttp import web +from settings_library.redis import RedisSettings from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager -from ...long_running_tasks.task import TaskContext, TasksManager +from ...long_running_tasks.models import TaskContext +from ...long_running_tasks.task import RedisNamespace, TasksManager from ._constants import APP_LONG_RUNNING_MANAGER_KEY from ._request import get_task_context @@ -14,11 +16,15 @@ def __init__( app: web.Application, stale_task_check_interval: datetime.timedelta, stale_task_detect_timeout: datetime.timedelta, + redis_settings: RedisSettings, + redis_namespace: RedisNamespace, ): self._app = app self._tasks_manager = TasksManager( stale_task_check_interval=stale_task_check_interval, stale_task_detect_timeout=stale_task_detect_timeout, + redis_settings=redis_settings, + redis_namespace=redis_namespace, ) @property diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py index ec036dcb1c5..cb735779901 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py @@ -2,10 +2,10 @@ from aiohttp import web from pydantic import BaseModel -from servicelib.aiohttp import status +from ...aiohttp import status from ...long_running_tasks import lrt_api -from ...long_running_tasks.models import TaskGet, TaskId, TaskStatus +from ...long_running_tasks.models import TaskGet, TaskId from ..requests_validation import parse_request_path_parameters_as from ..rest_responses import create_data_response from ._manager import get_long_running_manager @@ -28,7 +28,7 @@ async def list_tasks(request: web.Request) -> web.Response: result_href=f"{request.app.router['get_task_result'].url_for(task_id=t.task_id)}", abort_href=f"{request.app.router['cancel_and_delete_task'].url_for(task_id=t.task_id)}", ) - for t in lrt_api.list_tasks( + for t in await lrt_api.list_tasks( long_running_manager.tasks_manager, long_running_manager.get_task_context(request), ) @@ -41,7 +41,7 @@ async def get_task_status(request: web.Request) -> web.Response: path_params = parse_request_path_parameters_as(_PathParam, request) long_running_manager = get_long_running_manager(request.app) - task_status: TaskStatus = lrt_api.get_task_status( + task_status = await lrt_api.get_task_status( long_running_manager.tasks_manager, long_running_manager.get_task_context(request), path_params.task_id, diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py index e8b2efad1dc..68147fede02 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py @@ -5,17 +5,23 @@ from typing import Any from aiohttp import web +from aiohttp.web import HTTPException from common_library.json_serialization import json_dumps from pydantic import AnyHttpUrl, TypeAdapter +from settings_library.redis import RedisSettings from ...aiohttp import status from ...long_running_tasks import lrt_api +from ...long_running_tasks._redis_serialization import ( + BaseObjectSerializer, + register_custom_serialization, +) from ...long_running_tasks.constants import ( DEFAULT_STALE_TASK_CHECK_INTERVAL, DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) -from ...long_running_tasks.models import TaskGet -from ...long_running_tasks.task import RegisteredTaskName, TaskContext +from ...long_running_tasks.models import TaskContext, TaskGet +from ...long_running_tasks.task import RedisNamespace, RegisteredTaskName from ..typing_extension import Handler from . import _routes from ._constants import ( @@ -89,11 +95,10 @@ async def start_long_running_task( dumps=json_dumps, ) except asyncio.CancelledError: - # cancel the task, the client has disconnected + # remove the task, the client was disconnected if task_id: - long_running_manager = get_long_running_manager(request_.app) - await long_running_manager.tasks_manager.cancel_task( - task_id, with_task_context=None + await lrt_api.remove_task( + long_running_manager.tasks_manager, task_context, task_id ) raise @@ -117,10 +122,28 @@ def _wrap_and_add_routes( ) +class AiohttpHTTPExceptionSerializer(BaseObjectSerializer[HTTPException]): + @classmethod + def get_init_kwargs_from_object(cls, obj: HTTPException) -> dict: + return { + "status_code": obj.status_code, + "reason": obj.reason, + "text": obj.text, + "headers": dict(obj.headers) if obj.headers else None, + } + + @classmethod + def prepare_object_init_kwargs(cls, data: dict) -> dict: + data.pop("status_code") + return data + + def setup( app: web.Application, *, router_prefix: str, + redis_settings: RedisSettings, + redis_namespace: RedisNamespace, handler_check_decorator: Callable = _no_ops_decorator, task_request_context_decorator: Callable = _no_task_context_decorator, stale_task_check_interval: datetime.timedelta = DEFAULT_STALE_TASK_CHECK_INTERVAL, @@ -137,6 +160,8 @@ def setup( """ async def on_cleanup_ctx(app: web.Application) -> AsyncGenerator[None, None]: + register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer) + # add error handlers app.middlewares.append(base_long_running_error_handler) @@ -146,6 +171,8 @@ async def on_cleanup_ctx(app: web.Application) -> AsyncGenerator[None, None]: app=app, stale_task_check_interval=stale_task_check_interval, stale_task_detect_timeout=stale_task_detect_timeout, + redis_settings=redis_settings, + redis_namespace=redis_namespace, ) ) diff --git a/packages/service-library/src/servicelib/aiohttp/observer.py b/packages/service-library/src/servicelib/aiohttp/observer.py index a80b5d90e15..265ee74425a 100644 --- a/packages/service-library/src/servicelib/aiohttp/observer.py +++ b/packages/service-library/src/servicelib/aiohttp/observer.py @@ -8,8 +8,8 @@ from collections.abc import Callable from aiohttp import web -from servicelib.aiohttp.application_setup import ensure_single_setup +from ..aiohttp.application_setup import ensure_single_setup from ..utils import logged_gather log = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py b/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py index 89a2c56d39a..fe05ae9a55d 100644 --- a/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py +++ b/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py @@ -16,13 +16,12 @@ from common_library.user_messages import user_message from models_library.basic_types import IDStr from models_library.rest_error import ErrorGet, ErrorItemType, LogMessageType -from servicelib.rest_constants import RESPONSE_MODEL_POLICY -from servicelib.status_codes_utils import is_5xx_server_error from ..logging_errors import create_troubleshootting_log_kwargs from ..mimetype_constants import MIMETYPE_APPLICATION_JSON +from ..rest_constants import RESPONSE_MODEL_POLICY from ..rest_responses import is_enveloped_from_text -from ..status_codes_utils import get_code_description +from ..status_codes_utils import get_code_description, is_5xx_server_error from . import status from .rest_responses import ( create_data_response, diff --git a/packages/service-library/src/servicelib/aiohttp/tracing.py b/packages/service-library/src/servicelib/aiohttp/tracing.py index 7f26ebf27de..0d8fac83625 100644 --- a/packages/service-library/src/servicelib/aiohttp/tracing.py +++ b/packages/service-library/src/servicelib/aiohttp/tracing.py @@ -17,11 +17,12 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import SpanProcessor, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from servicelib.logging_utils import log_context -from servicelib.tracing import get_trace_id_header from settings_library.tracing import TracingSettings from yarl import URL +from ..logging_utils import log_context +from ..tracing import get_trace_id_header + _logger = logging.getLogger(__name__) try: from opentelemetry.instrumentation.botocore import ( # type: ignore[import-not-found] diff --git a/packages/service-library/src/servicelib/archiving_utils/_interface_7zip.py b/packages/service-library/src/servicelib/archiving_utils/_interface_7zip.py index 1e642895f1d..9fab723c1ed 100644 --- a/packages/service-library/src/servicelib/archiving_utils/_interface_7zip.py +++ b/packages/service-library/src/servicelib/archiving_utils/_interface_7zip.py @@ -11,10 +11,10 @@ import tqdm from pydantic import NonNegativeInt -from servicelib.logging_utils import log_catch from tqdm.contrib.logging import tqdm_logging_redirect from ..file_utils import shutil_move +from ..logging_utils import log_catch from ..progress_bar import ProgressBarData from ._errors import ( CouldNotFindValueError, diff --git a/packages/service-library/src/servicelib/background_task_utils.py b/packages/service-library/src/servicelib/background_task_utils.py index 45119649c26..bd70241b183 100644 --- a/packages/service-library/src/servicelib/background_task_utils.py +++ b/packages/service-library/src/servicelib/background_task_utils.py @@ -3,11 +3,10 @@ from collections.abc import Callable, Coroutine from typing import Any, ParamSpec, TypeVar -from servicelib.exception_utils import suppress_exceptions -from servicelib.redis._errors import CouldNotAcquireLockError - from .background_task import periodic +from .exception_utils import suppress_exceptions from .redis import RedisClientSDK, exclusive +from .redis._errors import CouldNotAcquireLockError P = ParamSpec("P") R = TypeVar("R") diff --git a/packages/service-library/src/servicelib/celery/app_server.py b/packages/service-library/src/servicelib/celery/app_server.py index dbddd9bfee2..9312497aa31 100644 --- a/packages/service-library/src/servicelib/celery/app_server.py +++ b/packages/service-library/src/servicelib/celery/app_server.py @@ -4,7 +4,7 @@ from asyncio import AbstractEventLoop from typing import Generic, TypeVar -from servicelib.celery.task_manager import TaskManager +from ..celery.task_manager import TaskManager T = TypeVar("T") diff --git a/packages/service-library/src/servicelib/deferred_tasks/_deferred_manager.py b/packages/service-library/src/servicelib/deferred_tasks/_deferred_manager.py index ab2fd15f20e..9c41d2e5f87 100644 --- a/packages/service-library/src/servicelib/deferred_tasks/_deferred_manager.py +++ b/packages/service-library/src/servicelib/deferred_tasks/_deferred_manager.py @@ -16,10 +16,10 @@ RabbitRouter, ) from pydantic import NonNegativeInt -from servicelib.logging_utils import log_catch, log_context -from servicelib.redis import RedisClientSDK from settings_library.rabbit import RabbitSettings +from ..logging_utils import log_catch, log_context +from ..redis import RedisClientSDK from ._base_deferred_handler import ( BaseDeferredHandler, DeferredContext, diff --git a/packages/service-library/src/servicelib/exception_utils.py b/packages/service-library/src/servicelib/exception_utils.py index 0af95d35e4d..6481a9746e9 100644 --- a/packages/service-library/src/servicelib/exception_utils.py +++ b/packages/service-library/src/servicelib/exception_utils.py @@ -6,7 +6,8 @@ from typing import Any, Final, ParamSpec, TypeVar from pydantic import BaseModel, Field, NonNegativeFloat, PrivateAttr -from servicelib.logging_errors import create_troubleshootting_log_kwargs + +from .logging_errors import create_troubleshootting_log_kwargs _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/fastapi/http_client_thin.py b/packages/service-library/src/servicelib/fastapi/http_client_thin.py index e4806f88bcf..a62461f0009 100644 --- a/packages/service-library/src/servicelib/fastapi/http_client_thin.py +++ b/packages/service-library/src/servicelib/fastapi/http_client_thin.py @@ -8,7 +8,6 @@ from common_library.errors_classes import OsparcErrorMixin from httpx import AsyncClient, ConnectError, HTTPError, PoolTimeout, Response from httpx._types import TimeoutTypes, URLTypes -from servicelib.fastapi.tracing import setup_httpx_client_tracing from settings_library.tracing import TracingSettings from tenacity import RetryCallState from tenacity.asyncio import AsyncRetrying @@ -18,6 +17,7 @@ from tenacity.wait import wait_exponential from .http_client import BaseHTTPApi +from .tracing import setup_httpx_client_tracing _logger = logging.getLogger(__name__) @@ -128,7 +128,7 @@ def retry_on_errors( """ def decorator( - request_func: Callable[..., Awaitable[Response]] + request_func: Callable[..., Awaitable[Response]], ) -> Callable[..., Awaitable[Response]]: assert asyncio.iscoroutinefunction(request_func) @@ -178,7 +178,7 @@ def expect_status( """ def decorator( - request_func: Callable[..., Awaitable[Response]] + request_func: Callable[..., Awaitable[Response]], ) -> Callable[..., Awaitable[Response]]: assert asyncio.iscoroutinefunction(request_func) diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py index 3c615c3fc5c..6a7ff58814d 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py @@ -5,8 +5,8 @@ from typing import Any, Final from pydantic import PositiveFloat -from servicelib.logging_errors import create_troubleshootting_log_message +from ...logging_errors import create_troubleshootting_log_message from ...long_running_tasks.errors import TaskClientTimeoutError, TaskExceptionError from ...long_running_tasks.models import ( ProgressCallback, @@ -128,7 +128,6 @@ async def _wait_for_task_result() -> Any: exception=e, ) from e except Exception as e: - error = TaskExceptionError(task_id=task_id, exception=e, traceback="") _logger.warning( create_troubleshootting_log_message( user_error_msg=f"{task_id=} raised an exception", @@ -136,4 +135,4 @@ async def _wait_for_task_result() -> Any: tip=f"Check the logs of the service responding to '{client.base_url}'", ) ) - raise error from e + raise TaskExceptionError(task_id=task_id, exception=e, traceback="") from e diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py index bbc3e098a7e..6f37eb40825 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py @@ -1,9 +1,10 @@ import datetime from fastapi import FastAPI +from settings_library.redis import RedisSettings from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager -from ...long_running_tasks.task import TasksManager +from ...long_running_tasks.task import RedisNamespace, TasksManager class FastAPILongRunningManager(BaseLongRunningManager): @@ -12,11 +13,15 @@ def __init__( app: FastAPI, stale_task_check_interval: datetime.timedelta, stale_task_detect_timeout: datetime.timedelta, + redis_settings: RedisSettings, + redis_namespace: RedisNamespace, ): self._app = app self._tasks_manager = TasksManager( stale_task_check_interval=stale_task_check_interval, stale_task_detect_timeout=stale_task_detect_timeout, + redis_settings=redis_settings, + redis_namespace=redis_namespace, ) @property diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py index 5f8a5e0a633..d3adbb1956d 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py @@ -29,8 +29,8 @@ async def list_tasks( request.url_for("cancel_and_delete_task", task_id=t.task_id) ), ) - for t in lrt_api.list_tasks( - long_running_manager.tasks_manager, task_context=None + for t in await lrt_api.list_tasks( + long_running_manager.tasks_manager, task_context={} ) ] @@ -51,8 +51,8 @@ async def get_task_status( ], ) -> TaskStatus: assert request # nosec - return lrt_api.get_task_status( - long_running_manager.tasks_manager, task_context=None, task_id=task_id + return await lrt_api.get_task_status( + long_running_manager.tasks_manager, task_context={}, task_id=task_id ) @@ -75,7 +75,7 @@ async def get_task_result( ) -> TaskResult | Any: assert request # nosec return await lrt_api.get_task_result( - long_running_manager.tasks_manager, task_context=None, task_id=task_id + long_running_manager.tasks_manager, task_context={}, task_id=task_id ) @@ -98,5 +98,5 @@ async def cancel_and_delete_task( ) -> None: assert request # nosec await lrt_api.remove_task( - long_running_manager.tasks_manager, task_context=None, task_id=task_id + long_running_manager.tasks_manager, task_context={}, task_id=task_id ) diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py index 272250ae258..f00e6c8f521 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py @@ -1,12 +1,14 @@ import datetime from fastapi import APIRouter, FastAPI +from settings_library.redis import RedisSettings from ...long_running_tasks.constants import ( DEFAULT_STALE_TASK_CHECK_INTERVAL, DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) from ...long_running_tasks.errors import BaseLongRunningError +from ...long_running_tasks.task import RedisNamespace from ._error_handlers import base_long_running_error_handler from ._manager import FastAPILongRunningManager from ._routes import router @@ -16,6 +18,8 @@ def setup( app: FastAPI, *, router_prefix: str = "", + redis_settings: RedisSettings, + redis_namespace: RedisNamespace, stale_task_check_interval: datetime.timedelta = DEFAULT_STALE_TASK_CHECK_INTERVAL, stale_task_detect_timeout: datetime.timedelta = DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) -> None: @@ -41,6 +45,8 @@ async def on_startup() -> None: app=app, stale_task_check_interval=stale_task_check_interval, stale_task_detect_timeout=stale_task_detect_timeout, + redis_settings=redis_settings, + redis_namespace=redis_namespace, ) ) await long_running_manager.setup() diff --git a/packages/service-library/src/servicelib/fastapi/monitoring.py b/packages/service-library/src/servicelib/fastapi/monitoring.py index 32dd26f53d6..a9c33f0d216 100644 --- a/packages/service-library/src/servicelib/fastapi/monitoring.py +++ b/packages/service-library/src/servicelib/fastapi/monitoring.py @@ -13,12 +13,6 @@ CONTENT_TYPE_LATEST, generate_latest, ) -from servicelib.prometheus_metrics import ( - PrometheusMetrics, - get_prometheus_metrics, - record_request_metrics, - record_response_metrics, -) from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.types import ASGIApp @@ -26,6 +20,12 @@ UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE, X_SIMCORE_USER_AGENT, ) +from ..prometheus_metrics import ( + PrometheusMetrics, + get_prometheus_metrics, + record_request_metrics, + record_response_metrics, +) _logger = logging.getLogger(__name__) _PROMETHEUS_METRICS = "prometheus_metrics" diff --git a/packages/service-library/src/servicelib/fastapi/profiler.py b/packages/service-library/src/servicelib/fastapi/profiler.py index cb3e7c5c084..9010c6296f0 100644 --- a/packages/service-library/src/servicelib/fastapi/profiler.py +++ b/packages/service-library/src/servicelib/fastapi/profiler.py @@ -1,11 +1,11 @@ from typing import Any, Final from fastapi import FastAPI -from servicelib.aiohttp import status -from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send +from ..aiohttp import status +from ..mimetype_constants import MIMETYPE_APPLICATION_JSON from ..utils_profiling_middleware import ( _is_profiling, _profiler, diff --git a/packages/service-library/src/servicelib/fastapi/redis_lifespan.py b/packages/service-library/src/servicelib/fastapi/redis_lifespan.py index b1ac98e9d6c..b8955d2c8ae 100644 --- a/packages/service-library/src/servicelib/fastapi/redis_lifespan.py +++ b/packages/service-library/src/servicelib/fastapi/redis_lifespan.py @@ -51,6 +51,7 @@ async def redis_client_sdk_lifespan(_: FastAPI, state: State) -> AsyncIterator[S redis_dsn_with_secrets, client_name=redis_state.REDIS_CLIENT_NAME, ) + await redis_client.setup() try: yield {"REDIS_CLIENT_SDK": redis_client, **called_state} diff --git a/packages/service-library/src/servicelib/fastapi/tracing.py b/packages/service-library/src/servicelib/fastapi/tracing.py index 9943ed81022..50c8aeab1d7 100644 --- a/packages/service-library/src/servicelib/fastapi/tracing.py +++ b/packages/service-library/src/servicelib/fastapi/tracing.py @@ -15,12 +15,13 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import SpanProcessor, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from servicelib.logging_utils import log_context -from servicelib.tracing import get_trace_id_header from settings_library.tracing import TracingSettings from starlette.middleware.base import BaseHTTPMiddleware from yarl import URL +from ..logging_utils import log_context +from ..tracing import get_trace_id_header + _logger = logging.getLogger(__name__) try: diff --git a/packages/service-library/src/servicelib/long_running_tasks/_redis_serialization.py b/packages/service-library/src/servicelib/long_running_tasks/_redis_serialization.py new file mode 100644 index 00000000000..ae7125147a1 --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/_redis_serialization.py @@ -0,0 +1,85 @@ +import base64 +import logging +import pickle +from abc import ABC, abstractmethod +from typing import Any, Final, Generic, TypeVar + +_logger = logging.getLogger(__name__) + + +T = TypeVar("T") + + +class BaseObjectSerializer(ABC, Generic[T]): + + @classmethod + @abstractmethod + def get_init_kwargs_from_object(cls, obj: T) -> dict: + """dictionary reppreseting the kwargs passed to the __init__ method""" + + @classmethod + @abstractmethod + def prepare_object_init_kwargs(cls, data: dict) -> dict: + """cleanup data to be used as kwargs for the __init__ method if required""" + + +_SERIALIZERS: Final[dict[type, type[BaseObjectSerializer]]] = {} + + +def register_custom_serialization( + object_type: type, object_serializer: type[BaseObjectSerializer] +) -> None: + """Register a custom serializer for a specific object type. + + Arguments: + object_type -- the type or parent class of the object to be serialized + object_serializer -- custom implementation of BaseObjectSerializer for the object type + """ + _SERIALIZERS[object_type] = object_serializer + + +_TYPE_FIELD: Final[str] = "__pickle__type__field__" +_MODULE_FIELD: Final[str] = "__pickle__module__field__" + + +def object_to_string(e: Any) -> str: + """Serialize object to base64-encoded string.""" + to_serialize: Any | dict = e + object_class = type(e) + + for registered_class, object_serializer in _SERIALIZERS.items(): + if issubclass(object_class, registered_class): + to_serialize = { + _TYPE_FIELD: type(e).__name__, + _MODULE_FIELD: type(e).__module__, + **object_serializer.get_init_kwargs_from_object(e), + } + break + + return base64.b85encode(pickle.dumps(to_serialize)).decode("utf-8") + + +def string_to_object(error_str: str) -> Any: + """Deserialize object from base64-encoded string.""" + data = pickle.loads(base64.b85decode(error_str)) # noqa: S301 + + if isinstance(data, dict) and _TYPE_FIELD in data and _MODULE_FIELD in data: + try: + # Import the module and get the exception class + module = __import__(data[_MODULE_FIELD], fromlist=[data[_TYPE_FIELD]]) + exception_class = getattr(module, data[_TYPE_FIELD]) + + for registered_class, object_serializer in _SERIALIZERS.items(): + if issubclass(exception_class, registered_class): + # remove unrequired + data.pop(_TYPE_FIELD) + data.pop(_MODULE_FIELD) + + return exception_class( + **object_serializer.prepare_object_init_kwargs(data) + ) + except (ImportError, AttributeError, TypeError) as e: + msg = f"Could not reconstruct object from data: {data}" + raise ValueError(msg) from e + + return data diff --git a/packages/service-library/src/servicelib/long_running_tasks/_store/__init__.py b/packages/service-library/src/servicelib/long_running_tasks/_store/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/service-library/src/servicelib/long_running_tasks/_store/base.py b/packages/service-library/src/servicelib/long_running_tasks/_store/base.py new file mode 100644 index 00000000000..20829c01ca6 --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/_store/base.py @@ -0,0 +1,40 @@ +from abc import abstractmethod + +from ..models import TaskContext, TaskData, TaskId + + +class BaseStore: + + @abstractmethod + async def get_task_data(self, task_id: TaskId) -> TaskData | None: + """Retrieve a tracked task""" + + @abstractmethod + async def set_task_data(self, task_id: TaskId, value: TaskData) -> None: + """Set a tracked task's data""" + + @abstractmethod + async def list_tasks_data(self) -> list[TaskData]: + """List all tracked tasks.""" + + @abstractmethod + async def delete_task_data(self, task_id: TaskId) -> None: + """Delete a tracked task.""" + + @abstractmethod + async def set_as_cancelled( + self, task_id: TaskId, with_task_context: TaskContext + ) -> None: + """Mark a tracked task as cancelled.""" + + @abstractmethod + async def get_cancelled(self) -> dict[TaskId, TaskContext]: + """Get cancelled tasks.""" + + @abstractmethod + async def setup(self) -> None: + """Setup the store, if needed.""" + + @abstractmethod + async def shutdown(self) -> None: + """Shutdown the store, if needed.""" diff --git a/packages/service-library/src/servicelib/long_running_tasks/_store/redis.py b/packages/service-library/src/servicelib/long_running_tasks/_store/redis.py new file mode 100644 index 00000000000..3ac1314e11c --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/_store/redis.py @@ -0,0 +1,87 @@ +from typing import Any, Final + +import redis.asyncio as aioredis +from common_library.json_serialization import json_dumps, json_loads +from pydantic import TypeAdapter +from settings_library.redis import RedisDatabase, RedisSettings + +from ...redis._client import RedisClientSDK +from ...redis._utils import handle_redis_returns_union_types +from ..models import TaskContext, TaskData, TaskId +from .base import BaseStore + +_STORE_TYPE_TASK_DATA: Final[str] = "TD" +_STORE_TYPE_CANCELLED_TASKS: Final[str] = "CT" + + +class RedisStore(BaseStore): + def __init__(self, redis_settings: RedisSettings, namespace: str): + self.redis_settings = redis_settings + self.namespace = namespace.upper() + + self._client: RedisClientSDK | None = None + + async def setup(self) -> None: + self._client = RedisClientSDK( + self.redis_settings.build_redis_dsn(RedisDatabase.LONG_RUNNING_TASKS), + client_name=f"long_running_tasks_store_{self.namespace}", + ) + await self._client.setup() + + async def shutdown(self) -> None: + if self._client: + await self._client.shutdown() + + @property + def _redis(self) -> aioredis.Redis: + assert self._client # nosec + return self._client.redis + + def _get_redis_hash_key(self, store_type: str) -> str: + return f"{self.namespace}:{store_type}" + + def _get_key(self, store_type: str, name: str) -> str: + return f"{self.namespace}:{store_type}:{name}" + + async def get_task_data(self, task_id: TaskId) -> TaskData | None: + result: Any | None = await handle_redis_returns_union_types( + self._redis.hget(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id) + ) + return TypeAdapter(TaskData).validate_json(result) if result else None + + async def set_task_data(self, task_id: TaskId, value: TaskData) -> None: + await handle_redis_returns_union_types( + self._redis.hset( + self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), + task_id, + value.model_dump_json(), + ) + ) + + async def list_tasks_data(self) -> list[TaskData]: + result: list[Any] = await handle_redis_returns_union_types( + self._redis.hvals(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA)) + ) + return [TypeAdapter(TaskData).validate_json(item) for item in result] + + async def delete_task_data(self, task_id: TaskId) -> None: + await handle_redis_returns_union_types( + self._redis.hdel(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id) + ) + + async def set_as_cancelled( + self, task_id: TaskId, with_task_context: TaskContext + ) -> None: + await handle_redis_returns_union_types( + self._redis.hset( + self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS), + task_id, + json_dumps(with_task_context), + ) + ) + + async def get_cancelled(self) -> dict[TaskId, TaskContext]: + result: dict[str, str | None] = await handle_redis_returns_union_types( + self._redis.hgetall(self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS)) + ) + return {task_id: json_loads(context) for task_id, context in result.items()} diff --git a/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py index 5e8c9191f7e..6f732d49e49 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py +++ b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py @@ -2,11 +2,11 @@ from typing import Any from common_library.error_codes import create_error_code -from servicelib.logging_errors import create_troubleshootting_log_kwargs +from ..logging_errors import create_troubleshootting_log_kwargs from .errors import TaskNotCompletedError, TaskNotFoundError -from .models import TaskBase, TaskId, TaskStatus -from .task import RegisteredTaskName, TaskContext, TasksManager +from .models import TaskBase, TaskContext, TaskId, TaskStatus +from .task import RegisteredTaskName, TasksManager _logger = logging.getLogger(__name__) @@ -46,7 +46,7 @@ async def start_task( Returns: TaskId: the task unique identifier """ - return tasks_manager.start_task( + return await tasks_manager.start_task( registered_task_name, unique=unique, task_context=task_context, @@ -56,26 +56,26 @@ async def start_task( ) -def list_tasks( - tasks_manager: TasksManager, task_context: TaskContext | None +async def list_tasks( + tasks_manager: TasksManager, task_context: TaskContext ) -> list[TaskBase]: - return tasks_manager.list_tasks(with_task_context=task_context) + return await tasks_manager.list_tasks(with_task_context=task_context) -def get_task_status( - tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId +async def get_task_status( + tasks_manager: TasksManager, task_context: TaskContext, task_id: TaskId ) -> TaskStatus: """returns the status of a task""" - return tasks_manager.get_task_status( + return await tasks_manager.get_task_status( task_id=task_id, with_task_context=task_context ) async def get_task_result( - tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId + tasks_manager: TasksManager, task_context: TaskContext, task_id: TaskId ) -> Any: try: - task_result = tasks_manager.get_task_result( + task_result = await tasks_manager.get_task_result( task_id, with_task_context=task_context ) await tasks_manager.remove_task( @@ -101,7 +101,7 @@ async def get_task_result( async def remove_task( - tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId + tasks_manager: TasksManager, task_context: TaskContext, task_id: TaskId ) -> None: - """removes / cancels a task""" + """cancels and removes the task""" await tasks_manager.remove_task(task_id, with_task_context=task_context) diff --git a/packages/service-library/src/servicelib/long_running_tasks/models.py b/packages/service-library/src/servicelib/long_running_tasks/models.py index 15ab97515af..a8c626714c1 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/models.py +++ b/packages/service-library/src/servicelib/long_running_tasks/models.py @@ -1,10 +1,10 @@ # mypy: disable-error-code=truthy-function -from asyncio import Task from collections.abc import Awaitable, Callable, Coroutine from dataclasses import dataclass from datetime import UTC, datetime -from typing import Any, TypeAlias +from typing import Annotated, Any, TypeAlias +from common_library.basic_types import DEFAULT_FACTORY from models_library.api_schemas_long_running_tasks.base import ( ProgressMessage, ProgressPercent, @@ -17,7 +17,7 @@ TaskResult, TaskStatus, ) -from pydantic import BaseModel, ConfigDict, Field, PositiveFloat +from pydantic import BaseModel, ConfigDict, Field, PositiveFloat, model_validator TaskType: TypeAlias = Callable[..., Coroutine[Any, Any, Any]] @@ -26,29 +26,70 @@ ] RequestBody: TypeAlias = Any +TaskContext: TypeAlias = dict[str, Any] -class TrackedTask(BaseModel): +class ResultField(BaseModel): + result: str | None = None + error: str | None = None + + @model_validator(mode="after") + def validate_mutually_exclusive(self) -> "ResultField": + if self.result is not None and self.error is not None: + msg = "Cannot set both 'result' and 'error' - they are mutually exclusive" + raise ValueError(msg) + return self + + +class TaskData(BaseModel): task_id: str - task: Task task_progress: TaskProgress # NOTE: this context lifetime is with the tracked task (similar to aiohttp storage concept) task_context: dict[str, Any] - fire_and_forget: bool = Field( - ..., - description="if True then the task will not be auto-cancelled if no one enquires of its status", - ) - - started: datetime = Field(default_factory=lambda: datetime.now(UTC)) - last_status_check: datetime | None = Field( - default=None, - description=( - "used to detect when if the task is not actively " - "polled by the client who created it" + fire_and_forget: Annotated[ + bool, + Field( + description="if True then the task will not be auto-cancelled if no one enquires of its status" ), + ] + + started: Annotated[datetime, Field(default_factory=lambda: datetime.now(UTC))] = ( + DEFAULT_FACTORY ) + last_status_check: Annotated[ + datetime | None, + Field( + description=( + "used to detect when if the task is not actively " + "polled by the client who created it" + ) + ), + ] = None + + is_done: Annotated[ + bool, + Field(description="True when the task finished running with or without errors"), + ] = False + result_field: Annotated[ + ResultField | None, Field(description="the result of the task") + ] = None + model_config = ConfigDict( arbitrary_types_allowed=True, + json_schema_extra={ + "examples": [ + { + "task_id": "1a119618-7186-4bc1-b8de-7e3ff314cb7e", + "task_name": "running-task", + "task_status": "running", + "task_progress": { + "task_id": "1a119618-7186-4bc1-b8de-7e3ff314cb7e" + }, + "task_context": {"key": "value"}, + "fire_and_forget": False, + } + ] + }, ) @@ -75,8 +116,8 @@ async def result(self) -> Any: __all__: tuple[str, ...] = ( "ProgressMessage", "ProgressPercent", - "TaskGet", "TaskBase", + "TaskGet", "TaskId", "TaskProgress", "TaskResult", diff --git a/packages/service-library/src/servicelib/long_running_tasks/task.py b/packages/service-library/src/servicelib/long_running_tasks/task.py index b563c3db1b1..ca22d9ff6af 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/task.py +++ b/packages/service-library/src/servicelib/long_running_tasks/task.py @@ -1,45 +1,49 @@ import asyncio import datetime +import functools import inspect import logging -import traceback import urllib.parse -from collections import deque -from contextlib import suppress from typing import Any, ClassVar, Final, Protocol, TypeAlias from uuid import uuid4 from common_library.async_tools import cancel_wait_task from models_library.api_schemas_long_running_tasks.base import TaskProgress -from pydantic import PositiveFloat +from pydantic import NonNegativeFloat, PositiveFloat +from settings_library.redis import RedisDatabase, RedisSettings +from tenacity import ( + AsyncRetrying, + TryAgain, + retry_if_exception_type, + stop_after_delay, + wait_exponential, +) from ..background_task import create_periodic_task -from ..logging_utils import log_catch +from ..redis import RedisClientSDK, exclusive +from ._redis_serialization import object_to_string, string_to_object +from ._store.base import BaseStore +from ._store.redis import RedisStore from .errors import ( TaskAlreadyRunningError, TaskCancelledError, - TaskExceptionError, TaskNotCompletedError, TaskNotFoundError, TaskNotRegisteredError, ) -from .models import TaskBase, TaskId, TaskStatus, TrackedTask +from .models import ResultField, TaskBase, TaskContext, TaskData, TaskId, TaskStatus _logger = logging.getLogger(__name__) -# NOTE: for now only this one is used, in future it will be unqiue per service -# and this default will be removed and become mandatory -_DEFAULT_NAMESPACE: Final[str] = "lrt" +_CANCEL_TASKS_CHECK_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) +_STATUS_UPDATE_CHECK_INTERNAL: Final[datetime.timedelta] = datetime.timedelta(seconds=1) + +_TASK_REMOVAL_MAX_WAIT: Final[NonNegativeFloat] = 60 -_CANCEL_TASK_TIMEOUT: Final[PositiveFloat] = datetime.timedelta( - seconds=10 # NOTE: 1 second is too short to cleanup a task -).total_seconds() RegisteredTaskName: TypeAlias = str -Namespace: TypeAlias = str -TrackedTaskGroupDict: TypeAlias = dict[TaskId, TrackedTask] -TaskContext: TypeAlias = dict[str, Any] +RedisNamespace: TypeAlias = str class TaskProtocol(Protocol): @@ -64,19 +68,15 @@ def unregister(cls, task: TaskProtocol) -> None: del cls.REGISTERED_TASKS[task.__name__] -async def _await_task(task: asyncio.Task) -> None: - await task - - -def _get_tasks_to_remove( - tracked_tasks: TrackedTaskGroupDict, +async def _get_tasks_to_remove( + tracked_tasks: BaseStore, stale_task_detect_timeout_s: PositiveFloat, -) -> list[TaskId]: +) -> list[tuple[TaskId, TaskContext]]: utc_now = datetime.datetime.now(tz=datetime.UTC) - tasks_to_remove: list[TaskId] = [] + tasks_to_remove: list[tuple[TaskId, TaskContext]] = [] - for task_id, tracked_task in tracked_tasks.items(): + for tracked_task in await tracked_tasks.list_tasks_data(): if tracked_task.fire_and_forget: continue @@ -84,61 +84,127 @@ def _get_tasks_to_remove( # the task just added or never received a poll request elapsed_from_start = (utc_now - tracked_task.started).seconds if elapsed_from_start > stale_task_detect_timeout_s: - tasks_to_remove.append(task_id) + tasks_to_remove.append( + (tracked_task.task_id, tracked_task.task_context) + ) else: # the task status was already queried by the client elapsed_from_last_poll = (utc_now - tracked_task.last_status_check).seconds if elapsed_from_last_poll > stale_task_detect_timeout_s: - tasks_to_remove.append(task_id) + tasks_to_remove.append( + (tracked_task.task_id, tracked_task.task_context) + ) return tasks_to_remove -class TasksManager: +class TasksManager: # pylint:disable=too-many-instance-attributes """ Monitors execution and results retrieval of a collection of asyncio.Tasks """ def __init__( self, + redis_settings: RedisSettings, stale_task_check_interval: datetime.timedelta, stale_task_detect_timeout: datetime.timedelta, - namespace: Namespace = _DEFAULT_NAMESPACE, + redis_namespace: RedisNamespace, ): - self.namespace = namespace # Task groups: Every taskname maps to multiple asyncio.Task within TrackedTask model - self._tracked_tasks: TrackedTaskGroupDict = {} + self._tasks_data: BaseStore = RedisStore(redis_settings, redis_namespace) + self._created_tasks: dict[TaskId, asyncio.Task] = {} self.stale_task_check_interval = stale_task_check_interval self.stale_task_detect_timeout_s: PositiveFloat = ( stale_task_detect_timeout.total_seconds() ) + self.redis_namespace = redis_namespace + self.redis_settings = redis_settings - self._stale_tasks_monitor_task: asyncio.Task | None = None + self.locks_redis_client_sdk: RedisClientSDK | None = None + + # stale_tasks_monitor + self._task_stale_tasks_monitor: asyncio.Task | None = None + self._started_event_task_stale_tasks_monitor = asyncio.Event() + + # cancelled_tasks_removal + self._task_cancelled_tasks_removal: asyncio.Task | None = None + self._started_event_task_cancelled_tasks_removal = asyncio.Event() + + # status_update + self._task_status_update: asyncio.Task | None = None + self._started_event_task_status_update = asyncio.Event() async def setup(self) -> None: - self._stale_tasks_monitor_task = create_periodic_task( - task=self._stale_tasks_monitor_worker, + await self._tasks_data.setup() + + self.locks_redis_client_sdk = RedisClientSDK( + self.redis_settings.build_redis_dsn(RedisDatabase.LOCKS), + client_name=f"long_running_tasks_store_{self.redis_namespace}_lock", + ) + await self.locks_redis_client_sdk.setup() + + # stale_tasks_monitor + self._task_stale_tasks_monitor = create_periodic_task( + task=exclusive( + self.locks_redis_client_sdk, + lock_key=f"{__name__}_{self.redis_namespace}_stale_tasks_monitor", + )(self._stale_tasks_monitor), interval=self.stale_task_check_interval, - task_name=f"{__name__}.{self._stale_tasks_monitor_worker.__name__}", + task_name=f"{__name__}.{self._stale_tasks_monitor.__name__}", ) + await self._started_event_task_stale_tasks_monitor.wait() + + # cancelled_tasks_removal + self._task_cancelled_tasks_removal = create_periodic_task( + task=self._cancelled_tasks_removal, + interval=_CANCEL_TASKS_CHECK_INTERVAL, + task_name=f"{__name__}.{self._cancelled_tasks_removal.__name__}", + ) + await self._started_event_task_cancelled_tasks_removal.wait() + + # status_update + self._task_status_update = create_periodic_task( + task=self._status_update, + interval=_STATUS_UPDATE_CHECK_INTERNAL, + task_name=f"{__name__}.{self._status_update.__name__}", + ) + await self._started_event_task_status_update.wait() async def teardown(self) -> None: - task_ids_to_remove: deque[TaskId] = deque() + # ensure all created tasks are cancelled + for tracked_task in await self._tasks_data.list_tasks_data(): + await self.remove_task( + tracked_task.task_id, + tracked_task.task_context, + # when closing we do not care about pending errors + reraise_errors=False, + ) - for tracked_task in self._tracked_tasks.values(): - task_ids_to_remove.append(tracked_task.task_id) + for task in self._created_tasks.values(): + _logger.warning( + "Task %s was not completed before shutdown, cancelling it", + task.get_name(), + ) + await cancel_wait_task(task) - for task_id in task_ids_to_remove: - # when closing we do not care about pending errors - await self.remove_task(task_id, None, reraise_errors=False) + # stale_tasks_monitor + if self._task_stale_tasks_monitor: + await cancel_wait_task(self._task_stale_tasks_monitor) - if self._stale_tasks_monitor_task: - with log_catch(_logger, reraise=False): - await cancel_wait_task( - self._stale_tasks_monitor_task, max_delay=_CANCEL_TASK_TIMEOUT - ) + # cancelled_tasks_removal + if self._task_cancelled_tasks_removal: + await cancel_wait_task(self._task_cancelled_tasks_removal) + + # status_update + if self._task_status_update: + await cancel_wait_task(self._task_status_update) - async def _stale_tasks_monitor_worker(self) -> None: + if self.locks_redis_client_sdk is not None: + await self.locks_redis_client_sdk.shutdown() + + await self._tasks_data.shutdown() + + async def _stale_tasks_monitor(self) -> None: """ A task is considered stale, if the task status is not queried in the last `stale_task_detect_timeout_s` and it is not a fire and forget type of task. @@ -155,12 +221,14 @@ async def _stale_tasks_monitor_worker(self) -> None: # Since we own the client, we assume (for now) this # will not be the case. - tasks_to_remove = _get_tasks_to_remove( - self._tracked_tasks, self.stale_task_detect_timeout_s + self._started_event_task_stale_tasks_monitor.set() + + tasks_to_remove = await _get_tasks_to_remove( + self._tasks_data, self.stale_task_detect_timeout_s ) # finally remove tasks and warn - for task_id in tasks_to_remove: + for task_id, task_context in tasks_to_remove: # NOTE: task can be in the following cases: # - still ongoing # - finished with a result @@ -169,59 +237,95 @@ async def _stale_tasks_monitor_worker(self) -> None: _logger.warning( "Removing stale task '%s' with status '%s'", task_id, - self.get_task_status(task_id, with_task_context=None).model_dump_json(), + ( + await self.get_task_status(task_id, with_task_context=task_context) + ).model_dump_json(), ) await self.remove_task( - task_id, with_task_context=None, reraise_errors=False + task_id, with_task_context=task_context, reraise_errors=False ) - def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBase]: + async def _cancelled_tasks_removal(self) -> None: + """ + A task can be cancelled by the client, which implies it does not for sure + run in the same process as the one processing the request. + + This is a periodic task that ensures the cancellation occurred. + """ + self._started_event_task_cancelled_tasks_removal.set() + + cancelled_tasks = await self._tasks_data.get_cancelled() + for task_id, task_context in cancelled_tasks.items(): + await self.remove_task(task_id, task_context, reraise_errors=False) + + async def _status_update(self) -> None: + """ + A task which monitors locally running tasks and updates their status + in the Redis store when they are done. + """ + self._started_event_task_status_update.set() + task_id: TaskId + for task_id in set(self._created_tasks.keys()): + if task := self._created_tasks.get(task_id, None): + is_done = task.done() + if not is_done: + # task is still running, do not update + continue + + # write to redis only when done + task_data = await self._tasks_data.get_task_data(task_id) + if task_data is None or task_data.is_done: + # already done and updatet data in redis + continue + + # update and store in Redis + task_data.is_done = is_done + + # get task result + try: + task_data.result_field = ResultField( + result=object_to_string(task.result()) + ) + except asyncio.InvalidStateError: + # task was not completed try again next time and see if it is done + continue + except asyncio.CancelledError: + task_data.result_field = ResultField( + error=object_to_string(TaskCancelledError(task_id=task_id)) + ) + except Exception as e: # pylint:disable=broad-except + task_data.result_field = ResultField(error=object_to_string(e)) + + await self._tasks_data.set_task_data(task_id, task_data) + + async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBase]: if not with_task_context: return [ - TaskBase(task_id=task.task_id) for task in self._tracked_tasks.values() + TaskBase(task_id=task.task_id) + for task in (await self._tasks_data.list_tasks_data()) ] return [ TaskBase(task_id=task.task_id) - for task in self._tracked_tasks.values() + for task in (await self._tasks_data.list_tasks_data()) if task.task_context == with_task_context ] - def _add_task( - self, - task: asyncio.Task, - task_progress: TaskProgress, - task_context: TaskContext, - task_id: TaskId, - *, - fire_and_forget: bool, - ) -> TrackedTask: - tracked_task = TrackedTask( - task_id=task_id, - task=task, - task_progress=task_progress, - task_context=task_context, - fire_and_forget=fire_and_forget, - ) - self._tracked_tasks[task_id] = tracked_task + async def _get_tracked_task( + self, task_id: TaskId, with_task_context: TaskContext + ) -> TaskData: + task_data = await self._tasks_data.get_task_data(task_id) - return tracked_task - - def _get_tracked_task( - self, task_id: TaskId, with_task_context: TaskContext | None - ) -> TrackedTask: - if task_id not in self._tracked_tasks: + if task_data is None: raise TaskNotFoundError(task_id=task_id) - task = self._tracked_tasks[task_id] - - if with_task_context and task.task_context != with_task_context: + if with_task_context and task_data.task_context != with_task_context: raise TaskNotFoundError(task_id=task_id) - return task + return task_data - def get_task_status( - self, task_id: TaskId, with_task_context: TaskContext | None + async def get_task_status( + self, task_id: TaskId, with_task_context: TaskContext ) -> TaskStatus: """ returns: the status of the task, along with updates @@ -229,22 +333,20 @@ def get_task_status( raises TaskNotFoundError if the task cannot be found """ - tracked_task: TrackedTask = self._get_tracked_task(task_id, with_task_context) - tracked_task.last_status_check = datetime.datetime.now(tz=datetime.UTC) - - task = tracked_task.task - done = task.done() + task_data: TaskData = await self._get_tracked_task(task_id, with_task_context) + task_data.last_status_check = datetime.datetime.now(tz=datetime.UTC) + await self._tasks_data.set_task_data(task_id, task_data) return TaskStatus.model_validate( { - "task_progress": tracked_task.task_progress, - "done": done, - "started": tracked_task.started, + "task_progress": task_data.task_progress, + "done": task_data.is_done, + "started": task_data.started, } ) - def get_task_result( - self, task_id: TaskId, with_task_context: TaskContext | None + async def get_task_result( + self, task_id: TaskId, with_task_context: TaskContext ) -> Any: """ returns: the result of the task @@ -253,86 +355,96 @@ def get_task_result( raises TaskCancelledError if the task was cancelled raises TaskNotCompletedError if the task is not completed """ - tracked_task = self._get_tracked_task(task_id, with_task_context) + tracked_task = await self._get_tracked_task(task_id, with_task_context) - try: - return tracked_task.task.result() - except asyncio.InvalidStateError as exc: - # the task is not ready - raise TaskNotCompletedError(task_id=task_id) from exc - except asyncio.CancelledError as exc: - # the task was cancelled - raise TaskCancelledError(task_id=task_id) from exc - - async def cancel_task( - self, task_id: TaskId, with_task_context: TaskContext | None - ) -> None: - """ - cancels the task + if not tracked_task.is_done or tracked_task.result_field is None: + raise TaskNotCompletedError(task_id=task_id) - raises TaskNotFoundError if the task cannot be found - """ - tracked_task = self._get_tracked_task(task_id, with_task_context) - await self._cancel_tracked_task(tracked_task.task, task_id, reraise_errors=True) + if tracked_task.result_field.error is not None: + raise string_to_object(tracked_task.result_field.error) - @staticmethod - async def _cancel_asyncio_task( - task: asyncio.Task, reference: str, *, reraise_errors: bool - ) -> None: - task.cancel() - with suppress(asyncio.CancelledError): - try: - try: - await asyncio.wait_for( - _await_task(task), timeout=_CANCEL_TASK_TIMEOUT - ) - except TimeoutError: - _logger.warning( - "Timed out while awaiting for cancellation of '%s'", reference - ) - except Exception: # pylint:disable=broad-except - if reraise_errors: - raise + if tracked_task.result_field.result is None: + return None + + return string_to_object(tracked_task.result_field.result) async def _cancel_tracked_task( - self, task: asyncio.Task, task_id: TaskId, *, reraise_errors: bool + self, task: asyncio.Task, task_id: TaskId, with_task_context: TaskContext ) -> None: try: - await self._cancel_asyncio_task( - task, task_id, reraise_errors=reraise_errors - ) + await self._tasks_data.set_as_cancelled(task_id, with_task_context) + await cancel_wait_task(task) except Exception as e: # pylint:disable=broad-except - formatted_traceback = "".join(traceback.format_exception(e)) - raise TaskExceptionError( - task_id=task_id, exception=e, traceback=formatted_traceback - ) from e + _logger.info( + "Task %s cancellation failed with error: %s", + task_id, + e, + stack_info=True, + ) async def remove_task( self, task_id: TaskId, - with_task_context: TaskContext | None, + with_task_context: TaskContext, *, reraise_errors: bool = True, ) -> None: """cancels and removes task""" try: - tracked_task = self._get_tracked_task(task_id, with_task_context) + tracked_task = await self._get_tracked_task(task_id, with_task_context) except TaskNotFoundError: if reraise_errors: raise return - try: - await self._cancel_tracked_task( - tracked_task.task, task_id, reraise_errors=reraise_errors - ) - finally: - del self._tracked_tasks[task_id] + + if tracked_task.task_id in self._created_tasks: + task_to_cancel = self._created_tasks.pop(tracked_task.task_id, None) + if task_to_cancel is not None: + # canceling the task affects the worker that started it. + # awaiting the cancelled task is a must since if the CancelledError + # was intercepted, those actions need to finish + await cancel_wait_task(task_to_cancel) + + await self._tasks_data.delete_task_data(task_id) + + # wait for task to be removed since it might not have been running + # in this process + async for attempt in AsyncRetrying( + wait=wait_exponential(max=2), + stop=stop_after_delay(_TASK_REMOVAL_MAX_WAIT), + retry=retry_if_exception_type(TryAgain), + ): + with attempt: + try: + await self._get_tracked_task(task_id, with_task_context) + raise TryAgain + except TaskNotFoundError: + pass def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId: unique_part = "unique" if is_unique else f"{uuid4()}" - return f"{self.namespace}.{task_name}.{unique_part}" + return f"{self.redis_namespace}.{task_name}.{unique_part}" + + async def _update_progress( + self, + task_id: TaskId, + task_context: TaskContext, + task_progress: TaskProgress, + ) -> None: + # NOTE: avoids errors while updating progress, since the task could have been + # cancelled and it's data removed + try: + tracked_data = await self._get_tracked_task(task_id, task_context) + tracked_data.task_progress = task_progress + await self._tasks_data.set_task_data(task_id=task_id, value=tracked_data) + except TaskNotFoundError: + _logger.debug( + "Task '%s' not found while updating progress %s", + task_id, + task_progress, + ) - def start_task( + async def start_task( self, registered_task_name: RegisteredTaskName, *, @@ -357,43 +469,48 @@ def start_task( task_id = self._get_task_id(task_name, is_unique=unique) # only one unique task can be running - if unique and task_id in self._tracked_tasks: + queried_task = await self._tasks_data.get_task_data(task_id) + if unique and queried_task is not None: raise TaskAlreadyRunningError( - task_name=task_name, managed_task=self._tracked_tasks[task_id] + task_name=task_name, managed_task=queried_task ) + context_to_use = task_context or {} task_progress = TaskProgress.create(task_id=task_id) + # set update callback + task_progress.set_update_callback( + functools.partial(self._update_progress, task_id, context_to_use) + ) - # bind the task with progress 0 and 1 - async def _progress_task(progress: TaskProgress, handler: TaskProtocol): - progress.update(message="starting", percent=0) + async def _task_with_progress(progress: TaskProgress, handler: TaskProtocol): + # bind the task with progress 0 and 1 + await progress.update(message="starting", percent=0) try: return await handler(progress, **task_kwargs) finally: - progress.update(message="finished", percent=1) + await progress.update(message="finished", percent=1) - async_task = asyncio.create_task( - _progress_task(task_progress, task), name=task_name + self._created_tasks[task_id] = asyncio.create_task( + _task_with_progress(task_progress, task), name=task_name ) - tracked_task = self._add_task( - task=async_task, + tracked_task = TaskData( + task_id=task_id, task_progress=task_progress, - task_context=task_context or {}, + task_context=context_to_use, fire_and_forget=fire_and_forget, - task_id=task_id, ) - + await self._tasks_data.set_task_data(task_id, tracked_task) return tracked_task.task_id __all__: tuple[str, ...] = ( "TaskAlreadyRunningError", "TaskCancelledError", + "TaskData", "TaskId", "TaskProgress", "TaskProtocol", "TaskStatus", "TasksManager", - "TrackedTask", ) diff --git a/packages/service-library/src/servicelib/rabbitmq/_client_base.py b/packages/service-library/src/servicelib/rabbitmq/_client_base.py index 69720659e50..09cea1d4452 100644 --- a/packages/service-library/src/servicelib/rabbitmq/_client_base.py +++ b/packages/service-library/src/servicelib/rabbitmq/_client_base.py @@ -6,9 +6,10 @@ import aio_pika import aiormq -from servicelib.logging_utils import log_catch from settings_library.rabbit import RabbitSettings +from ..logging_utils import log_catch + _DEFAULT_RABBITMQ_SERVER_HEARTBEAT_S: Final[int] = 60 _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/containers.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/containers.py index 2049f0a409f..0e64ff62506 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/containers.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/containers.py @@ -6,8 +6,9 @@ from models_library.projects_nodes_io import NodeID from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace from pydantic import NonNegativeInt, TypeAdapter -from servicelib.logging_utils import log_decorator -from servicelib.rabbitmq import RabbitMQRPCClient + +from ....logging_utils import log_decorator +from ....rabbitmq import RabbitMQRPCClient _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/volumes.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/volumes.py index 41cf2ffd8b8..07f8f961750 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/volumes.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/agent/volumes.py @@ -6,8 +6,9 @@ from models_library.projects_nodes_io import NodeID from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace from pydantic import NonNegativeInt, TypeAdapter -from servicelib.logging_utils import log_decorator -from servicelib.rabbitmq import RabbitMQRPCClient + +from ....logging_utils import log_decorator +from ....rabbitmq import RabbitMQRPCClient _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/dynamic_scheduler/services.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/dynamic_scheduler/services.py index fb3276ae670..edf4a480c1f 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/dynamic_scheduler/services.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/dynamic_scheduler/services.py @@ -18,8 +18,9 @@ from models_library.services_types import ServicePortKey from models_library.users import UserID from pydantic import NonNegativeInt, TypeAdapter -from servicelib.logging_utils import log_decorator -from servicelib.rabbitmq import RabbitMQRPCClient + +from ....logging_utils import log_decorator +from ....rabbitmq import RabbitMQRPCClient _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/auth/api_keys.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/auth/api_keys.py index 0358a0e3b6a..4530fc5c63a 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/auth/api_keys.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/auth/api_keys.py @@ -6,8 +6,9 @@ from models_library.rpc.webserver.auth.api_keys import ApiKeyCreate, ApiKeyGet from models_library.users import UserID from pydantic import TypeAdapter -from servicelib.logging_utils import log_decorator -from servicelib.rabbitmq import RabbitMQRPCClient + +from .....logging_utils import log_decorator +from .....rabbitmq import RabbitMQRPCClient _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/licenses/licensed_items.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/licenses/licensed_items.py index acb367de27b..3df94f40d4d 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/licenses/licensed_items.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/licenses/licensed_items.py @@ -15,8 +15,9 @@ from models_library.users import UserID from models_library.wallets import WalletID from pydantic import TypeAdapter -from servicelib.logging_utils import log_decorator -from servicelib.rabbitmq import RabbitMQRPCClient + +from .....logging_utils import log_decorator +from .....rabbitmq import RabbitMQRPCClient _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/projects.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/projects.py index 15f40d66011..1f01f453036 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/projects.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/projects.py @@ -16,8 +16,9 @@ ) from models_library.users import UserID from pydantic import TypeAdapter, validate_call -from servicelib.logging_utils import log_decorator -from servicelib.rabbitmq import RabbitMQRPCClient + +from ....logging_utils import log_decorator +from ....rabbitmq import RabbitMQRPCClient _logger = logging.getLogger(__name__) diff --git a/packages/service-library/src/servicelib/redis/_client.py b/packages/service-library/src/servicelib/redis/_client.py index e961a6a73e9..de407a74fe8 100644 --- a/packages/service-library/src/servicelib/redis/_client.py +++ b/packages/service-library/src/servicelib/redis/_client.py @@ -8,6 +8,7 @@ import redis.asyncio as aioredis import redis.exceptions +import tenacity from common_library.async_tools import cancel_wait_task from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry @@ -24,8 +25,18 @@ _logger = logging.getLogger(__name__) -# SEE https://github.com/ITISFoundation/osparc-simcore/pull/7077 -_HEALTHCHECK_TASK_TIMEOUT_S: Final[float] = 3.0 +_HEALTHCHECK_TIMEOUT_S: Final[float] = 3.0 + + +@tenacity.retry( + wait=tenacity.wait_fixed(2), + stop=tenacity.stop_after_delay(20), + before_sleep=tenacity.before_sleep_log(_logger, logging.INFO), + reraise=True, +) +async def wait_till_redis_is_responsive(client: aioredis.Redis) -> None: + if not await client.ping(): + raise tenacity.TryAgain @dataclass @@ -36,8 +47,8 @@ class RedisClientSDK: health_check_interval: datetime.timedelta = DEFAULT_HEALTH_CHECK_INTERVAL _client: aioredis.Redis = field(init=False) - _health_check_task: Task | None = None - _health_check_task_started_event: asyncio.Event | None = None + _task_health_check: Task | None = None + _started_event_task_health_check: asyncio.Event | None = None _is_healthy: bool = False @property @@ -59,21 +70,23 @@ def __post_init__(self) -> None: decode_responses=self.decode_responses, client_name=self.client_name, ) - # NOTE: connection is done here already self._is_healthy = False - self._health_check_task_started_event = asyncio.Event() + self._started_event_task_health_check = asyncio.Event() + async def setup(self) -> None: @periodic(interval=self.health_check_interval) async def _periodic_check_health() -> None: - assert self._health_check_task_started_event # nosec - self._health_check_task_started_event.set() + assert self._started_event_task_health_check # nosec + self._started_event_task_health_check.set() self._is_healthy = await self.ping() - self._health_check_task = asyncio.create_task( + self._task_health_check = asyncio.create_task( _periodic_check_health(), name=f"redis_service_health_check_{self.redis_dsn}__{uuid4()}", ) + await wait_till_redis_is_responsive(self._client) + _logger.info( "Connection to %s succeeded with %s", f"redis at {self.redis_dsn=}", @@ -84,12 +97,12 @@ async def shutdown(self) -> None: with log_context( _logger, level=logging.DEBUG, msg=f"Shutdown RedisClientSDK {self}" ): - if self._health_check_task: - assert self._health_check_task_started_event # nosec - # NOTE: wait for the health check task to have started once before we can cancel it - await self._health_check_task_started_event.wait() + if self._task_health_check: + assert self._started_event_task_health_check # nosec + await self._started_event_task_health_check.wait() + await cancel_wait_task( - self._health_check_task, max_delay=_HEALTHCHECK_TASK_TIMEOUT_S + self._task_health_check, max_delay=_HEALTHCHECK_TIMEOUT_S ) await self._client.aclose(close_connection_pool=True) @@ -97,8 +110,9 @@ async def shutdown(self) -> None: async def ping(self) -> bool: with log_catch(_logger, reraise=False): # NOTE: retry_* input parameters from aioredis.from_url do not apply for the ping call - await self._client.ping() + await asyncio.wait_for(self._client.ping(), timeout=_HEALTHCHECK_TIMEOUT_S) return True + return False @property diff --git a/packages/service-library/src/servicelib/redis/_clients_manager.py b/packages/service-library/src/servicelib/redis/_clients_manager.py index 60b93360b88..758977f8526 100644 --- a/packages/service-library/src/servicelib/redis/_clients_manager.py +++ b/packages/service-library/src/servicelib/redis/_clients_manager.py @@ -27,6 +27,7 @@ async def setup(self) -> None: health_check_interval=config.health_check_interval, client_name=f"{self.client_name}", ) + await self._client_sdks[config.database].setup() async def shutdown(self) -> None: await asyncio.gather( diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py b/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py index bd8d77da3ce..917cd335c65 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py @@ -36,7 +36,7 @@ async def _string_list_task( for index in range(num_strings): generated_strings.append(f"{index}") await asyncio.sleep(sleep_time) - progress.update(message="generated item", percent=index / num_strings) + await progress.update(message="generated item", percent=index / num_strings) if fail: msg = "We were asked to fail!!" raise RuntimeError(msg) diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py index b904d766d10..94eaecef7e3 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py @@ -23,19 +23,35 @@ from servicelib.aiohttp.rest_middlewares import append_rest_middlewares from servicelib.long_running_tasks.models import TaskGet, TaskId, TaskStatus from servicelib.long_running_tasks.task import TaskContext +from settings_library.redis import RedisSettings from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type from tenacity.stop import stop_after_delay from tenacity.wait import wait_fixed +pytest_simcore_core_services_selection = [ + "redis", +] + +pytest_simcore_ops_services_selection = [ + "redis-commander", +] + @pytest.fixture -def app(server_routes: web.RouteTableDef) -> web.Application: +def app( + server_routes: web.RouteTableDef, redis_service: RedisSettings +) -> web.Application: app = web.Application() app.add_routes(server_routes) # this adds enveloping and error middlewares append_rest_middlewares(app, api_version="") - long_running_tasks.server.setup(app, router_prefix="/futures") + long_running_tasks.server.setup( + app, + redis_settings=redis_service, + redis_namespace="test", + router_prefix="/futures", + ) return app diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py index b211cc3d1ca..0d626519734 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py @@ -15,16 +15,24 @@ long_running_task_request, ) from servicelib.aiohttp.rest_middlewares import append_rest_middlewares +from settings_library.redis import RedisSettings from yarl import URL @pytest.fixture -def app(server_routes: web.RouteTableDef) -> web.Application: +def app( + server_routes: web.RouteTableDef, use_in_memory_redis: RedisSettings +) -> web.Application: app = web.Application() app.add_routes(server_routes) # this adds enveloping and error middlewares append_rest_middlewares(app, api_version="") - long_running_tasks.server.setup(app, router_prefix="/futures") + long_running_tasks.server.setup( + app, + redis_settings=use_in_memory_redis, + redis_namespace="test", + router_prefix="/futures", + ) return app diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py index 8f7ff5efd23..e303d9362ad 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py @@ -27,7 +27,11 @@ from servicelib.aiohttp.typing_extension import Handler from servicelib.long_running_tasks.models import TaskGet, TaskId from servicelib.long_running_tasks.task import TaskContext +from settings_library.redis import RedisSettings +pytest_simcore_core_services_selection = [ + "redis", +] # WITH TASK CONTEXT # NOTE: as the long running task framework may be used in any number of services # in some cases there might be specific so-called task contexts. @@ -61,7 +65,9 @@ async def _test_task_context_decorator( @pytest.fixture def app_with_task_context( - server_routes: web.RouteTableDef, task_context_decorator + server_routes: web.RouteTableDef, + task_context_decorator, + redis_service: RedisSettings, ) -> web.Application: app = web.Application() app.add_routes(server_routes) @@ -69,6 +75,8 @@ def app_with_task_context( append_rest_middlewares(app, api_version="") long_running_tasks.server.setup( app, + redis_settings=redis_service, + redis_namespace="test", router_prefix="/futures_with_task_context", task_request_context_decorator=task_context_decorator, ) diff --git a/packages/service-library/tests/conftest.py b/packages/service-library/tests/conftest.py index d123e16f12e..845a8565d22 100644 --- a/packages/service-library/tests/conftest.py +++ b/packages/service-library/tests/conftest.py @@ -1,3 +1,4 @@ +# pylint: disable=contextmanager-generator-missing-cleanup # pylint: disable=redefined-outer-name # pylint: disable=unused-argument # pylint: disable=unused-import @@ -12,7 +13,6 @@ import pytest import servicelib from faker import Faker -from pytest_mock import MockerFixture from servicelib.redis import RedisClientSDK, RedisClientsManager, RedisManagerDBConfig from settings_library.redis import RedisDatabase, RedisSettings @@ -69,12 +69,10 @@ def fake_data_dict(faker: Faker) -> dict[str, Any]: return data -@pytest.fixture -async def get_redis_client_sdk( - mock_redis_socket_timeout: None, - mocker: MockerFixture, - redis_service: RedisSettings, -) -> AsyncIterable[ +@asynccontextmanager +async def _get_redis_client_sdk( + redis_settings: RedisSettings, +) -> AsyncIterator[ Callable[[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]] ]: @asynccontextmanager @@ -82,10 +80,11 @@ async def _( database: RedisDatabase, decode_response: bool = True, # noqa: FBT002 ) -> AsyncIterator[RedisClientSDK]: - redis_resources_dns = redis_service.build_redis_dsn(database) + redis_resources_dns = redis_settings.build_redis_dsn(database) client = RedisClientSDK( redis_resources_dns, decode_responses=decode_response, client_name="pytest" ) + await client.setup() assert client assert client.redis_dsn == redis_resources_dns assert client.client_name == "pytest" @@ -100,9 +99,29 @@ async def _cleanup_redis_data(clients_manager: RedisClientsManager) -> None: async with RedisClientsManager( {RedisManagerDBConfig(database=db) for db in RedisDatabase}, - redis_service, + redis_settings, client_name="pytest", ) as clients_manager: await _cleanup_redis_data(clients_manager) yield _ await _cleanup_redis_data(clients_manager) + + +@pytest.fixture +async def get_redis_client_sdk( + mock_redis_socket_timeout: None, use_in_memory_redis: RedisSettings +) -> AsyncIterable[ + Callable[[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]] +]: + async with _get_redis_client_sdk(use_in_memory_redis) as client: + yield client + + +@pytest.fixture +async def get_in_process_redis_client_sdk( + mock_redis_socket_timeout: None, redis_service: RedisSettings +) -> AsyncIterable[ + Callable[[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]] +]: + async with _get_redis_client_sdk(redis_service) as client: + yield client diff --git a/packages/service-library/tests/deferred_tasks/conftest.py b/packages/service-library/tests/deferred_tasks/conftest.py index 00881e61471..e5d8849e7e1 100644 --- a/packages/service-library/tests/deferred_tasks/conftest.py +++ b/packages/service-library/tests/deferred_tasks/conftest.py @@ -8,11 +8,11 @@ @pytest.fixture async def redis_client_sdk_deferred_tasks( - get_redis_client_sdk: Callable[ + get_in_process_redis_client_sdk: Callable[ [RedisDatabase, bool], AbstractAsyncContextManager[RedisClientSDK] - ] + ], ) -> AsyncIterator[RedisClientSDK]: - async with get_redis_client_sdk( + async with get_in_process_redis_client_sdk( RedisDatabase.DEFERRED_TASKS, decode_response=False ) as client: yield client diff --git a/packages/service-library/tests/deferred_tasks/example_app.py b/packages/service-library/tests/deferred_tasks/example_app.py index 9adb654e896..991aa2efe8e 100644 --- a/packages/service-library/tests/deferred_tasks/example_app.py +++ b/packages/service-library/tests/deferred_tasks/example_app.py @@ -95,6 +95,7 @@ def __init__( ) async def setup(self) -> None: + await self._redis_client.setup() await self._manager.setup() diff --git a/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py b/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py index cc19133b6b2..6a3a872c861 100644 --- a/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py +++ b/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py @@ -34,7 +34,6 @@ pytest_simcore_core_services_selection = [ "rabbit", - "redis", ] @@ -43,7 +42,8 @@ class MockKeys(StrAutoEnum): GET_TIMEOUT = auto() START_DEFERRED = auto() ON_DEFERRED_CREATED = auto() - RUN_DEFERRED = auto() + RUN_DEFERRED_BEFORE_HANDLER = auto() + RUN_DEFERRED_AFTER_HANDLER = auto() ON_DEFERRED_RESULT = auto() ON_FINISHED_WITH_ERROR = auto() @@ -57,6 +57,7 @@ async def redis_client_sdk( decode_responses=False, client_name="pytest", ) + await sdk.setup() yield sdk await sdk.shutdown() @@ -122,8 +123,9 @@ async def on_created( @classmethod async def run(cls, context: DeferredContext) -> Any: + mocks[MockKeys.RUN_DEFERRED_BEFORE_HANDLER](context) result = await run(context) - mocks[MockKeys.RUN_DEFERRED](context) + mocks[MockKeys.RUN_DEFERRED_AFTER_HANDLER](context) return result @classmethod @@ -229,8 +231,8 @@ async def _run_ok(_: DeferredContext) -> Any: await _assert_mock_call(mocks, key=MockKeys.ON_DEFERRED_CREATED, count=1) assert TaskUID(mocks[MockKeys.ON_DEFERRED_CREATED].call_args_list[0].args[0]) - await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED, count=1) - mocks[MockKeys.RUN_DEFERRED].assert_called_once_with(context) + await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED_AFTER_HANDLER, count=1) + mocks[MockKeys.RUN_DEFERRED_AFTER_HANDLER].assert_called_once_with(context) await _assert_mock_call(mocks, key=MockKeys.ON_DEFERRED_RESULT, count=1) mocks[MockKeys.ON_DEFERRED_RESULT].assert_called_once_with(run_return, context) @@ -282,7 +284,7 @@ async def _run_raises(_: DeferredContext) -> None: count=retry_count, ) - await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED, count=0) + await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED_AFTER_HANDLER, count=0) await _assert_mock_call(mocks, key=MockKeys.ON_DEFERRED_RESULT, count=0) await _assert_log_message( @@ -319,6 +321,7 @@ async def _run_to_cancel(_: DeferredContext) -> None: await _assert_mock_call(mocks, key=MockKeys.ON_DEFERRED_CREATED, count=1) task_uid = TaskUID(mocks[MockKeys.ON_DEFERRED_CREATED].call_args_list[0].args[0]) + await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED_BEFORE_HANDLER, count=1) await mocked_deferred_handler.cancel(task_uid) await _assert_mock_call(mocks, key=MockKeys.ON_FINISHED_WITH_ERROR, count=0) @@ -330,7 +333,7 @@ async def _run_to_cancel(_: DeferredContext) -> None: == 0 ) - await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED, count=0) + await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED_AFTER_HANDLER, count=0) await _assert_mock_call(mocks, key=MockKeys.ON_DEFERRED_RESULT, count=0) await _assert_log_message( @@ -450,7 +453,7 @@ async def _run_that_times_out(_: DeferredContext) -> None: for entry in mocks[MockKeys.ON_FINISHED_WITH_ERROR].call_args_list: assert "builtins.TimeoutError" in entry.args[0].error - await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED, count=0) + await _assert_mock_call(mocks, key=MockKeys.RUN_DEFERRED_AFTER_HANDLER, count=0) await _assert_mock_call(mocks, key=MockKeys.ON_DEFERRED_RESULT, count=0) diff --git a/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py b/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py index 7d11d257153..7e60c71cb30 100644 --- a/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py +++ b/packages/service-library/tests/deferred_tasks/test_deferred_tasks.py @@ -333,8 +333,7 @@ async def rabbit_client( class ClientWithPingProtocol(Protocol): - async def ping(self) -> bool: - ... + async def ping(self) -> bool: ... class ServiceManager: diff --git a/packages/service-library/tests/fastapi/long_running_tasks/conftest.py b/packages/service-library/tests/fastapi/long_running_tasks/conftest.py index d43a7e445c1..0cab1161c09 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/conftest.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/conftest.py @@ -9,13 +9,19 @@ from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from servicelib.fastapi import long_running_tasks +from settings_library.redis import RedisSettings @pytest.fixture -async def bg_task_app(router_prefix: str) -> FastAPI: +async def bg_task_app(router_prefix: str, redis_service: RedisSettings) -> FastAPI: app = FastAPI() - long_running_tasks.server.setup(app, router_prefix=router_prefix) + long_running_tasks.server.setup( + app, + redis_settings=redis_service, + redis_namespace="test", + router_prefix=router_prefix, + ) return app diff --git a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py index 5ebbdb744e0..0f10b7a165f 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py @@ -33,6 +33,7 @@ TaskStatus, ) from servicelib.long_running_tasks.task import TaskContext, TaskRegistry +from settings_library.redis import RedisSettings from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type from tenacity.stop import stop_after_delay @@ -52,7 +53,7 @@ async def _string_list_task( for index in range(num_strings): generated_strings.append(f"{index}") await asyncio.sleep(sleep_time) - progress.update(message="generated item", percent=index / num_strings) + await progress.update(message="generated item", percent=index / num_strings) if fail: msg = "We were asked to fail!!" raise RuntimeError(msg) @@ -91,11 +92,13 @@ async def create_string_list_task( @pytest.fixture -async def app(server_routes: APIRouter) -> AsyncIterator[FastAPI]: +async def app( + server_routes: APIRouter, use_in_memory_redis: RedisSettings +) -> AsyncIterator[FastAPI]: # overrides fastapi/conftest.py:app app = FastAPI(title="test app") app.include_router(server_routes) - setup_server(app) + setup_server(app, redis_settings=use_in_memory_redis, redis_namespace="test") setup_client(app) async with LifespanManager(app): yield app diff --git a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py index 85b59eb1a35..179c967088f 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py @@ -28,6 +28,7 @@ TaskProgress, ) from servicelib.long_running_tasks.task import TaskRegistry +from settings_library.redis import RedisSettings TASK_SLEEP_INTERVAL: Final[PositiveFloat] = 0.1 @@ -89,13 +90,18 @@ async def create_task_which_fails( @pytest.fixture async def bg_task_app( - user_routes: APIRouter, router_prefix: str + user_routes: APIRouter, router_prefix: str, use_in_memory_redis: RedisSettings ) -> AsyncIterable[FastAPI]: app = FastAPI() app.include_router(user_routes) - setup_server(app, router_prefix=router_prefix) + setup_server( + app, + router_prefix=router_prefix, + redis_settings=use_in_memory_redis, + redis_namespace="test", + ) setup_client(app, router_prefix=router_prefix) async with LifespanManager(app): diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__error_serialization.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__error_serialization.py new file mode 100644 index 00000000000..f0d0d14f165 --- /dev/null +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__error_serialization.py @@ -0,0 +1,47 @@ +from typing import Any + +import pytest +from aiohttp.web import HTTPException, HTTPInternalServerError +from servicelib.aiohttp.long_running_tasks._server import AiohttpHTTPExceptionSerializer +from servicelib.long_running_tasks._redis_serialization import ( + object_to_string, + register_custom_serialization, + string_to_object, +) + +register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer) + + +class PositionalArguments: + def __init__(self, arg1, arg2, *args): + self.arg1 = arg1 + self.arg2 = arg2 + self.args = args + + +class MixedArguments: + def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): + self.arg1 = arg1 + self.arg2 = arg2 + self.kwarg1 = kwarg1 + self.kwarg2 = kwarg2 + + +@pytest.mark.parametrize( + "obj", + [ + HTTPInternalServerError(reason="Uh-oh!", text="Failure!"), + PositionalArguments("arg1", "arg2", "arg3", "arg4"), + MixedArguments("arg1", "arg2", kwarg1="kwarg1", kwarg2="kwarg2"), + "a_string", + 1, + ], +) +def test_serialization(obj: Any): + str_data = object_to_string(obj) + + reconstructed_obj = string_to_object(str_data) + + assert type(reconstructed_obj) is type(obj) + if hasattr(obj, "__dict__"): + assert reconstructed_obj.__dict__ == obj.__dict__ diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_serialization.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_serialization.py new file mode 100644 index 00000000000..f0d0d14f165 --- /dev/null +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_serialization.py @@ -0,0 +1,47 @@ +from typing import Any + +import pytest +from aiohttp.web import HTTPException, HTTPInternalServerError +from servicelib.aiohttp.long_running_tasks._server import AiohttpHTTPExceptionSerializer +from servicelib.long_running_tasks._redis_serialization import ( + object_to_string, + register_custom_serialization, + string_to_object, +) + +register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer) + + +class PositionalArguments: + def __init__(self, arg1, arg2, *args): + self.arg1 = arg1 + self.arg2 = arg2 + self.args = args + + +class MixedArguments: + def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): + self.arg1 = arg1 + self.arg2 = arg2 + self.kwarg1 = kwarg1 + self.kwarg2 = kwarg2 + + +@pytest.mark.parametrize( + "obj", + [ + HTTPInternalServerError(reason="Uh-oh!", text="Failure!"), + PositionalArguments("arg1", "arg2", "arg3", "arg4"), + MixedArguments("arg1", "arg2", kwarg1="kwarg1", kwarg2="kwarg2"), + "a_string", + 1, + ], +) +def test_serialization(obj: Any): + str_data = object_to_string(obj) + + reconstructed_obj = string_to_object(str_data) + + assert type(reconstructed_obj) is type(obj) + if hasattr(obj, "__dict__"): + assert reconstructed_obj.__dict__ == obj.__dict__ diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py new file mode 100644 index 00000000000..bd4586be648 --- /dev/null +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py @@ -0,0 +1,105 @@ +# pylint:disable=redefined-outer-name + +from collections.abc import AsyncIterable, Callable +from contextlib import AbstractAsyncContextManager + +import pytest +from pydantic import TypeAdapter +from servicelib.long_running_tasks._store.base import BaseStore +from servicelib.long_running_tasks._store.redis import RedisStore +from servicelib.long_running_tasks.models import TaskData +from servicelib.redis._client import RedisClientSDK +from settings_library.redis import RedisDatabase, RedisSettings + + +@pytest.fixture +def task_data() -> TaskData: + return TypeAdapter(TaskData).validate_python( + TaskData.model_json_schema()["examples"][0] + ) + + +@pytest.fixture +async def store( + use_in_memory_redis: RedisSettings, + get_redis_client_sdk: Callable[ + [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] + ], +) -> AsyncIterable[BaseStore]: + store = RedisStore(redis_settings=use_in_memory_redis, namespace="test") + + await store.setup() + yield store + await store.shutdown() + + # triggers cleanup of all redis data + async with get_redis_client_sdk(RedisDatabase.LONG_RUNNING_TASKS): + pass + + +async def test_workflow(store: BaseStore, task_data: TaskData) -> None: + # task data + assert await store.list_tasks_data() == [] + assert await store.get_task_data("missing") is None + + await store.set_task_data(task_data.task_id, task_data) + + assert await store.list_tasks_data() == [task_data] + + await store.delete_task_data(task_data.task_id) + + assert await store.list_tasks_data() == [] + + # cancelled tasks + assert await store.get_cancelled() == {} + + await store.set_as_cancelled(task_data.task_id, task_data.task_context) + + assert await store.get_cancelled() == {task_data.task_id: task_data.task_context} + + +@pytest.fixture +async def redis_stores( + use_in_memory_redis: RedisSettings, + get_redis_client_sdk: Callable[ + [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] + ], +) -> AsyncIterable[list[RedisStore]]: + stores: list[RedisStore] = [ + RedisStore(redis_settings=use_in_memory_redis, namespace=f"test-{i}") + for i in range(5) + ] + for store in stores: + await store.setup() + + yield stores + + for store in stores: + await store.shutdown() + + # triggers cleanup of all redis data + async with get_redis_client_sdk(RedisDatabase.LONG_RUNNING_TASKS): + pass + + +async def test_workflow_multiple_redis_stores_with_different_namespaces( + redis_stores: list[RedisStore], task_data: TaskData +): + + for store in redis_stores: + assert await store.list_tasks_data() == [] + assert await store.get_cancelled() == {} + + for store in redis_stores: + await store.set_task_data(task_data.task_id, task_data) + await store.set_as_cancelled(task_data.task_id, None) + + for store in redis_stores: + assert await store.list_tasks_data() == [task_data] + assert await store.get_cancelled() == {task_data.task_id: None} + + for store in redis_stores: + await store.delete_task_data(task_data.task_id) + + for store in redis_stores: + assert await store.list_tasks_data() == [] diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py index 566be002808..902061cbbf2 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py @@ -6,22 +6,28 @@ import asyncio import urllib.parse -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import suppress from datetime import datetime, timedelta from typing import Any, Final import pytest from faker import Faker +from models_library.api_schemas_long_running_tasks.base import ProgressMessage from servicelib.long_running_tasks import lrt_api from servicelib.long_running_tasks.errors import ( TaskAlreadyRunningError, - TaskCancelledError, TaskNotCompletedError, TaskNotFoundError, TaskNotRegisteredError, ) -from servicelib.long_running_tasks.models import TaskProgress, TaskStatus -from servicelib.long_running_tasks.task import TaskRegistry, TasksManager +from servicelib.long_running_tasks.models import TaskContext, TaskProgress, TaskStatus +from servicelib.long_running_tasks.task import ( + RedisNamespace, + TaskRegistry, + TasksManager, +) +from settings_library.redis import RedisSettings from tenacity import TryAgain from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type @@ -44,7 +50,7 @@ async def a_background_task( """sleeps and raises an error or returns 42""" for i in range(total_sleep): await asyncio.sleep(1) - progress.update(percent=(i + 1) / total_sleep) + await progress.update(percent=(i + 1) / total_sleep) if raise_when_finished: msg = "raised this error as instructed" raise RuntimeError(msg) @@ -71,262 +77,381 @@ async def failing_background_task(progress: TaskProgress): @pytest.fixture -async def tasks_manager() -> AsyncIterator[TasksManager]: - tasks_manager = TasksManager( - stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), - stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), - ) - await tasks_manager.setup() - yield tasks_manager - await tasks_manager.teardown() +def empty_context() -> TaskContext: + return {} + + +@pytest.fixture +async def get_tasks_manager( + faker: Faker, +) -> AsyncIterator[ + Callable[[RedisSettings, RedisNamespace | None], Awaitable[TasksManager]] +]: + managers: list[TasksManager] = [] + + async def _( + redis_settings: RedisSettings, namespace: RedisNamespace | None + ) -> TasksManager: + tasks_manager = TasksManager( + stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), + stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), + redis_settings=redis_settings, + redis_namespace=namespace or f"test{faker.uuid4()}", + ) + await tasks_manager.setup() + managers.append(tasks_manager) + return tasks_manager + + yield _ + + for manager in managers: + with suppress(TimeoutError): # avoids tets hanging on teardown + await asyncio.wait_for(manager.teardown(), timeout=10) + + +@pytest.fixture +async def tasks_manager( + use_in_memory_redis: RedisSettings, + get_tasks_manager: Callable[ + [RedisSettings, RedisNamespace | None], Awaitable[TasksManager] + ], +) -> TasksManager: + return await get_tasks_manager(use_in_memory_redis, None) @pytest.mark.parametrize("check_task_presence_before", [True, False]) async def test_task_is_auto_removed( - tasks_manager: TasksManager, check_task_presence_before: bool + tasks_manager: TasksManager, + check_task_presence_before: bool, + empty_context: TaskContext, ): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=10 * TEST_CHECK_STALE_INTERVAL_S, + task_context=empty_context, ) if check_task_presence_before: # immediately after starting the task is still there - task_status = tasks_manager.get_task_status(task_id, with_task_context=None) + task_status = await tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) assert task_status # wait for task to be automatically removed # meaning no calls via the manager methods are received async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: - if task_id in tasks_manager._tracked_tasks: # noqa: SLF001 + if ( + await tasks_manager._tasks_data.get_task_data(task_id) # noqa: SLF001 + is not None + ): msg = "wait till no element is found any longer" raise TryAgain(msg) with pytest.raises(TaskNotFoundError): - tasks_manager.get_task_status(task_id, with_task_context=None) + await tasks_manager.get_task_status(task_id, with_task_context=empty_context) with pytest.raises(TaskNotFoundError): - tasks_manager.get_task_result(task_id, with_task_context=None) + await tasks_manager.get_task_result(task_id, with_task_context=empty_context) -async def test_checked_task_is_not_auto_removed(tasks_manager: TasksManager): +async def test_checked_task_is_not_auto_removed( + tasks_manager: TasksManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, + task_context=empty_context, ) async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: - status = tasks_manager.get_task_status(task_id, with_task_context=None) + status = await tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) assert status.done, f"task {task_id} not complete" - result = tasks_manager.get_task_result(task_id, with_task_context=None) + result = await tasks_manager.get_task_result( + task_id, with_task_context=empty_context + ) assert result -async def test_fire_and_forget_task_is_not_auto_removed(tasks_manager: TasksManager): +async def test_fire_and_forget_task_is_not_auto_removed( + tasks_manager: TasksManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, fire_and_forget=True, + task_context=empty_context, ) await asyncio.sleep(3 * TEST_CHECK_STALE_INTERVAL_S) # the task shall still be present even if we did not check the status before - status = tasks_manager.get_task_status(task_id, with_task_context=None) + status = await tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) assert not status.done, "task was removed although it is fire and forget" # the task shall finish - await asyncio.sleep(3 * TEST_CHECK_STALE_INTERVAL_S) + await asyncio.sleep(4 * TEST_CHECK_STALE_INTERVAL_S) # get the result - task_result = tasks_manager.get_task_result(task_id, with_task_context=None) + task_result = await tasks_manager.get_task_result( + task_id, with_task_context=empty_context + ) assert task_result == 42 -async def test_get_result_of_unfinished_task_raises(tasks_manager: TasksManager): +async def test_get_result_of_unfinished_task_raises( + tasks_manager: TasksManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, + task_context=empty_context, ) with pytest.raises(TaskNotCompletedError): - tasks_manager.get_task_result(task_id, with_task_context=None) + await tasks_manager.get_task_result(task_id, with_task_context=empty_context) -async def test_unique_task_already_running(tasks_manager: TasksManager): +async def test_unique_task_already_running( + tasks_manager: TasksManager, empty_context: TaskContext +): async def unique_task(progress: TaskProgress): _ = progress await asyncio.sleep(1) TaskRegistry.register(unique_task) - await lrt_api.start_task(tasks_manager, unique_task.__name__, unique=True) + await lrt_api.start_task( + tasks_manager, unique_task.__name__, unique=True, task_context=empty_context + ) # ensure unique running task regardless of how many times it gets started with pytest.raises(TaskAlreadyRunningError) as exec_info: - await lrt_api.start_task(tasks_manager, unique_task.__name__, unique=True) + await lrt_api.start_task( + tasks_manager, unique_task.__name__, unique=True, task_context=empty_context + ) assert "must be unique, found: " in f"{exec_info.value}" TaskRegistry.unregister(unique_task) -async def test_start_multiple_not_unique_tasks(tasks_manager: TasksManager): +async def test_start_multiple_not_unique_tasks( + tasks_manager: TasksManager, empty_context: TaskContext +): async def not_unique_task(progress: TaskProgress): await asyncio.sleep(1) TaskRegistry.register(not_unique_task) for _ in range(5): - await lrt_api.start_task(tasks_manager, not_unique_task.__name__) + await lrt_api.start_task( + tasks_manager, not_unique_task.__name__, task_context=empty_context + ) TaskRegistry.unregister(not_unique_task) @pytest.mark.parametrize("is_unique", [True, False]) -def test_get_task_id(tasks_manager: TasksManager, faker: Faker, is_unique: bool): +async def test_get_task_id(tasks_manager: TasksManager, faker: Faker, is_unique: bool): obj1 = tasks_manager._get_task_id(faker.word(), is_unique=is_unique) # noqa: SLF001 obj2 = tasks_manager._get_task_id(faker.word(), is_unique=is_unique) # noqa: SLF001 assert obj1 != obj2 -async def test_get_status(tasks_manager: TasksManager): +async def test_get_status(tasks_manager: TasksManager, empty_context: TaskContext): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=10, + task_context=empty_context, + ) + task_status = await tasks_manager.get_task_status( + task_id, with_task_context=empty_context ) - task_status = tasks_manager.get_task_status(task_id, with_task_context=None) assert isinstance(task_status, TaskStatus) - assert task_status.task_progress.message == "" + assert isinstance(task_status.task_progress.message, ProgressMessage) assert task_status.task_progress.percent == 0.0 assert task_status.done is False assert isinstance(task_status.started, datetime) -async def test_get_status_missing(tasks_manager: TasksManager): +async def test_get_status_missing( + tasks_manager: TasksManager, empty_context: TaskContext +): with pytest.raises(TaskNotFoundError) as exec_info: - tasks_manager.get_task_status("missing_task_id", with_task_context=None) + await tasks_manager.get_task_status( + "missing_task_id", with_task_context=empty_context + ) assert f"{exec_info.value}" == "No task with missing_task_id found" -async def test_get_result(tasks_manager: TasksManager): - task_id = await lrt_api.start_task(tasks_manager, fast_background_task.__name__) - await asyncio.sleep(0.1) - result = tasks_manager.get_task_result(task_id, with_task_context=None) +async def test_get_result(tasks_manager: TasksManager, empty_context: TaskContext): + task_id = await lrt_api.start_task( + tasks_manager, fast_background_task.__name__, task_context=empty_context + ) + + async for attempt in AsyncRetrying(**_RETRY_PARAMS): + with attempt: + status = await tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) + assert status.done is True + + result = await tasks_manager.get_task_result( + task_id, with_task_context=empty_context + ) assert result == 42 -async def test_get_result_missing(tasks_manager: TasksManager): +async def test_get_result_missing( + tasks_manager: TasksManager, empty_context: TaskContext +): with pytest.raises(TaskNotFoundError) as exec_info: - tasks_manager.get_task_result("missing_task_id", with_task_context=None) + await tasks_manager.get_task_result( + "missing_task_id", with_task_context=empty_context + ) assert f"{exec_info.value}" == "No task with missing_task_id found" -async def test_get_result_finished_with_error(tasks_manager: TasksManager): - task_id = await lrt_api.start_task(tasks_manager, failing_background_task.__name__) +async def test_get_result_finished_with_error( + tasks_manager: TasksManager, empty_context: TaskContext +): + task_id = await lrt_api.start_task( + tasks_manager, failing_background_task.__name__, task_context=empty_context + ) # wait for result async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: - assert tasks_manager.get_task_status(task_id, with_task_context=None).done + assert ( + await tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) + ).done with pytest.raises(RuntimeError, match="failing asap"): - tasks_manager.get_task_result(task_id, with_task_context=None) + await tasks_manager.get_task_result(task_id, with_task_context=empty_context) -async def test_get_result_task_was_cancelled_multiple_times( - tasks_manager: TasksManager, +async def test_cancel_task_from_different_manager( + use_in_memory_redis: RedisSettings, + get_tasks_manager: Callable[ + [RedisSettings, RedisNamespace | None], Awaitable[TasksManager] + ], + empty_context: TaskContext, ): + manager_1 = await get_tasks_manager(use_in_memory_redis, "test-namespace") + manager_2 = await get_tasks_manager(use_in_memory_redis, "test-namespace") + manager_3 = await get_tasks_manager(use_in_memory_redis, "test-namespace") + task_id = await lrt_api.start_task( - tasks_manager, + manager_1, a_background_task.__name__, raise_when_finished=False, - total_sleep=10, + total_sleep=1, + task_context=empty_context, ) - for _ in range(5): - await tasks_manager.cancel_task(task_id, with_task_context=None) - with pytest.raises( - TaskCancelledError, match=f"Task {task_id} was cancelled before completing" - ): - tasks_manager.get_task_result(task_id, with_task_context=None) + # wati for task to complete + for manager in (manager_1, manager_2, manager_3): + status = await manager.get_task_status(task_id, empty_context) + assert status.done is False + async for attempt in AsyncRetrying(**_RETRY_PARAMS): + with attempt: + for manager in (manager_1, manager_2, manager_3): + status = await manager.get_task_status(task_id, empty_context) + assert status.done is True + + # check all provide the same result + for manager in (manager_1, manager_2, manager_3): + task_result = await manager.get_task_result(task_id, empty_context) + assert task_result == 42 -async def test_remove_task(tasks_manager: TasksManager): + +async def test_remove_task(tasks_manager: TasksManager, empty_context: TaskContext): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=10, + task_context=empty_context, ) - tasks_manager.get_task_status(task_id, with_task_context=None) - await tasks_manager.remove_task(task_id, with_task_context=None) + await tasks_manager.get_task_status(task_id, with_task_context=empty_context) + await tasks_manager.remove_task(task_id, with_task_context=empty_context) with pytest.raises(TaskNotFoundError): - tasks_manager.get_task_status(task_id, with_task_context=None) + await tasks_manager.get_task_status(task_id, with_task_context=empty_context) with pytest.raises(TaskNotFoundError): - tasks_manager.get_task_result(task_id, with_task_context=None) + await tasks_manager.get_task_result(task_id, with_task_context=empty_context) -async def test_remove_task_with_task_context(tasks_manager: TasksManager): - TASK_CONTEXT = {"some_context": "some_value"} +async def test_remove_task_with_task_context( + tasks_manager: TasksManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=10, - task_context=TASK_CONTEXT, + task_context=empty_context, ) # getting status fails if wrong task context given with pytest.raises(TaskNotFoundError): - tasks_manager.get_task_status( + await tasks_manager.get_task_status( task_id, with_task_context={"wrong_task_context": 12} ) - tasks_manager.get_task_status(task_id, with_task_context=TASK_CONTEXT) + await tasks_manager.get_task_status(task_id, with_task_context=empty_context) # removing task fails if wrong task context given with pytest.raises(TaskNotFoundError): await tasks_manager.remove_task( task_id, with_task_context={"wrong_task_context": 12} ) - await tasks_manager.remove_task(task_id, with_task_context=TASK_CONTEXT) + await tasks_manager.remove_task(task_id, with_task_context=empty_context) -async def test_remove_unknown_task(tasks_manager: TasksManager): +async def test_remove_unknown_task( + tasks_manager: TasksManager, empty_context: TaskContext +): with pytest.raises(TaskNotFoundError): - await tasks_manager.remove_task("invalid_id", with_task_context=None) + await tasks_manager.remove_task("invalid_id", with_task_context=empty_context) await tasks_manager.remove_task( - "invalid_id", with_task_context=None, reraise_errors=False + "invalid_id", with_task_context=empty_context, reraise_errors=False ) -async def test_cancel_task_with_task_context(tasks_manager: TasksManager): - TASK_CONTEXT = {"some_context": "some_value"} +async def test__cancelled_tasks_worker_equivalent_of_cancellation_from_a_different_process( + tasks_manager: TasksManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=10, - task_context=TASK_CONTEXT, + task_context=empty_context, ) - # getting status fails if wrong task context given - with pytest.raises(TaskNotFoundError): - tasks_manager.get_task_status( - task_id, with_task_context={"wrong_task_context": 12} - ) - # getting status fails if wrong task context given - with pytest.raises(TaskNotFoundError): - await tasks_manager.cancel_task( - task_id, with_task_context={"wrong_task_context": 12} - ) - await tasks_manager.cancel_task(task_id, with_task_context=TASK_CONTEXT) + await tasks_manager._tasks_data.set_as_cancelled( # noqa: SLF001 + task_id, with_task_context=empty_context + ) + + async for attempt in AsyncRetrying(**_RETRY_PARAMS): + with attempt: # noqa: SIM117 + with pytest.raises(TaskNotFoundError): + assert ( + await tasks_manager.get_task_status(task_id, empty_context) is None + ) -async def test_list_tasks(tasks_manager: TasksManager): - assert tasks_manager.list_tasks(with_task_context=None) == [] +async def test_list_tasks(tasks_manager: TasksManager, empty_context: TaskContext): + assert await tasks_manager.list_tasks(with_task_context=empty_context) == [] # start a bunch of tasks NUM_TASKS = 10 task_ids = [] @@ -337,22 +462,29 @@ async def test_list_tasks(tasks_manager: TasksManager): a_background_task.__name__, raise_when_finished=False, total_sleep=10, + task_context=empty_context, ) ) - assert len(tasks_manager.list_tasks(with_task_context=None)) == NUM_TASKS + assert ( + len(await tasks_manager.list_tasks(with_task_context=empty_context)) + == NUM_TASKS + ) for task_index, task_id in enumerate(task_ids): - await tasks_manager.remove_task(task_id, with_task_context=None) - assert len(tasks_manager.list_tasks(with_task_context=None)) == NUM_TASKS - ( - task_index + 1 - ) + await tasks_manager.remove_task(task_id, with_task_context=empty_context) + assert len( + await tasks_manager.list_tasks(with_task_context=empty_context) + ) == NUM_TASKS - (task_index + 1) -async def test_list_tasks_filtering(tasks_manager: TasksManager): +async def test_list_tasks_filtering( + tasks_manager: TasksManager, empty_context: TaskContext +): await lrt_api.start_task( tasks_manager, a_background_task.__name__, raise_when_finished=False, total_sleep=10, + task_context=empty_context, ) await lrt_api.start_task( tasks_manager, @@ -368,11 +500,11 @@ async def test_list_tasks_filtering(tasks_manager: TasksManager): total_sleep=10, task_context={"user_id": 213, "product": "osparc"}, ) - assert len(tasks_manager.list_tasks(with_task_context=None)) == 3 - assert len(tasks_manager.list_tasks(with_task_context={"user_id": 213})) == 1 + assert len(await tasks_manager.list_tasks(with_task_context=empty_context)) == 3 + assert len(await tasks_manager.list_tasks(with_task_context={"user_id": 213})) == 1 assert ( len( - tasks_manager.list_tasks( + await tasks_manager.list_tasks( with_task_context={"user_id": 213, "product": "osparc"} ) ) @@ -380,7 +512,7 @@ async def test_list_tasks_filtering(tasks_manager: TasksManager): ) assert ( len( - tasks_manager.list_tasks( + await tasks_manager.list_tasks( with_task_context={"user_id": 120, "product": "osparc"} ) ) diff --git a/packages/service-library/tests/redis/conftest.py b/packages/service-library/tests/redis/conftest.py index ae6d04c2085..c975dc1f4ad 100644 --- a/packages/service-library/tests/redis/conftest.py +++ b/packages/service-library/tests/redis/conftest.py @@ -12,11 +12,11 @@ @pytest.fixture async def redis_client_sdk( - get_redis_client_sdk: Callable[ + get_in_process_redis_client_sdk: Callable[ [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] ], ) -> AsyncIterator[RedisClientSDK]: - async with get_redis_client_sdk(RedisDatabase.RESOURCES) as client: + async with get_in_process_redis_client_sdk(RedisDatabase.RESOURCES) as client: yield client diff --git a/packages/service-library/tests/redis/test_client.py b/packages/service-library/tests/redis/test_client.py index 210c857bb9b..91ff29e5f38 100644 --- a/packages/service-library/tests/redis/test_client.py +++ b/packages/service-library/tests/redis/test_client.py @@ -110,7 +110,7 @@ async def test_redis_client_sdk_setup_shutdown( # setup redis_resources_dns = redis_service.build_redis_dsn(RedisDatabase.RESOURCES) client = RedisClientSDK(redis_resources_dns, client_name="pytest") - assert client + await client.setup() assert client.redis_dsn == redis_resources_dns # ensure health check task sets the health to True diff --git a/packages/service-library/tests/test_background_task_utils.py b/packages/service-library/tests/test_background_task_utils.py index 92d0337eca7..7307a8b7b89 100644 --- a/packages/service-library/tests/test_background_task_utils.py +++ b/packages/service-library/tests/test_background_task_utils.py @@ -24,13 +24,6 @@ wait_fixed, ) -pytest_simcore_core_services_selection = [ - "redis", -] -pytest_simcore_ops_services_selection = [ - "redis-commander", -] - @pytest.fixture async def redis_client_sdk( diff --git a/packages/settings-library/src/settings_library/redis.py b/packages/settings-library/src/settings_library/redis.py index 63e64fce449..28d6b6c66bd 100644 --- a/packages/settings-library/src/settings_library/redis.py +++ b/packages/settings-library/src/settings_library/redis.py @@ -15,7 +15,7 @@ class RedisDatabase(IntEnum): SCHEDULED_MAINTENANCE = 3 USER_NOTIFICATIONS = 4 ANNOUNCEMENTS = 5 - DISTRIBUTED_IDENTIFIERS = 6 + LONG_RUNNING_TASKS = 6 DEFERRED_TASKS = 7 DYNAMIC_SERVICES = 8 CELERY_TASKS = 9 diff --git a/services/autoscaling/src/simcore_service_autoscaling/modules/redis.py b/services/autoscaling/src/simcore_service_autoscaling/modules/redis.py index c0cf7a15e07..4aa9cea509c 100644 --- a/services/autoscaling/src/simcore_service_autoscaling/modules/redis.py +++ b/services/autoscaling/src/simcore_service_autoscaling/modules/redis.py @@ -18,6 +18,7 @@ async def on_startup() -> None: app.state.redis_client_sdk = RedisClientSDK( redis_locks_dsn, client_name=APP_NAME ) + await app.state.redis_client_sdk.setup() async def on_shutdown() -> None: redis_client_sdk: None | RedisClientSDK = app.state.redis_client_sdk diff --git a/services/clusters-keeper/src/simcore_service_clusters_keeper/modules/redis.py b/services/clusters-keeper/src/simcore_service_clusters_keeper/modules/redis.py index 8e2d5b71e33..595d41a4a55 100644 --- a/services/clusters-keeper/src/simcore_service_clusters_keeper/modules/redis.py +++ b/services/clusters-keeper/src/simcore_service_clusters_keeper/modules/redis.py @@ -19,6 +19,7 @@ async def on_startup() -> None: app.state.redis_client_sdk = RedisClientSDK( redis_locks_dsn, client_name=APP_NAME ) + await app.state.redis_client_sdk.setup() async def on_shutdown() -> None: redis_client_sdk: None | RedisClientSDK = app.state.redis_client_sdk diff --git a/services/director-v2/requirements/_test.in b/services/director-v2/requirements/_test.in index 2fb831189ba..34b5327e253 100644 --- a/services/director-v2/requirements/_test.in +++ b/services/director-v2/requirements/_test.in @@ -17,6 +17,7 @@ async-asgi-testclient # replacement for fastapi.testclient.TestClient [see b) be dask[distributed,diagnostics] docker Faker +fakeredis[lua] flaky pytest pytest-asyncio diff --git a/services/director-v2/requirements/_test.txt b/services/director-v2/requirements/_test.txt index 706aa216225..619015cc94b 100644 --- a/services/director-v2/requirements/_test.txt +++ b/services/director-v2/requirements/_test.txt @@ -99,6 +99,8 @@ execnet==2.1.1 # via pytest-xdist faker==37.3.0 # via -r requirements/_test.in +fakeredis==2.30.3 + # via -r requirements/_test.in flaky==3.8.1 # via -r requirements/_test.in frozenlist==1.6.0 @@ -159,6 +161,8 @@ locket==1.0.0 # -c requirements/_base.txt # distributed # partd +lupa==2.5 + # via fakeredis mako==1.3.10 # via # -c requirements/../../../requirements/constraints.txt @@ -271,6 +275,11 @@ pyyaml==6.0.2 # bokeh # dask # distributed +redis==6.1.0 + # via + # -c requirements/../../../requirements/constraints.txt + # -c requirements/_base.txt + # fakeredis requests==2.32.4 # via # -c requirements/_base.txt @@ -293,6 +302,7 @@ sortedcontainers==2.4.0 # via # -c requirements/_base.txt # distributed + # fakeredis sqlalchemy==1.4.54 # via # -c requirements/../../../requirements/constraints.txt diff --git a/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py b/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py index bbec37d8d45..2c84fd0bb36 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py @@ -106,7 +106,7 @@ async def _task_remove_service_containers( async def _progress_callback( message: ProgressMessage, percent: ProgressPercent | None, _: TaskId ) -> None: - progress.update(message=message, percent=percent) + await progress.update(message=message, percent=percent) await dynamic_sidecars_scheduler.remove_service_containers( node_uuid=node_uuid, progress_callback=_progress_callback @@ -171,7 +171,7 @@ async def _task_save_service_state( async def _progress_callback( message: ProgressMessage, percent: ProgressPercent | None, _: TaskId ) -> None: - progress.update(message=message, percent=percent) + await progress.update(message=message, percent=percent) await dynamic_sidecars_scheduler.save_service_state( node_uuid=node_uuid, progress_callback=_progress_callback @@ -218,7 +218,7 @@ async def _task_push_service_outputs( async def _progress_callback( message: ProgressMessage, percent: ProgressPercent | None, _: TaskId ) -> None: - progress.update(message=message, percent=percent) + await progress.update(message=message, percent=percent) await dynamic_sidecars_scheduler.push_service_outputs( node_uuid=node_uuid, progress_callback=_progress_callback diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_service_specs/sidecar.py b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_service_specs/sidecar.py index d7d013208cb..dd05734f237 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_service_specs/sidecar.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_service_specs/sidecar.py @@ -91,6 +91,7 @@ def _get_environment_variables( telemetry_enabled: bool, ) -> dict[str, str]: rabbit_settings = app_settings.DIRECTOR_V2_RABBITMQ + redis_settings = app_settings.REDIS r_clone_settings = ( app_settings.DYNAMIC_SERVICES.DYNAMIC_SIDECAR.DYNAMIC_SIDECAR_R_CLONE_SETTINGS ) @@ -163,6 +164,9 @@ def _get_environment_variables( "RABBIT_PORT": f"{rabbit_settings.RABBIT_PORT}", "RABBIT_USER": f"{rabbit_settings.RABBIT_USER}", "RABBIT_SECURE": f"{rabbit_settings.RABBIT_SECURE}", + "REDIS_SETTINGS": json_dumps( + model_dump_with_secrets(redis_settings, show_secrets=True) + ), "DY_DEPLOYMENT_REGISTRY_SETTINGS": ( json_dumps( model_dump_with_secrets( diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py index d1fd90644af..253f9be601d 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py @@ -1,12 +1,22 @@ from fastapi import FastAPI from servicelib.fastapi import long_running_tasks +from servicelib.long_running_tasks.task import RedisNamespace +from ...core.settings import AppSettings from . import api_client, scheduler +_LONG_RUNNING_TASKS_NAMESPACE: RedisNamespace = "director-v2" + def setup(app: FastAPI) -> None: + settings: AppSettings = app.state.settings + long_running_tasks.client.setup(app) - long_running_tasks.server.setup(app) + long_running_tasks.server.setup( + app, + redis_settings=settings.REDIS, + redis_namespace=_LONG_RUNNING_TASKS_NAMESPACE, + ) async def on_startup() -> None: await api_client.setup(app) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/scheduler/_core/_events_utils.py b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/scheduler/_core/_events_utils.py index 2caa9f7a237..e5083d5895c 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/scheduler/_core/_events_utils.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/scheduler/_core/_events_utils.py @@ -217,7 +217,7 @@ async def service_remove_sidecar_proxy_docker_networks_and_volumes( if set_were_state_and_outputs_saved is not None: scheduler_data.dynamic_sidecar.were_state_and_outputs_saved = True - task_progress.update( + await task_progress.update( message="removing dynamic sidecar stack", percent=ProgressPercent(0.1) ) await remove_dynamic_sidecar_stack( @@ -232,7 +232,7 @@ async def service_remove_sidecar_proxy_docker_networks_and_volumes( node_id=scheduler_data.node_uuid, ) - task_progress.update(message="removing network", percent=ProgressPercent(0.2)) + await task_progress.update(message="removing network", percent=ProgressPercent(0.2)) await remove_dynamic_sidecar_network(scheduler_data.dynamic_sidecar_network_name) if scheduler_data.dynamic_sidecar.were_state_and_outputs_saved: @@ -243,7 +243,7 @@ async def service_remove_sidecar_proxy_docker_networks_and_volumes( ) else: # Remove all dy-sidecar associated volumes from node - task_progress.update( + await task_progress.update( message="removing volumes", percent=ProgressPercent(0.3) ) with log_context(_logger, logging.DEBUG, f"removing volumes '{node_uuid}'"): @@ -265,7 +265,7 @@ async def service_remove_sidecar_proxy_docker_networks_and_volumes( scheduler_data.service_name, ) - task_progress.update( + await task_progress.update( message="removing project networks", percent=ProgressPercent(0.8) ) used_projects_networks = await get_projects_networks_containers( @@ -284,7 +284,7 @@ async def service_remove_sidecar_proxy_docker_networks_and_volumes( await app.state.dynamic_sidecar_scheduler.scheduler.remove_service_from_observation( scheduler_data.node_uuid ) - task_progress.update( + await task_progress.update( message="finished removing resources", percent=ProgressPercent(1) ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/redis.py b/services/director-v2/src/simcore_service_director_v2/modules/redis.py index 5928cc78e97..9e02e403ab6 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/redis.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/redis.py @@ -14,11 +14,7 @@ async def on_startup() -> None: app.state.redis_clients_manager = redis_clients_manager = RedisClientsManager( databases_configs={ - RedisManagerDBConfig(database=db) - for db in ( - RedisDatabase.LOCKS, - RedisDatabase.DISTRIBUTED_IDENTIFIERS, - ) + RedisManagerDBConfig(database=db) for db in (RedisDatabase.LOCKS,) }, settings=settings.REDIS, client_name=APP_NAME, diff --git a/services/director-v2/src/simcore_service_director_v2/utils/base_distributed_identifier.py b/services/director-v2/src/simcore_service_director_v2/utils/base_distributed_identifier.py deleted file mode 100644 index 25d5dca72f3..00000000000 --- a/services/director-v2/src/simcore_service_director_v2/utils/base_distributed_identifier.py +++ /dev/null @@ -1,286 +0,0 @@ -import logging -from abc import ABC, abstractmethod -from asyncio import Task -from datetime import timedelta -from typing import Final, Generic, TypeVar - -from common_library.async_tools import cancel_wait_task -from pydantic import NonNegativeInt -from servicelib.background_task import create_periodic_task -from servicelib.logging_utils import log_catch, log_context -from servicelib.redis import RedisClientSDK -from servicelib.utils import logged_gather -from settings_library.redis import RedisDatabase - -_logger = logging.getLogger(__name__) - -_REDIS_MAX_CONCURRENCY: Final[NonNegativeInt] = 10 -_DEFAULT_CLEANUP_INTERVAL: Final[timedelta] = timedelta(minutes=1) - -Identifier = TypeVar("Identifier") -ResourceObject = TypeVar("ResourceObject") -CleanupContext = TypeVar("CleanupContext") - - -class BaseDistributedIdentifierManager( - ABC, Generic[Identifier, ResourceObject, CleanupContext] -): - """Used to implement managers for resources that require book keeping - in a distributed system. - - NOTE: that ``Identifier`` and ``CleanupContext`` are serialized and deserialized - to and from Redis. - - Generics: - Identifier -- a user defined object: used to uniquely identify the resource - ResourceObject -- a user defined object: referring to an existing resource - CleanupContext -- a user defined object: contains all necessary - arguments used for removal and cleanup. - """ - - def __init__( - self, - redis_client_sdk: RedisClientSDK, - *, - cleanup_interval: timedelta = _DEFAULT_CLEANUP_INTERVAL, - ) -> None: - """ - Arguments: - redis_client_sdk -- client connecting to Redis - - Keyword Arguments: - cleanup_interval -- interval at which cleanup for unused - resources runs (default: {_DEFAULT_CLEANUP_INTERVAL}) - """ - - if not redis_client_sdk.redis_dsn.endswith( - f"{RedisDatabase.DISTRIBUTED_IDENTIFIERS}" - ): - msg = ( - f"Redis endpoint {redis_client_sdk.redis_dsn} contains the wrong database." - f"Expected {RedisDatabase.DISTRIBUTED_IDENTIFIERS}" - ) - raise TypeError(msg) - - self._redis_client_sdk = redis_client_sdk - self.cleanup_interval = cleanup_interval - - self._cleanup_task: Task | None = None - - async def setup(self) -> None: - self._cleanup_task = create_periodic_task( - self._cleanup_unused_identifiers, - interval=self.cleanup_interval, - task_name="cleanup_unused_identifiers_task", - ) - - async def shutdown(self) -> None: - if self._cleanup_task: - await cancel_wait_task(self._cleanup_task, max_delay=5) - - @classmethod - def class_path(cls) -> str: - return f"{cls.__module__}.{cls.__name__}" - - @classmethod - def _redis_key_prefix(cls) -> str: - return f"{cls.class_path()}:" - - @classmethod - def _to_redis_key(cls, identifier: Identifier) -> str: - return f"{cls._redis_key_prefix()}{cls._serialize_identifier(identifier)}" - - @classmethod - def _from_redis_key(cls, redis_key: str) -> Identifier: - sad = redis_key.removeprefix(cls._redis_key_prefix()) - return cls._deserialize_identifier(sad) - - async def _get_identifier_context( - self, identifier: Identifier - ) -> CleanupContext | None: - raw: str | None = await self._redis_client_sdk.redis.get( - self._to_redis_key(identifier) - ) - return self._deserialize_cleanup_context(raw) if raw else None - - async def _get_tracked(self) -> dict[Identifier, CleanupContext]: - identifiers: list[Identifier] = [ - self._from_redis_key(redis_key) - for redis_key in await self._redis_client_sdk.redis.keys( - f"{self._redis_key_prefix()}*" - ) - ] - - cleanup_contexts: list[CleanupContext | None] = await logged_gather( - *(self._get_identifier_context(identifier) for identifier in identifiers), - max_concurrency=_REDIS_MAX_CONCURRENCY, - ) - - return { - identifier: cleanup_context - for identifier, cleanup_context in zip( - identifiers, cleanup_contexts, strict=True - ) - # NOTE: cleanup_context will be None if the key was removed before - # recovering all the cleanup_contexts - if cleanup_context is not None - } - - async def _cleanup_unused_identifiers(self) -> None: - # removes no longer used identifiers - tracked_data: dict[Identifier, CleanupContext] = await self._get_tracked() - _logger.info("Will remove unused %s", list(tracked_data.keys())) - - for identifier, cleanup_context in tracked_data.items(): - if await self.is_used(identifier, cleanup_context): - continue - - await self.remove(identifier) - - async def create( - self, *, cleanup_context: CleanupContext, **extra_kwargs - ) -> tuple[Identifier, ResourceObject]: - """Used for creating the resources - - Arguments: - cleanup_context -- user defined CleanupContext object - **extra_kwargs -- can be overloaded by the user - - Returns: - tuple[identifier for the resource, resource object] - """ - identifier, result = await self._create(**extra_kwargs) - await self._redis_client_sdk.redis.set( - self._to_redis_key(identifier), - self._serialize_cleanup_context(cleanup_context), - ) - return identifier, result - - async def remove(self, identifier: Identifier, *, reraise: bool = False) -> None: - """Attempts to remove the resource, if an error occurs it is logged. - - Arguments: - identifier -- user chosen identifier for the resource - reraise -- when True raises any exception raised by ``destroy`` (default: {False}) - """ - - cleanup_context = await self._get_identifier_context(identifier) - if cleanup_context is None: - _logger.warning( - "Something went wrong, did not find any context for %s", identifier - ) - return - - with ( - log_context( - _logger, logging.DEBUG, f"{self.__class__}: removing {identifier}" - ), - log_catch(_logger, reraise=reraise), - ): - await self._destroy(identifier, cleanup_context) - - await self._redis_client_sdk.redis.delete(self._to_redis_key(identifier)) - - @classmethod - @abstractmethod - def _deserialize_identifier(cls, raw: str) -> Identifier: - """User provided deserialization for the identifier - - Arguments: - raw -- stream to be deserialized - - Returns: - an identifier object - """ - - @classmethod - @abstractmethod - def _serialize_identifier(cls, identifier: Identifier) -> str: - """User provided serialization for the identifier - - Arguments: - cleanup_context -- user defined identifier object - - Returns: - object encoded as string - """ - - @classmethod - @abstractmethod - def _deserialize_cleanup_context(cls, raw: str) -> CleanupContext: - """User provided deserialization for the context - - Arguments: - raw -- stream to be deserialized - - Returns: - an object of the type chosen by the user - """ - - @classmethod - @abstractmethod - def _serialize_cleanup_context(cls, cleanup_context: CleanupContext) -> str: - """User provided serialization for the context - - Arguments: - cleanup_context -- user defined cleanup context object - - Returns: - object encoded as string - """ - - @abstractmethod - async def is_used( - self, identifier: Identifier, cleanup_context: CleanupContext - ) -> bool: - """Check if the resource associated to the ``identifier`` is - still being used. - # NOTE: a resource can be created but not in use. - - Arguments: - identifier -- user chosen identifier for the resource - cleanup_context -- user defined CleanupContext object - - Returns: - True if ``identifier`` is still being used - """ - - @abstractmethod - async def _create(self, **extra_kwargs) -> tuple[Identifier, ResourceObject]: - """Used INTERNALLY for creating the resources. - # NOTE: should not be used directly, use the public - version ``create`` instead. - - Arguments: - **extra_kwargs -- can be overloaded by the user - - Returns: - tuple[identifier for the resource, resource object] - """ - - @abstractmethod - async def get( - self, identifier: Identifier, **extra_kwargs - ) -> ResourceObject | None: - """If exists, returns the resource. - - Arguments: - identifier -- user chosen identifier for the resource - **extra_kwargs -- can be overloaded by the user - - Returns: - None if the resource does not exit - """ - - @abstractmethod - async def _destroy( - self, identifier: Identifier, cleanup_context: CleanupContext - ) -> None: - """Used to destroy an existing resource - # NOTE: should not be used directly, use the public - version ``remove`` instead. - - Arguments: - identifier -- user chosen identifier for the resource - cleanup_context -- user defined CleanupContext object - """ diff --git a/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py b/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py index 37f2b7b4965..aa0eea7c7d7 100644 --- a/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py +++ b/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py @@ -15,6 +15,7 @@ from models_library.service_settings_labels import SimcoreServiceLabels from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict +from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import SchedulerData from simcore_service_director_v2.modules.dynamic_sidecar.errors import ( DynamicSidecarNotFoundError, @@ -27,6 +28,7 @@ @pytest.fixture def mock_env( + use_in_memory_redis: RedisSettings, mock_exclusive: None, disable_rabbitmq: None, disable_postgres: None, diff --git a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_docker_service_specs_sidecar.py b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_docker_service_specs_sidecar.py index 27cdb831914..01af42b5a6c 100644 --- a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_docker_service_specs_sidecar.py +++ b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_docker_service_specs_sidecar.py @@ -55,6 +55,7 @@ "RABBIT_PORT", "RABBIT_SECURE", "RABBIT_USER", + "REDIS_SETTINGS", "S3_ACCESS_KEY", "S3_BUCKET_NAME", "S3_ENDPOINT", diff --git a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py index 77c1e033ef6..07deb1aeb8e 100644 --- a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py +++ b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py @@ -23,6 +23,7 @@ from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict from respx.router import MockRouter +from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import ( DockerContainerInspect, DynamicSidecarStatus, @@ -124,6 +125,7 @@ async def _assert_get_dynamic_services_mocked( @pytest.fixture def mock_env( + use_in_memory_redis: RedisSettings, mock_exclusive: None, disable_postgres: None, disable_rabbitmq: None, diff --git a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py index fd328bd66aa..e8ed258bbea 100644 --- a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py +++ b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py @@ -20,6 +20,7 @@ from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from respx.router import MockRouter +from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import SchedulerData from simcore_service_director_v2.modules.dynamic_sidecar.api_client._public import ( SidecarsClient, @@ -43,6 +44,7 @@ @pytest.fixture def mock_env( + use_in_memory_redis: RedisSettings, disable_postgres: None, disable_rabbitmq: None, mock_env: EnvVarsDict, diff --git a/services/director-v2/tests/unit/test_modules_notifier.py b/services/director-v2/tests/unit/test_modules_notifier.py index 357edc68af8..f30091676c5 100644 --- a/services/director-v2/tests/unit/test_modules_notifier.py +++ b/services/director-v2/tests/unit/test_modules_notifier.py @@ -23,6 +23,7 @@ from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from servicelib.utils import logged_gather from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings from simcore_service_director_v2.core.settings import AppSettings from simcore_service_director_v2.modules.notifier import ( publish_shutdown_no_more_credits, @@ -34,6 +35,7 @@ pytest_simcore_core_services_selection = [ "rabbit", + "redis", ] @@ -127,6 +129,7 @@ async def _assert_call_count(mock: AsyncMock, *, call_count: int) -> None: async def test_notifier_publish_message( + redis_service: RedisSettings, socketio_server_events: dict[str, AsyncMock], initialized_app: FastAPI, user_id: UserID, diff --git a/services/director-v2/tests/unit/test_utils_distributed_identifier.py b/services/director-v2/tests/unit/test_utils_distributed_identifier.py deleted file mode 100644 index c7ad46b74a9..00000000000 --- a/services/director-v2/tests/unit/test_utils_distributed_identifier.py +++ /dev/null @@ -1,359 +0,0 @@ -# pylint:disable=protected-access -# pylint:disable=redefined-outer-name - -import asyncio -import string -from collections.abc import AsyncIterable, AsyncIterator -from dataclasses import dataclass -from secrets import choice -from typing import Final -from uuid import UUID, uuid4 - -import pytest -from pydantic import BaseModel, NonNegativeInt -from pytest_mock import MockerFixture -from servicelib.redis import RedisClientSDK -from servicelib.utils import logged_gather -from settings_library.redis import RedisDatabase, RedisSettings -from simcore_service_director_v2.utils.base_distributed_identifier import ( - BaseDistributedIdentifierManager, -) - -pytest_simcore_core_services_selection = [ - "redis", -] - -pytest_simcore_ops_services_selection = [ - # "redis-commander", -] - -# if this goes too high, max open file limit is reached -_MAX_REDIS_CONCURRENCY: Final[NonNegativeInt] = 1000 - - -class UserDefinedID: - # define a custom type of ID for the API - # by choice it is hard to serialize/deserialize - - def __init__(self, uuid: UUID | None = None) -> None: - self._id = uuid if uuid else uuid4() - - def __eq__(self, other: "UserDefinedID") -> bool: - return self._id == other._id - - # only necessary for nice looking IDs in the logs - def __repr__(self) -> str: - return f"" - - # only necessary for RandomTextAPI - def __hash__(self): - return hash(str(self._id)) - - -class RandomTextEntry(BaseModel): - text: str - - @classmethod - def create(cls, length: int) -> "RandomTextEntry": - letters_and_digits = string.ascii_letters + string.digits - text = "".join(choice(letters_and_digits) for _ in range(length)) - return cls(text=text) - - -class RandomTextAPI: - # Emulates an external API - # used to create resources - - def __init__(self) -> None: - self._created: dict[UserDefinedID, RandomTextEntry] = {} - - def create(self, length: int) -> tuple[UserDefinedID, RandomTextEntry]: - identifier = UserDefinedID(uuid4()) - self._created[identifier] = RandomTextEntry.create(length) - return identifier, self._created[identifier] - - def delete(self, identifier: UserDefinedID) -> None: - del self._created[identifier] - - def get(self, identifier: UserDefinedID) -> RandomTextEntry | None: - return self._created.get(identifier, None) - - -@dataclass -class ComponentUsingRandomText: - # Emulates another component in the system - # using the created resources - - _in_use: bool = True - - def is_used(self, an_id: UserDefinedID) -> bool: - _ = an_id - return self._in_use - - def toggle_usage(self, in_use: bool) -> None: - self._in_use = in_use - - -class AnEmptyTextCleanupContext(BaseModel): - # nothing is required during cleanup, so the context - # is an empty object. - # A ``pydantic.BaseModel`` is used for convenience - # this could have inherited from ``object`` - ... - - -class RandomTextResourcesManager( - BaseDistributedIdentifierManager[ - UserDefinedID, RandomTextEntry, AnEmptyTextCleanupContext - ] -): - # Implements a resource manager for handling the lifecycle of - # resources created by a service. - # It also comes in with automatic cleanup in case the service owing - # the resources failed to removed them in the past. - - def __init__( - self, - redis_client_sdk: RedisClientSDK, - component_using_random_text: ComponentUsingRandomText, - ) -> None: - # THESE two systems would normally come stored in the `app` context - self.api = RandomTextAPI() - self.component_using_random_text = component_using_random_text - - super().__init__(redis_client_sdk) - - @classmethod - def _deserialize_identifier(cls, raw: str) -> UserDefinedID: - return UserDefinedID(UUID(raw)) - - @classmethod - def _serialize_identifier(cls, identifier: UserDefinedID) -> str: - return f"{identifier._id}" # noqa: SLF001 - - @classmethod - def _deserialize_cleanup_context( - cls, raw: str | bytes - ) -> AnEmptyTextCleanupContext: - return AnEmptyTextCleanupContext.model_validate_json(raw) - - @classmethod - def _serialize_cleanup_context( - cls, cleanup_context: AnEmptyTextCleanupContext - ) -> str: - return cleanup_context.model_dump_json() - - async def is_used( - self, identifier: UserDefinedID, cleanup_context: AnEmptyTextCleanupContext - ) -> bool: - _ = cleanup_context - return self.component_using_random_text.is_used(identifier) - - # NOTE: it is intended for the user to overwrite the **kwargs with custom names - # to provide a cleaner interface, tooling will complain slightly - async def _create( # pylint:disable=arguments-differ # type:ignore [override] - self, length: int - ) -> tuple[UserDefinedID, RandomTextEntry]: - return self.api.create(length) - - async def get(self, identifier: UserDefinedID, **_) -> RandomTextEntry | None: - return self.api.get(identifier) - - async def _destroy( - self, identifier: UserDefinedID, _: AnEmptyTextCleanupContext - ) -> None: - self.api.delete(identifier) - - -@pytest.fixture -async def redis_client_sdk( - redis_service: RedisSettings, -) -> AsyncIterator[RedisClientSDK]: - redis_resources_dns = redis_service.build_redis_dsn( - RedisDatabase.DISTRIBUTED_IDENTIFIERS - ) - - client = RedisClientSDK(redis_resources_dns, client_name="pytest") - assert client - assert client.redis_dsn == redis_resources_dns - # cleanup, previous run's leftovers - await client.redis.flushall() - - yield client - # cleanup, properly close the clients - await client.redis.flushall() - await client.shutdown() - - -@pytest.fixture -def component_using_random_text() -> ComponentUsingRandomText: - return ComponentUsingRandomText() - - -@pytest.fixture -async def manager_with_no_cleanup_task( - redis_client_sdk: RedisClientSDK, - component_using_random_text: ComponentUsingRandomText, -) -> RandomTextResourcesManager: - return RandomTextResourcesManager(redis_client_sdk, component_using_random_text) - - -@pytest.fixture -async def manager( - manager_with_no_cleanup_task: RandomTextResourcesManager, -) -> AsyncIterable[RandomTextResourcesManager]: - await manager_with_no_cleanup_task.setup() - yield manager_with_no_cleanup_task - await manager_with_no_cleanup_task.shutdown() - - -async def test_resource_is_missing(manager: RandomTextResourcesManager): - missing_identifier = UserDefinedID() - assert await manager.get(missing_identifier) is None - - -@pytest.mark.parametrize("delete_before_removal", [True, False]) -async def test_full_workflow( - manager: RandomTextResourcesManager, delete_before_removal: bool -): - # creation - identifier, _ = await manager.create( - cleanup_context=AnEmptyTextCleanupContext(), length=1 - ) - assert await manager.get(identifier) is not None - - # optional removal - if delete_before_removal: - await manager.remove(identifier) - - is_still_present = not delete_before_removal - assert (await manager.get(identifier) is not None) is is_still_present - - # safe remove the resource - await manager.remove(identifier) - - # resource no longer exists - assert await manager.get(identifier) is None - - -@pytest.mark.parametrize("reraise", [True, False]) -async def test_remove_raises_error( - mocker: MockerFixture, - manager: RandomTextResourcesManager, - caplog: pytest.LogCaptureFixture, - reraise: bool, -): - caplog.clear() - - error_message = "mock error during resource destroy" - mocker.patch.object(manager, "_destroy", side_effect=RuntimeError(error_message)) - - # after creation object is present - identifier, _ = await manager.create( - cleanup_context=AnEmptyTextCleanupContext(), length=1 - ) - assert await manager.get(identifier) is not None - - if reraise: - with pytest.raises(RuntimeError): - await manager.remove(identifier, reraise=reraise) - else: - await manager.remove(identifier, reraise=reraise) - # check logs in case of error - assert "Unhandled exception:" in caplog.text - assert error_message in caplog.text - - -async def _create_resources( - manager: RandomTextResourcesManager, count: int -) -> list[UserDefinedID]: - creation_results: list[tuple[UserDefinedID, RandomTextEntry]] = await logged_gather( - *[ - manager.create(cleanup_context=AnEmptyTextCleanupContext(), length=1) - for _ in range(count) - ], - max_concurrency=_MAX_REDIS_CONCURRENCY, - ) - return [x[0] for x in creation_results] - - -async def _assert_all_resources( - manager: RandomTextResourcesManager, - identifiers: list[UserDefinedID], - *, - exist: bool, -) -> None: - get_results: list[RandomTextEntry | None] = await logged_gather( - *[manager.get(identifier) for identifier in identifiers], - max_concurrency=_MAX_REDIS_CONCURRENCY, - ) - if exist: - assert all(x is not None for x in get_results) - else: - assert all(x is None for x in get_results) - - -@pytest.mark.parametrize("count", [1000]) -async def test_parallel_create_remove(manager: RandomTextResourcesManager, count: int): - # create resources - identifiers: list[UserDefinedID] = await _create_resources(manager, count) - await _assert_all_resources(manager, identifiers, exist=True) - - # safe remove the resources, they do not exist any longer - await asyncio.gather(*[manager.remove(identifier) for identifier in identifiers]) - await _assert_all_resources(manager, identifiers, exist=False) - - -async def test_background_removal_of_unused_resources( - manager_with_no_cleanup_task: RandomTextResourcesManager, - component_using_random_text: ComponentUsingRandomText, -): - # create resources - identifiers: list[UserDefinedID] = await _create_resources( - manager_with_no_cleanup_task, 10_000 - ) - await _assert_all_resources(manager_with_no_cleanup_task, identifiers, exist=True) - - # call cleanup, all resources still exist - await manager_with_no_cleanup_task._cleanup_unused_identifiers() # noqa: SLF001 - await _assert_all_resources(manager_with_no_cleanup_task, identifiers, exist=True) - - # make resources unused in external system - component_using_random_text.toggle_usage(in_use=False) - await manager_with_no_cleanup_task._cleanup_unused_identifiers() # noqa: SLF001 - await _assert_all_resources(manager_with_no_cleanup_task, identifiers, exist=False) - - -async def test_no_redis_key_overlap_when_inheriting( - redis_client_sdk: RedisClientSDK, - component_using_random_text: ComponentUsingRandomText, -): - class ChildRandomTextResourcesManager(RandomTextResourcesManager): - ... - - parent_manager = RandomTextResourcesManager( - redis_client_sdk, component_using_random_text - ) - child_manager = ChildRandomTextResourcesManager( - redis_client_sdk, component_using_random_text - ) - - # create an entry in the child and one in the parent - - parent_identifier, _ = await parent_manager.create( - cleanup_context=AnEmptyTextCleanupContext(), length=1 - ) - child_identifier, _ = await child_manager.create( - cleanup_context=AnEmptyTextCleanupContext(), length=1 - ) - assert parent_identifier != child_identifier - - keys = await redis_client_sdk.redis.keys("*") - assert len(keys) == 2 - - # check keys contain the correct prefixes - key_prefixes: set[str] = {k.split(":")[0] for k in keys} - assert key_prefixes == { - RandomTextResourcesManager.class_path(), - ChildRandomTextResourcesManager.class_path(), - } diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py b/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py index fd1d43e25aa..ba858c5940c 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py @@ -40,6 +40,7 @@ X_DYNAMIC_SIDECAR_REQUEST_SCHEME, X_SIMCORE_USER_AGENT, ) +from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import SchedulerData from simcore_service_director_v2.modules.dynamic_sidecar.errors import ( DynamicSidecarNotFoundError, @@ -52,6 +53,7 @@ pytest_simcore_core_services_selection = [ "postgres", + "redis", ] pytest_simcore_ops_services_selection = [ "adminer", @@ -99,6 +101,7 @@ def mock_env( mock_exclusive: None, disable_postgres: None, disable_rabbitmq: None, + redis_service: RedisSettings, monkeypatch: pytest.MonkeyPatch, faker: Faker, ) -> None: diff --git a/services/director-v2/tests/unit/with_dbs/test_cli.py b/services/director-v2/tests/unit/with_dbs/test_cli.py index 61f12ceaf15..cce3c67968d 100644 --- a/services/director-v2/tests/unit/with_dbs/test_cli.py +++ b/services/director-v2/tests/unit/with_dbs/test_cli.py @@ -23,6 +23,7 @@ from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.long_running_tasks.models import ProgressCallback +from settings_library.redis import RedisSettings from simcore_service_director_v2.cli import DEFAULT_NODE_SAVE_ATTEMPTS, main from simcore_service_director_v2.cli._close_and_save_service import ( ThinDV2LocalhostClient, @@ -31,6 +32,7 @@ pytest_simcore_core_services_selection = [ "postgres", + "redis", ] pytest_simcore_ops_services_selection = [ "adminer", @@ -41,6 +43,7 @@ def minimal_configuration( mock_env: EnvVarsDict, postgres_host_config: dict[str, str], + redis_service: RedisSettings, monkeypatch: pytest.MonkeyPatch, faker: Faker, with_product: dict[str, Any], diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_dynamic_sidecar_docker_service_specs.py b/services/director-v2/tests/unit/with_dbs/test_modules_dynamic_sidecar_docker_service_specs.py index 4618a9a9ba0..0bd3ccc48a0 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_dynamic_sidecar_docker_service_specs.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_dynamic_sidecar_docker_service_specs.py @@ -82,6 +82,7 @@ def mock_env( "RABBIT_PORT": "5672", "RABBIT_USER": "admin", "RABBIT_SECURE": "false", + "REDIS_SETTINGS": '{"REDIS_SECURE":false,"REDIS_HOST":"redis","REDIS_PORT":6789,"REDIS_USER":null,"REDIS_PASSWORD":null}', "REGISTRY_AUTH": "false", "REGISTRY_PW": "test", "REGISTRY_SSL": "false", @@ -280,6 +281,7 @@ def expected_dynamic_sidecar_spec( "RABBIT_PORT": "5672", "RABBIT_USER": "admin", "RABBIT_SECURE": "False", + "REDIS_SETTINGS": '{"REDIS_SECURE":false,"REDIS_HOST":"redis","REDIS_PORT":6789,"REDIS_USER":null,"REDIS_PASSWORD":null}', "R_CLONE_OPTION_BUFFER_SIZE": "16M", "R_CLONE_OPTION_RETRIES": "3", "R_CLONE_OPTION_TRANSFERS": "5", diff --git a/services/docker-compose-ops.yml b/services/docker-compose-ops.yml index dd15dd1ecbd..4b4e94ff453 100644 --- a/services/docker-compose-ops.yml +++ b/services/docker-compose-ops.yml @@ -92,7 +92,7 @@ services: scheduled_maintenance:${REDIS_HOST}:${REDIS_PORT}:3:${REDIS_PASSWORD}, user_notifications:${REDIS_HOST}:${REDIS_PORT}:4:${REDIS_PASSWORD}, announcements:${REDIS_HOST}:${REDIS_PORT}:5:${REDIS_PASSWORD}, - distributed_identifiers:${REDIS_HOST}:${REDIS_PORT}:6:${REDIS_PASSWORD}, + long_running_tasks:${REDIS_HOST}:${REDIS_PORT}:6:${REDIS_PASSWORD}, deferred_tasks:${REDIS_HOST}:${REDIS_PORT}:7:${REDIS_PASSWORD}, dynamic_services:${REDIS_HOST}:${REDIS_PORT}:8:${REDIS_PASSWORD}, celery_tasks:${REDIS_HOST}:${REDIS_PORT}:9:${REDIS_PASSWORD}, diff --git a/services/docker-compose.yml b/services/docker-compose.yml index b5546e697c8..8ed41a82657 100644 --- a/services/docker-compose.yml +++ b/services/docker-compose.yml @@ -1465,6 +1465,7 @@ services: networks: - default - autoscaling_subnet + - interactive_services_subnet volumes: - redis-data:/data healthcheck: diff --git a/services/dynamic-scheduler/requirements/_test.in b/services/dynamic-scheduler/requirements/_test.in index 840e5093b13..9a5d5c342dd 100644 --- a/services/dynamic-scheduler/requirements/_test.in +++ b/services/dynamic-scheduler/requirements/_test.in @@ -15,6 +15,7 @@ asgi_lifespan coverage docker faker +fakeredis[lua] hypercorn playwright pytest diff --git a/services/dynamic-scheduler/requirements/_test.txt b/services/dynamic-scheduler/requirements/_test.txt index 23d1dc27439..5142eac3843 100644 --- a/services/dynamic-scheduler/requirements/_test.txt +++ b/services/dynamic-scheduler/requirements/_test.txt @@ -23,6 +23,8 @@ docker==7.1.0 # via -r requirements/_test.in faker==36.2.2 # via -r requirements/_test.in +fakeredis==2.30.3 + # via -r requirements/_test.in greenlet==3.1.1 # via # -c requirements/_base.txt @@ -67,6 +69,8 @@ idna==3.10 # requests iniconfig==2.0.0 # via pytest +lupa==2.5 + # via fakeredis mypy==1.16.1 # via sqlalchemy mypy-extensions==1.1.0 @@ -118,6 +122,11 @@ python-dotenv==1.0.1 # via # -c requirements/_base.txt # -r requirements/_test.in +redis==5.2.1 + # via + # -c requirements/../../../requirements/constraints.txt + # -c requirements/_base.txt + # fakeredis requests==2.32.4 # via # -c requirements/_base.txt @@ -129,6 +138,8 @@ sniffio==1.3.1 # -c requirements/_base.txt # anyio # asgi-lifespan +sortedcontainers==2.4.0 + # via fakeredis sqlalchemy==1.4.54 # via # -c requirements/../../../requirements/constraints.txt diff --git a/services/dynamic-sidecar/requirements/_test.in b/services/dynamic-sidecar/requirements/_test.in index 35f081991de..7203b0a9320 100644 --- a/services/dynamic-sidecar/requirements/_test.in +++ b/services/dynamic-sidecar/requirements/_test.in @@ -8,6 +8,7 @@ asgi_lifespan async-asgi-testclient # replacement for fastapi.testclient.TestClient [see b) below] docker faker +fakeredis[lua] flaky pytest pytest-asyncio diff --git a/services/dynamic-sidecar/requirements/_test.txt b/services/dynamic-sidecar/requirements/_test.txt index 96b8735688b..764d36fb7bf 100644 --- a/services/dynamic-sidecar/requirements/_test.txt +++ b/services/dynamic-sidecar/requirements/_test.txt @@ -51,6 +51,8 @@ docker==7.1.0 # via -r requirements/_test.in faker==36.2.2 # via -r requirements/_test.in +fakeredis==2.30.3 + # via -r requirements/_test.in flaky==3.8.1 # via -r requirements/_test.in frozenlist==1.5.0 @@ -74,6 +76,8 @@ jmespath==1.0.1 # aiobotocore # boto3 # botocore +lupa==2.5 + # via fakeredis multidict==6.1.0 # via # -c requirements/_base.txt @@ -125,6 +129,11 @@ python-dotenv==1.0.1 # via # -c requirements/_base.txt # -r requirements/_test.in +redis==5.2.1 + # via + # -c requirements/../../../requirements/constraints.txt + # -c requirements/_base.txt + # fakeredis requests==2.32.4 # via # -c requirements/_base.txt @@ -140,6 +149,8 @@ sniffio==1.3.1 # via # -c requirements/_base.txt # asgi-lifespan +sortedcontainers==2.4.0 + # via fakeredis sqlalchemy==1.4.54 # via # -c requirements/../../../requirements/constraints.txt diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/cli.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/cli.py index 8b7b9086ef1..378ef7f12ed 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/cli.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/cli.py @@ -11,6 +11,7 @@ from ._meta import PROJECT_NAME from .core.application import create_base_app +from .core.rabbitmq import setup_rabbitmq from .core.settings import ApplicationSettings from .modules.long_running_tasks import task_ports_outputs_push, task_save_state from .modules.mounted_fs import MountedVolumes, setup_mounted_fs @@ -39,6 +40,7 @@ async def _initialized_app() -> AsyncIterator[FastAPI]: app = create_base_app() # setup MountedVolumes + setup_rabbitmq(app) setup_mounted_fs(app) setup_outputs(app) diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py index dad8f18dd59..46cf61b0ebd 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py @@ -146,7 +146,11 @@ def create_base_app() -> FastAPI: override_fastapi_openapi_method(app) app.state.settings = app_settings - long_running_tasks.server.setup(app) + long_running_tasks.server.setup( + app, + redis_settings=app_settings.REDIS_SETTINGS, + redis_namespace=f"dy_sidecar-{app_settings.DY_SIDECAR_RUN_ID}", + ) app.include_router(get_main_router(app)) diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/settings.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/settings.py index 4187f08b02c..ffb972d5fcc 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/settings.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/settings.py @@ -29,6 +29,7 @@ from settings_library.postgres import PostgresSettings from settings_library.r_clone import RCloneSettings from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings from settings_library.resource_usage_tracker import ( DEFAULT_RESOURCE_USAGE_HEARTBEAT_INTERVAL, ) @@ -187,6 +188,9 @@ class ApplicationSettings(BaseApplicationSettings, MixinLoggingSettings): RABBIT_SETTINGS: RabbitSettings = Field( json_schema_extra={"auto_default_from_env": True} ) + REDIS_SETTINGS: RedisSettings = Field( + json_schema_extra={"auto_default_from_env": True} + ) DY_DEPLOYMENT_REGISTRY_SETTINGS: RegistrySettings = Field() DY_DOCKER_HUB_REGISTRY_SETTINGS: RegistrySettings | None = Field(default=None) diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py index f14f90be1ef..2e326a11258 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py @@ -149,11 +149,11 @@ async def task_pull_user_servcices_docker_images( ) -> None: assert shared_store.compose_spec # nosec - progress.update(message="started pulling user services", percent=0) + await progress.update(message="started pulling user services", percent=0) await docker_compose_pull(app, shared_store.compose_spec) - progress.update(message="finished pulling user services", percent=1) + await progress.update(message="finished pulling user services", percent=1) async def task_create_service_containers( @@ -164,7 +164,7 @@ async def task_create_service_containers( app: FastAPI, application_health: ApplicationHealth, ) -> list[str]: - progress.update(message="validating service spec", percent=0) + await progress.update(message="validating service spec", percent=0) assert shared_store.compose_spec # nosec @@ -185,19 +185,19 @@ async def task_create_service_containers( await progress_bar.update() # removes previous pending containers - progress.update(message="cleanup previous used resources") + await progress.update(message="cleanup previous used resources") result = await docker_compose_rm(shared_store.compose_spec, settings) _raise_for_errors(result, "rm") await progress_bar.update() - progress.update(message="creating and starting containers", percent=0.90) + await progress.update(message="creating and starting containers", percent=0.90) await post_sidecar_log_message( app, "starting service containers", log_level=logging.INFO ) await _retry_docker_compose_create(shared_store.compose_spec, settings) await progress_bar.update() - progress.update(message="ensure containers are started", percent=0.95) + await progress.update(message="ensure containers are started", percent=0.95) compose_start_result = await _retry_docker_compose_start( shared_store.compose_spec, settings ) @@ -280,7 +280,7 @@ async def _send_resource_tracking_stop(platform_status: SimcorePlatformStatus): await send_service_stopped(app, simcore_platform_status) try: - progress.update(message="running docker-compose-down", percent=0.1) + await progress.update(message="running docker-compose-down", percent=0.1) await run_before_shutdown_actions( shared_store, settings.DY_SIDECAR_CALLBACKS_MAPPING.before_shutdown @@ -293,11 +293,11 @@ async def _send_resource_tracking_stop(platform_status: SimcorePlatformStatus): result = await _retry_docker_compose_down(shared_store.compose_spec, settings) _raise_for_errors(result, "down") - progress.update(message="stopping logs", percent=0.9) + await progress.update(message="stopping logs", percent=0.9) for container_name in shared_store.container_names: await stop_log_fetching(app, container_name) - progress.update(message="removing pending resources", percent=0.95) + await progress.update(message="removing pending resources", percent=0.95) result = await docker_compose_rm(shared_store.compose_spec, settings) _raise_for_errors(result, "rm") except Exception: @@ -314,7 +314,7 @@ async def _send_resource_tracking_stop(platform_status: SimcorePlatformStatus): async with shared_store: shared_store.compose_spec = None shared_store.container_names = [] - progress.update(message="done", percent=0.99) + await progress.update(message="done", percent=0.99) def _get_satate_folders_size(paths: list[Path]) -> int: @@ -377,7 +377,7 @@ async def task_restore_state( # NOTE: this implies that the legacy format will always be decompressed # until it is not removed. - progress.update(message="Downloading state", percent=0.05) + await progress.update(message="Downloading state", percent=0.05) state_paths = list(mounted_volumes.disk_state_paths_iter()) await post_sidecar_log_message( app, @@ -407,7 +407,7 @@ async def task_restore_state( await post_sidecar_log_message( app, "Finished state downloading", log_level=logging.INFO ) - progress.update(message="state restored", percent=0.99) + await progress.update(message="state restored", percent=0.99) return _get_satate_folders_size(state_paths) @@ -447,7 +447,7 @@ async def task_save_state( If a legacy archive is detected, it will be removed after saving the new format. """ - progress.update(message="starting state save", percent=0.0) + await progress.update(message="starting state save", percent=0.0) state_paths = list(mounted_volumes.disk_state_paths_iter()) async with ProgressBarData( num_steps=len(state_paths), @@ -473,7 +473,7 @@ async def task_save_state( ) await post_sidecar_log_message(app, "Finished state saving", log_level=logging.INFO) - progress.update(message="finished state saving", percent=0.99) + await progress.update(message="finished state saving", percent=0.99) return _get_satate_folders_size(state_paths) @@ -491,12 +491,12 @@ async def task_ports_inputs_pull( _logger.info("Received request to pull inputs but was ignored") return 0 - progress.update(message="starting inputs pulling", percent=0.0) + await progress.update(message="starting inputs pulling", percent=0.0) port_keys = [] if port_keys is None else port_keys await post_sidecar_log_message( app, f"Pulling inputs for {port_keys}", log_level=logging.INFO ) - progress.update(message="pulling inputs", percent=0.1) + await progress.update(message="pulling inputs", percent=0.1) async with ProgressBarData( num_steps=1, progress_report_cb=functools.partial( @@ -527,7 +527,7 @@ async def task_ports_inputs_pull( await post_sidecar_log_message( app, "Finished pulling inputs", log_level=logging.INFO ) - progress.update(message="finished inputs pulling", percent=0.99) + await progress.update(message="finished inputs pulling", percent=0.99) return int(transferred_bytes) @@ -537,7 +537,7 @@ async def task_ports_outputs_pull( mounted_volumes: MountedVolumes, app: FastAPI, ) -> int: - progress.update(message="starting outputs pulling", percent=0.0) + await progress.update(message="starting outputs pulling", percent=0.0) port_keys = [] if port_keys is None else port_keys await post_sidecar_log_message( app, f"Pulling output for {port_keys}", log_level=logging.INFO @@ -564,14 +564,14 @@ async def task_ports_outputs_pull( await post_sidecar_log_message( app, "Finished pulling outputs", log_level=logging.INFO ) - progress.update(message="finished outputs pulling", percent=0.99) + await progress.update(message="finished outputs pulling", percent=0.99) return int(transferred_bytes) async def task_ports_outputs_push( progress: TaskProgress, outputs_manager: OutputsManager, app: FastAPI ) -> None: - progress.update(message="starting outputs pushing", percent=0.0) + await progress.update(message="starting outputs pushing", percent=0.0) await post_sidecar_log_message( app, f"waiting for outputs {outputs_manager.outputs_context.file_type_port_keys} to be pushed", @@ -583,7 +583,7 @@ async def task_ports_outputs_push( await post_sidecar_log_message( app, "finished outputs pushing", log_level=logging.INFO ) - progress.update(message="finished outputs pushing", percent=0.99) + await progress.update(message="finished outputs pushing", percent=0.99) async def task_containers_restart( @@ -598,7 +598,7 @@ async def task_containers_restart( # or some other state, the service will get shutdown, to prevent this # blocking status while containers are being restarted. async with app.state.container_restart_lock: - progress.update(message="starting containers restart", percent=0.0) + await progress.update(message="starting containers restart", percent=0.0) if shared_store.compose_spec is None: msg = "No spec for docker compose command was found" raise RuntimeError(msg) @@ -606,23 +606,23 @@ async def task_containers_restart( for container_name in shared_store.container_names: await stop_log_fetching(app, container_name) - progress.update(message="stopped log fetching", percent=0.1) + await progress.update(message="stopped log fetching", percent=0.1) result = await docker_compose_restart(shared_store.compose_spec, settings) _raise_for_errors(result, "restart") - progress.update(message="containers restarted", percent=0.8) + await progress.update(message="containers restarted", percent=0.8) for container_name in shared_store.container_names: await start_log_fetching(app, container_name) - progress.update(message="started log fetching", percent=0.9) + await progress.update(message="started log fetching", percent=0.9) await post_sidecar_log_message( app, "Service was restarted please reload the UI", log_level=logging.INFO ) await post_event_reload_iframe(app) - progress.update(message="started log fetching", percent=0.99) + await progress.update(message="started log fetching", percent=0.99) for task in ( diff --git a/services/dynamic-sidecar/tests/conftest.py b/services/dynamic-sidecar/tests/conftest.py index c85d29105a9..1bf41e23c83 100644 --- a/services/dynamic-sidecar/tests/conftest.py +++ b/services/dynamic-sidecar/tests/conftest.py @@ -28,6 +28,7 @@ setenvs_from_dict, setenvs_from_envfile, ) +from settings_library.redis import RedisSettings from simcore_service_dynamic_sidecar.core.reserved_space import ( remove_reserved_disk_space, ) @@ -168,6 +169,7 @@ def mock_rabbit_check(mocker: MockerFixture) -> None: @pytest.fixture def base_mock_envs( + use_in_memory_redis: RedisSettings, dy_volumes: Path, shared_store_dir: Path, compose_namespace: str, @@ -209,6 +211,7 @@ def base_mock_envs( @pytest.fixture def mock_environment( + use_in_memory_redis: RedisSettings, mock_storage_check: None, mock_postgres_check: None, mock_rabbit_check: None, diff --git a/services/dynamic-sidecar/tests/integration/test_modules_long_running_tasks.py b/services/dynamic-sidecar/tests/integration/test_modules_long_running_tasks.py index 221f7f0a10f..0fa9d821ac3 100644 --- a/services/dynamic-sidecar/tests/integration/test_modules_long_running_tasks.py +++ b/services/dynamic-sidecar/tests/integration/test_modules_long_running_tasks.py @@ -31,6 +31,8 @@ from pytest_simcore.helpers.storage import replace_storage_endpoint from servicelib.long_running_tasks.models import TaskProgress from servicelib.utils import logged_gather +from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings from settings_library.s3 import S3Settings from simcore_postgres_database.models.projects import projects from simcore_sdk.node_ports_common.constants import SIMCORE_LOCATION @@ -89,7 +91,8 @@ def project_id(user_id: int, postgres_db: sa.engine.Engine) -> Iterable[ProjectI def mock_environment( mock_storage_check: None, mock_rabbit_check: None, - rabbit_service, + redis_service: RedisSettings, + rabbit_service: RabbitSettings, postgres_host_config: PostgresTestConfig, storage_endpoint: URL, minio_s3_settings_envs: EnvVarsDict, diff --git a/services/dynamic-sidecar/tests/integration/test_modules_user_services_preferences.py b/services/dynamic-sidecar/tests/integration/test_modules_user_services_preferences.py index 9be0bbdebbf..73eb9781bea 100644 --- a/services/dynamic-sidecar/tests/integration/test_modules_user_services_preferences.py +++ b/services/dynamic-sidecar/tests/integration/test_modules_user_services_preferences.py @@ -17,6 +17,7 @@ from pydantic import TypeAdapter from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from pytest_simcore.helpers.postgres_tools import PostgresTestConfig +from settings_library.redis import RedisSettings from simcore_service_dynamic_sidecar.core.application import create_app from simcore_service_dynamic_sidecar.modules.user_services_preferences import ( load_user_services_preferences, @@ -32,6 +33,7 @@ pytest_simcore_core_services_selection = [ "migration", "postgres", + "redis", ] pytest_simcore_ops_services_selection = [ @@ -62,9 +64,10 @@ def product_name() -> ProductName: @pytest.fixture -def mock_environment( # pylint:disable=too-many-arguments +def mock_environment( # pylint:disable=too-many-arguments,too-many-positional-arguments mock_rabbit_check: None, mock_storage_check: None, + redis_service: RedisSettings, postgres_host_config: PostgresTestConfig, monkeypatch: pytest.MonkeyPatch, base_mock_envs: EnvVarsDict, diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py b/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py index 932a381fac6..75180ce4a00 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py @@ -2,6 +2,7 @@ # pylint: disable=unused-argument # pylint: disable=no-member +import asyncio import json from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator from contextlib import asynccontextmanager, contextmanager @@ -45,6 +46,12 @@ from simcore_service_dynamic_sidecar.modules.inputs import enable_inputs_pulling from simcore_service_dynamic_sidecar.modules.outputs._context import OutputsContext from simcore_service_dynamic_sidecar.modules.outputs._manager import OutputsManager +from tenacity import ( + AsyncRetrying, + retry_if_exception_type, + stop_after_delay, + wait_fixed, +) FAST_STATUS_POLL: Final[float] = 0.1 CREATE_SERVICE_CONTAINERS_TIMEOUT: Final[float] = 60 @@ -384,6 +391,20 @@ async def _debug_progress( print(f"{task_id} {percent} {message}") +async def _assert_progress_finished( + last_progress_message: tuple[ProgressMessage, ProgressPercent] | None, +) -> None: + async for attempt in AsyncRetrying( + stop=stop_after_delay(10), + wait=wait_fixed(0.1), + retry=retry_if_exception_type(AssertionError), + reraise=True, + ): + with attempt: + await asyncio.sleep(0) # yield control to the event loop + assert last_progress_message == ("finished", 1.0) + + async def test_create_containers_task( httpx_async_client: AsyncClient, client: Client, @@ -392,10 +413,13 @@ async def test_create_containers_task( mock_metrics_params: CreateServiceMetricsAdditionalParams, shared_store: SharedStore, ) -> None: - last_progress_message: tuple[str, float] | None = None + last_progress_message: tuple[ProgressMessage, ProgressPercent] | None = None - async def create_progress(message: str, percent: float, _: TaskId) -> None: + async def create_progress( + message: ProgressMessage, percent: ProgressPercent | None, _: TaskId + ) -> None: nonlocal last_progress_message + assert percent is not None last_progress_message = (message, percent) print(message, percent) @@ -410,7 +434,7 @@ async def create_progress(message: str, percent: float, _: TaskId) -> None: ) as result: assert shared_store.container_names == result - assert last_progress_message == ("finished", 1.0) + await _assert_progress_finished(last_progress_message) async def test_pull_user_servcices_docker_images( @@ -442,7 +466,7 @@ async def create_progress( ) as result: assert shared_store.container_names == result - assert last_progress_message == ("finished", 1.0) + await _assert_progress_finished(last_progress_message) async with periodic_task_result( client=client, @@ -454,7 +478,7 @@ async def create_progress( progress_callback=_debug_progress, ) as result: assert result is None - assert last_progress_message == ("finished", 1.0) + await _assert_progress_finished(last_progress_message) async def test_create_containers_task_invalid_yaml_spec( diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py b/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py index 62755395c99..145fd791fd3 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py @@ -82,7 +82,9 @@ def backend_url() -> AnyHttpUrl: @pytest.fixture def mock_environment( - monkeypatch: pytest.MonkeyPatch, mock_rabbitmq_envs: EnvVarsDict + mock_postgres_check: None, + monkeypatch: pytest.MonkeyPatch, + mock_rabbitmq_envs: EnvVarsDict, ) -> EnvVarsDict: setenvs_from_dict( monkeypatch, diff --git a/services/dynamic-sidecar/tests/unit/test_cli.py b/services/dynamic-sidecar/tests/unit/test_cli.py index a06fc698efc..855c21497d5 100644 --- a/services/dynamic-sidecar/tests/unit/test_cli.py +++ b/services/dynamic-sidecar/tests/unit/test_cli.py @@ -1,20 +1,41 @@ # pylint: disable=unused-argument # pylint: disable=redefined-outer-name - +import json import os import traceback +from pprint import pprint import pytest from click.testing import Result +from common_library.serialization import model_dump_with_secrets from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict +from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings from simcore_service_dynamic_sidecar.cli import main from typer.testing import CliRunner +pytest_simcore_core_services_selection = [ + "redis", + "rabbit", +] + @pytest.fixture -def cli_runner(mock_environment: EnvVarsDict) -> CliRunner: - return CliRunner() +def cli_runner( + rabbit_service: RabbitSettings, + redis_service: RedisSettings, + mock_environment: EnvVarsDict, +) -> CliRunner: + mock_environment["REDIS_SETTINGS"] = json.dumps( + model_dump_with_secrets(redis_service, show_secrets=True) + ) + mock_environment["RABBIT_SETTINGS"] = json.dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ) + + pprint(mock_environment) + return CliRunner(env=mock_environment) @pytest.fixture diff --git a/services/efs-guardian/src/simcore_service_efs_guardian/services/modules/redis.py b/services/efs-guardian/src/simcore_service_efs_guardian/services/modules/redis.py index 78d1462378a..74cf65b320e 100644 --- a/services/efs-guardian/src/simcore_service_efs_guardian/services/modules/redis.py +++ b/services/efs-guardian/src/simcore_service_efs_guardian/services/modules/redis.py @@ -18,6 +18,7 @@ async def on_startup() -> None: app.state.redis_lock_client_sdk = RedisClientSDK( redis_locks_dsn, client_name=APP_NAME ) + await app.state.redis_lock_client_sdk.setup() async def on_shutdown() -> None: redis_lock_client_sdk: None | RedisClientSDK = app.state.redis_lock_client_sdk diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/redis.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/redis.py index e2790b2a4e9..84e0df512e5 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/redis.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/redis.py @@ -24,6 +24,7 @@ async def on_startup() -> None: app.state.redis_client_sdk = RedisClientSDK( redis_locks_dsn, client_name=APP_NAME ) + await app.state.redis_client_sdk.setup() async def on_shutdown() -> None: with log_context( diff --git a/services/storage/src/simcore_service_storage/core/settings.py b/services/storage/src/simcore_service_storage/core/settings.py index 4d246a89eeb..a3725ac4857 100644 --- a/services/storage/src/simcore_service_storage/core/settings.py +++ b/services/storage/src/simcore_service_storage/core/settings.py @@ -36,7 +36,7 @@ class ApplicationSettings(BaseApplicationSettings, MixinLoggingSettings): ] STORAGE_REDIS: Annotated[ - RedisSettings | None, Field(json_schema_extra={"auto_default_from_env": True}) + RedisSettings, Field(json_schema_extra={"auto_default_from_env": True}) ] STORAGE_S3: Annotated[ diff --git a/services/storage/src/simcore_service_storage/modules/long_running_tasks.py b/services/storage/src/simcore_service_storage/modules/long_running_tasks.py index 229c1bd3fef..834ec2dcbb7 100644 --- a/services/storage/src/simcore_service_storage/modules/long_running_tasks.py +++ b/services/storage/src/simcore_service_storage/modules/long_running_tasks.py @@ -1,11 +1,19 @@ +from typing import Final + from fastapi import FastAPI from servicelib.fastapi.long_running_tasks._server import setup +from servicelib.long_running_tasks.task import RedisNamespace from .._meta import API_VTAG +from ..core.settings import get_application_settings + +_LONG_RUNNING_TASKS_NAMESPACE: Final[RedisNamespace] = "storage" def setup_rest_api_long_running_tasks_for_uploads(app: FastAPI) -> None: setup( app, router_prefix=f"/{API_VTAG}/futures", + redis_settings=get_application_settings(app).STORAGE_REDIS, + redis_namespace=_LONG_RUNNING_TASKS_NAMESPACE, ) diff --git a/services/storage/src/simcore_service_storage/modules/redis.py b/services/storage/src/simcore_service_storage/modules/redis.py index 6b2c15476ec..8b28eb58264 100644 --- a/services/storage/src/simcore_service_storage/modules/redis.py +++ b/services/storage/src/simcore_service_storage/modules/redis.py @@ -15,11 +15,11 @@ def setup(app: FastAPI) -> None: async def on_startup() -> None: app.state.redis_client_sdk = None redis_settings = get_application_settings(app).STORAGE_REDIS - assert redis_settings # nosec redis_locks_dsn = redis_settings.build_redis_dsn(RedisDatabase.LOCKS) app.state.redis_client_sdk = RedisClientSDK( redis_locks_dsn, client_name=APP_NAME ) + await app.state.redis_client_sdk.setup() async def on_shutdown() -> None: redis_client_sdk = app.state.redis_client_sdk diff --git a/services/web/server/requirements/_test.in b/services/web/server/requirements/_test.in index d8afabb9146..a05cd9ce38c 100644 --- a/services/web/server/requirements/_test.in +++ b/services/web/server/requirements/_test.in @@ -15,8 +15,9 @@ click coverage docker Faker -fastapi[standard] +fakeredis[lua] fastapi-pagination +fastapi[standard] flaky hypothesis jsonref diff --git a/services/web/server/requirements/_test.txt b/services/web/server/requirements/_test.txt index 8f7e0133255..5af91daf778 100644 --- a/services/web/server/requirements/_test.txt +++ b/services/web/server/requirements/_test.txt @@ -78,6 +78,8 @@ faker==19.6.1 # via # -c requirements/_base.txt # -r requirements/_test.in +fakeredis==2.30.3 + # via -r requirements/_test.in fastapi==0.115.6 # via -r requirements/_test.in fastapi-cli==0.0.5 @@ -139,6 +141,8 @@ jsonschema==3.2.0 # -r requirements/_test.in # openapi-schema-validator # openapi-spec-validator +lupa==2.5 + # via fakeredis mako==1.3.10 # via # -c requirements/../../../../requirements/constraints.txt @@ -265,6 +269,7 @@ redis==5.2.1 # -c requirements/../../../../requirements/constraints.txt # -c requirements/_base.txt # -r requirements/_test.in + # fakeredis referencing==0.8.11 # via # -c requirements/../../../../requirements/constraints.txt @@ -298,7 +303,9 @@ sniffio==1.3.1 # -c requirements/_base.txt # anyio sortedcontainers==2.4.0 - # via hypothesis + # via + # fakeredis + # hypothesis sqlalchemy==1.4.47 # via # -c requirements/../../../../requirements/constraints.txt diff --git a/services/web/server/src/simcore_service_webserver/long_running_tasks.py b/services/web/server/src/simcore_service_webserver/long_running_tasks.py index 411a7eaae4f..a97c82a5852 100644 --- a/services/web/server/src/simcore_service_webserver/long_running_tasks.py +++ b/services/web/server/src/simcore_service_webserver/long_running_tasks.py @@ -1,17 +1,26 @@ +import logging from functools import wraps +from typing import Final from aiohttp import web from models_library.utils.fastapi_encoders import jsonable_encoder +from servicelib.aiohttp.application_setup import ensure_single_setup from servicelib.aiohttp.long_running_tasks._constants import ( RQT_LONG_RUNNING_TASKS_CONTEXT_KEY, ) from servicelib.aiohttp.long_running_tasks.server import setup from servicelib.aiohttp.typing_extension import Handler +from servicelib.long_running_tasks.task import RedisNamespace +from . import redis from ._meta import API_VTAG from .login.decorators import login_required from .models import AuthenticatedRequestContext +_logger = logging.getLogger(__name__) + +_LONG_RUNNING_TASKS_NAMESPACE: Final[RedisNamespace] = "webserver-legacy" + def webserver_request_context_decorator(handler: Handler): @wraps(handler) @@ -26,9 +35,12 @@ async def _test_task_context_decorator( return _test_task_context_decorator +@ensure_single_setup(__name__, logger=_logger) def setup_long_running_tasks(app: web.Application) -> None: setup( app, + redis_settings=redis.get_plugin_settings(app), + redis_namespace=_LONG_RUNNING_TASKS_NAMESPACE, router_prefix=f"/{API_VTAG}/tasks-legacy", handler_check_decorator=login_required, task_request_context_decorator=webserver_request_context_decorator, diff --git a/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py b/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py index 917d678be42..7761e17ad13 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py +++ b/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py @@ -180,7 +180,7 @@ async def _copy() -> None: user_id=user_id, product_name=product_name, ): - task_progress.update( + await task_progress.update( message=( async_job_composed_result.status.progress.message.description if async_job_composed_result.status.progress.message @@ -296,7 +296,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche copy_file_coro = None project_nodes = None try: - progress.update(message="creating new study...") + await progress.update(message="creating new study...") workspace_id = None folder_id = None @@ -384,7 +384,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche parent_project_uuid=parent_project_uuid, parent_node_id=parent_node_id, ) - progress.update() + await progress.update() # 3.2 move project to proper folder if folder_id: @@ -414,7 +414,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche await dynamic_scheduler_service.update_projects_networks( request.app, project_id=ProjectID(new_project["uuid"]) ) - progress.update() + await progress.update() # This is a new project and every new graph needs to be reflected in the pipeline tables await director_v2_service.create_or_update_pipeline( @@ -435,7 +435,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche is_template=as_template, app=request.app, ) - progress.update() + await progress.update() # Adds permalink await update_or_pop_permalink_in_project(request, new_project) diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index bc4e8d08f4f..375165e7c77 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -58,7 +58,7 @@ @webserver_request_context_decorator async def get_async_jobs(request: web.Request) -> web.Response: inprocess_long_running_manager = get_long_running_manager(request.app) - inprocess_tracked_tasks = lrt_api.list_tasks( + inprocess_tracked_tasks = await lrt_api.list_tasks( inprocess_long_running_manager.tasks_manager, inprocess_long_running_manager.get_task_context(request), ) diff --git a/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py b/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py index 0b5f601d8f0..4a1f654e378 100644 --- a/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py +++ b/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py @@ -77,7 +77,7 @@ async def test_listing_tasks_with_list_inprocess_tasks_error( assert client.app class _DummyTaskManager: - def list_tasks(self, *args, **kwargs): + async def list_tasks(self, *args, **kwargs): raise Exception # pylint: disable=broad-exception-raised # noqa: TRY002 mock = Mock() diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py b/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py index c94efff772a..fdd9ec0549b 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py @@ -27,6 +27,7 @@ from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import ( AsyncJobComposedResult, ) +from settings_library.redis import RedisSettings from simcore_postgres_database.models.users import UserRole from simcore_service_webserver._meta import api_version_prefix from simcore_service_webserver.application_settings import get_application_settings @@ -40,7 +41,9 @@ @pytest.fixture def app_environment( - app_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPatch + use_in_memory_redis: RedisSettings, + app_environment: EnvVarsDict, + monkeypatch: pytest.MonkeyPatch, ) -> EnvVarsDict: envs_plugins = setenvs_from_dict( monkeypatch, @@ -91,6 +94,7 @@ def _standard_user_role_response() -> ( @pytest.mark.parametrize(*_standard_user_role_response()) async def test_copying_large_project_and_aborting_correctly_removes_new_project( + mock_dynamic_scheduler: None, client: TestClient, logged_user: dict[str, Any], primary_group: dict[str, str], @@ -142,6 +146,7 @@ async def test_copying_large_project_and_aborting_correctly_removes_new_project( @pytest.mark.parametrize(*_standard_user_role_response()) async def test_copying_large_project_and_retrieving_copy_task( + mock_dynamic_scheduler: None, client: TestClient, logged_user: dict[str, Any], primary_group: dict[str, str], @@ -296,6 +301,7 @@ async def test_copying_too_large_project_returns_422( ): assert client.app app_settings = get_application_settings(client.app) + assert app_settings.WEBSERVER_PROJECTS large_project_total_size = ( app_settings.WEBSERVER_PROJECTS.PROJECTS_MAX_COPY_SIZE_BYTES + 1 ) diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers__list.py b/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers__list.py index a65cd0ebe30..5596fdfdf1e 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers__list.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers__list.py @@ -18,6 +18,7 @@ standard_role_response, ) from servicelib.aiohttp import status +from settings_library.redis import RedisSettings from simcore_service_webserver._meta import api_version_prefix from simcore_service_webserver.db.models import UserRole from simcore_service_webserver.projects.models import ProjectDict @@ -151,6 +152,7 @@ async def test_list_projects_with_invalid_pagination_parameters( @pytest.mark.parametrize("limit", [7, 20, 43]) @pytest.mark.parametrize(*standard_user_role()) async def test_list_projects_with_pagination( + use_in_memory_redis: RedisSettings, mock_dynamic_scheduler: None, client: TestClient, logged_user: dict[str, Any], diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py b/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py index 4db7e6f8a22..ef3fd0c98bd 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py @@ -6,6 +6,7 @@ import asyncio import re from collections.abc import Awaitable, Callable +from contextlib import suppress from dataclasses import dataclass, field from datetime import datetime, timedelta from http import HTTPStatus @@ -46,12 +47,21 @@ ) from servicelib.aiohttp import status from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE +from settings_library.redis import RedisSettings from simcore_postgres_database.models.projects import projects as projects_db_model from simcore_service_webserver.db.models import UserRole from simcore_service_webserver.projects._controller.nodes_rest import ( _ProjectNodePreview, ) from simcore_service_webserver.projects.models import ProjectDict +from tenacity import ( + AsyncRetrying, + RetryError, + retry_if_exception_type, + retry_unless_exception_type, + stop_after_delay, + wait_fixed, +) @pytest.mark.parametrize( @@ -916,6 +926,7 @@ async def test_start_node_raises_if_called_with_wrong_data( @pytest.mark.parametrize(*standard_role_response(), ids=str) async def test_stop_node( + use_in_memory_redis: RedisSettings, client: TestClient, user_project_with_num_dynamic_services: Callable[[int], Awaitable[ProjectDict]], user_role: UserRole, @@ -935,18 +946,34 @@ async def test_stop_node( project_id=project["uuid"], node_id=choice(all_service_uuids) # noqa: S311 ) response = await client.post(f"{url}") - data, error = await assert_status( + _, error = await assert_status( response, status.HTTP_202_ACCEPTED if user_role == UserRole.GUEST else expected.accepted, ) + if error is None: - mocked_dynamic_services_interface[ - "dynamic_scheduler.api.stop_dynamic_service" - ].assert_called_once() + async for attempt in AsyncRetrying( + wait=wait_fixed(0.1), + stop=stop_after_delay(5), + retry=retry_if_exception_type(AssertionError), + reraise=True, + ): + with attempt: + mocked_dynamic_services_interface[ + "dynamic_scheduler.api.stop_dynamic_service" + ].assert_called_once() else: - mocked_dynamic_services_interface[ - "dynamic_scheduler.api.stop_dynamic_service" - ].assert_not_called() + with suppress(RetryError): + async for attempt in AsyncRetrying( + wait=wait_fixed(0.1), + stop=stop_after_delay(5), + retry=retry_unless_exception_type(AssertionError), + reraise=True, + ): + with attempt: + mocked_dynamic_services_interface[ + "dynamic_scheduler.api.stop_dynamic_service" + ].assert_not_called() @pytest.fixture diff --git a/services/web/server/tests/unit/with_dbs/docker-compose-devel.yml b/services/web/server/tests/unit/with_dbs/docker-compose-devel.yml index dd89755d90d..c8fcc39cb1b 100644 --- a/services/web/server/tests/unit/with_dbs/docker-compose-devel.yml +++ b/services/web/server/tests/unit/with_dbs/docker-compose-devel.yml @@ -82,7 +82,7 @@ services: scheduled_maintenance:redis:6379:3:${TEST_REDIS_PASSWORD}, user_notifications:redis:6379:4:${TEST_REDIS_PASSWORD}, announcements:redis:6379:5:${TEST_REDIS_PASSWORD}, - distributed_identifiers:redis:6379:6:${TEST_REDIS_PASSWORD}, + long_running_tasks:redis:6379:6:${TEST_REDIS_PASSWORD}, deferred_tasks:redis:6379:7:${TEST_REDIS_PASSWORD}, dynamic_services:${REDIS_HOST}:${REDIS_PORT}:8:${TEST_REDIS_PASSWORD}, celery_tasks:${REDIS_HOST}:${REDIS_PORT}:9:${TEST_REDIS_PASSWORD},