diff --git a/.env-devel b/.env-devel index 834bc8de0af2..974079c611c4 100644 --- a/.env-devel +++ b/.env-devel @@ -147,6 +147,7 @@ LICENSES_ITIS_VIP_API_URL=https://replace-with-itis-api/{category} LICENSES_ITIS_VIP_CATEGORIES='{"HumanWholeBody": "Humans", "HumanBodyRegion": "Humans (Region)", "AnimalWholeBody": "Animal"}' LICENSES_SPEAG_PHANTOMS_API_URL=https://replace-with-speag-api/{category} LICENSES_SPEAG_PHANTOMS_CATEGORIES='{"ComputationalPhantom": "Phantom of the Opera"}' +LONG_RUNNING_TASKS_NAMESPACE_SUFFIX=development # Can use 'docker run -it itisfoundation/invitations:latest simcore-service-invitations generate-dotenv --auto-password' INVITATIONS_DEFAULT_PRODUCT=osparc diff --git a/.github/workflows/ci-testing-deploy.yml b/.github/workflows/ci-testing-deploy.yml index b130c5254ba4..b225a782d11f 100644 --- a/.github/workflows/ci-testing-deploy.yml +++ b/.github/workflows/ci-testing-deploy.yml @@ -1252,7 +1252,7 @@ jobs: unit-test-service-library: needs: changes if: ${{ needs.changes.outputs.service-library == 'true' || github.event_name == 'push' || github.event.inputs.force_all_builds == 'true' }} - timeout-minutes: 18 # if this timeout gets too small, then split the tests + timeout-minutes: 20 # if this timeout gets too small, then split the tests name: "[unit] service-library" runs-on: ${{ matrix.os }} strategy: diff --git a/api/specs/web-server/_long_running_tasks.py b/api/specs/web-server/_long_running_tasks.py index 1c6e033b867d..77b979b2e3e8 100644 --- a/api/specs/web-server/_long_running_tasks.py +++ b/api/specs/web-server/_long_running_tasks.py @@ -32,44 +32,40 @@ @router.get( "/tasks", response_model=Envelope[list[TaskGet]], - name="list_tasks", - description="Lists all long running tasks", responses=_export_data_responses, ) -def get_async_jobs(): ... +def get_async_jobs(): + """Lists all long running tasks""" @router.get( "/tasks/{task_id}", response_model=Envelope[TaskStatus], - name="get_task_status", - description="Retrieves the status of a task", responses=_export_data_responses, ) def get_async_job_status( _path_params: Annotated[_PathParam, Depends()], -): ... +): + """Retrieves the status of a task""" @router.delete( "/tasks/{task_id}", - name="cancel_and_delete_task", - description="Cancels and deletes a task", responses=_export_data_responses, status_code=status.HTTP_204_NO_CONTENT, ) def cancel_async_job( _path_params: Annotated[_PathParam, Depends()], -): ... +): + """Cancels and removes a task""" @router.get( "/tasks/{task_id}/result", response_model=Any, - name="get_task_result", - description="Retrieves the result of a task", responses=_export_data_responses, ) def get_async_job_result( _path_params: Annotated[_PathParam, Depends()], -): ... +): + """Retrieves the result of a task""" diff --git a/api/specs/web-server/_long_running_tasks_legacy.py b/api/specs/web-server/_long_running_tasks_legacy.py index d5fc487301ac..59bc8881b0c3 100644 --- a/api/specs/web-server/_long_running_tasks_legacy.py +++ b/api/specs/web-server/_long_running_tasks_legacy.py @@ -42,11 +42,11 @@ async def get_task_status( @router.delete( "/{task_id}", - name="cancel_and_delete_task", - description="Cancels and deletes a task", + name="remove_task", + description="Cancels and removes a task", status_code=status.HTTP_204_NO_CONTENT, ) -async def cancel_and_delete_task( +async def remove_task( _path_params: Annotated[_PathParam, Depends()], ): ... diff --git a/packages/pytest-simcore/src/pytest_simcore/long_running_tasks.py b/packages/pytest-simcore/src/pytest_simcore/long_running_tasks.py new file mode 100644 index 000000000000..e3911dc62f5a --- /dev/null +++ b/packages/pytest-simcore/src/pytest_simcore/long_running_tasks.py @@ -0,0 +1,14 @@ +from datetime import timedelta + +import pytest +from pytest_mock import MockerFixture + + +@pytest.fixture +async def fast_long_running_tasks_cancellation( + mocker: MockerFixture, +) -> None: + mocker.patch( + "servicelib.long_running_tasks.task._CANCEL_TASKS_CHECK_INTERVAL", + new=timedelta(seconds=1), + ) diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py index f03945126bd0..b2023c980681 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py @@ -1,41 +1,12 @@ -import datetime - from aiohttp import web -from settings_library.redis import RedisSettings from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager from ...long_running_tasks.models import TaskContext -from ...long_running_tasks.task import RedisNamespace, TasksManager from ._constants import APP_LONG_RUNNING_MANAGER_KEY from ._request import get_task_context class AiohttpLongRunningManager(BaseLongRunningManager): - def __init__( - self, - app: web.Application, - stale_task_check_interval: datetime.timedelta, - stale_task_detect_timeout: datetime.timedelta, - redis_settings: RedisSettings, - redis_namespace: RedisNamespace, - ): - self._app = app - self._tasks_manager = TasksManager( - stale_task_check_interval=stale_task_check_interval, - stale_task_detect_timeout=stale_task_detect_timeout, - redis_settings=redis_settings, - redis_namespace=redis_namespace, - ) - - @property - def tasks_manager(self) -> TasksManager: - return self._tasks_manager - - async def setup(self) -> None: - await self._tasks_manager.setup() - - async def teardown(self) -> None: - await self._tasks_manager.teardown() @staticmethod def get_task_context(request: web.Request) -> TaskContext: diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py index cb735779901a..9e8ac646c330 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py @@ -1,12 +1,16 @@ -from typing import Any +from typing import Annotated, Any from aiohttp import web -from pydantic import BaseModel +from models_library.rest_base import RequestParameters +from pydantic import BaseModel, Field from ...aiohttp import status from ...long_running_tasks import lrt_api from ...long_running_tasks.models import TaskGet, TaskId -from ..requests_validation import parse_request_path_parameters_as +from ..requests_validation import ( + parse_request_path_parameters_as, + parse_request_query_parameters_as, +) from ..rest_responses import create_data_response from ._manager import get_long_running_manager @@ -26,10 +30,11 @@ async def list_tasks(request: web.Request) -> web.Response: task_id=t.task_id, status_href=f"{request.app.router['get_task_status'].url_for(task_id=t.task_id)}", result_href=f"{request.app.router['get_task_result'].url_for(task_id=t.task_id)}", - abort_href=f"{request.app.router['cancel_and_delete_task'].url_for(task_id=t.task_id)}", + abort_href=f"{request.app.router['remove_task'].url_for(task_id=t.task_id)}", ) for t in await lrt_api.list_tasks( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, long_running_manager.get_task_context(request), ) ] @@ -42,7 +47,8 @@ async def get_task_status(request: web.Request) -> web.Response: long_running_manager = get_long_running_manager(request.app) task_status = await lrt_api.get_task_status( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, long_running_manager.get_task_context(request), path_params.task_id, ) @@ -56,20 +62,36 @@ async def get_task_result(request: web.Request) -> web.Response | Any: # NOTE: this might raise an exception that will be catched by the _error_handlers return await lrt_api.get_task_result( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, long_running_manager.get_task_context(request), path_params.task_id, ) -@routes.delete("/{task_id}", name="cancel_and_delete_task") -async def cancel_and_delete_task(request: web.Request) -> web.Response: +class _RemoveTaskQueryParams(RequestParameters): + wait_for_removal: Annotated[ + bool, + Field( + description=( + "when True waits for the task to be removed " + "completly instead of returning immediately" + ) + ), + ] = True + + +@routes.delete("/{task_id}", name="remove_task") +async def remove_task(request: web.Request) -> web.Response: path_params = parse_request_path_parameters_as(_PathParam, request) + query_params = parse_request_query_parameters_as(_RemoveTaskQueryParams, request) long_running_manager = get_long_running_manager(request.app) await lrt_api.remove_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, long_running_manager.get_task_context(request), path_params.task_id, + wait_for_removal=query_params.wait_for_removal, ) return web.json_response(status=status.HTTP_204_NO_CONTENT) diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py index 68147fede027..b5ae54cb07f9 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py @@ -8,11 +8,12 @@ from aiohttp.web import HTTPException from common_library.json_serialization import json_dumps from pydantic import AnyHttpUrl, TypeAdapter +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from ...aiohttp import status from ...long_running_tasks import lrt_api -from ...long_running_tasks._redis_serialization import ( +from ...long_running_tasks._serialization import ( BaseObjectSerializer, register_custom_serialization, ) @@ -20,8 +21,12 @@ DEFAULT_STALE_TASK_CHECK_INTERVAL, DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) -from ...long_running_tasks.models import TaskContext, TaskGet -from ...long_running_tasks.task import RedisNamespace, RegisteredTaskName +from ...long_running_tasks.models import ( + LRTNamespace, + RegisteredTaskName, + TaskContext, + TaskGet, +) from ..typing_extension import Handler from . import _routes from ._constants import ( @@ -63,7 +68,8 @@ async def start_long_running_task( task_id = None try: task_id = await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, registerd_task_name, fire_and_forget=fire_and_forget, task_context=task_context, @@ -81,7 +87,7 @@ async def start_long_running_task( f"http://{ip_addr}:{port}{request_.app.router['get_task_result'].url_for(task_id=task_id)}" # NOSONAR ) abort_url = TypeAdapter(AnyHttpUrl).validate_python( - f"http://{ip_addr}:{port}{request_.app.router['cancel_and_delete_task'].url_for(task_id=task_id)}" # NOSONAR + f"http://{ip_addr}:{port}{request_.app.router['remove_task'].url_for(task_id=task_id)}" # NOSONAR ) task_get = TaskGet( task_id=task_id, @@ -98,7 +104,11 @@ async def start_long_running_task( # remove the task, the client was disconnected if task_id: await lrt_api.remove_task( - long_running_manager.tasks_manager, task_context, task_id + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + task_context, + task_id, + wait_for_removal=True, ) raise @@ -143,20 +153,23 @@ def setup( *, router_prefix: str, redis_settings: RedisSettings, - redis_namespace: RedisNamespace, - handler_check_decorator: Callable = _no_ops_decorator, - task_request_context_decorator: Callable = _no_task_context_decorator, + rabbit_settings: RabbitSettings, + lrt_namespace: LRTNamespace, stale_task_check_interval: datetime.timedelta = DEFAULT_STALE_TASK_CHECK_INTERVAL, stale_task_detect_timeout: datetime.timedelta = DEFAULT_STALE_TASK_DETECT_TIMEOUT, + handler_check_decorator: Callable = _no_ops_decorator, + task_request_context_decorator: Callable = _no_task_context_decorator, ) -> None: """ - `router_prefix` APIs are mounted on `/...`, this will change them to be mounted as `{router_prefix}/...` - - `stale_task_check_interval_s` interval at which the + - `redis_settings` settings for Redis connection + - `rabbit_settings` settings for RabbitMQ connection + - `lrt_namespace` namespace for the long-running tasks + - `stale_task_check_interval` interval at which the TaskManager checks for tasks which are no longer being actively monitored by a client - - `stale_task_detect_timeout_s` interval after which a - task is considered stale + - `stale_task_detect_timeout` interval after which atask is considered stale """ async def on_cleanup_ctx(app: web.Application) -> AsyncGenerator[None, None]: @@ -168,11 +181,11 @@ async def on_cleanup_ctx(app: web.Application) -> AsyncGenerator[None, None]: # add components to state app[APP_LONG_RUNNING_MANAGER_KEY] = long_running_manager = ( AiohttpLongRunningManager( - app=app, stale_task_check_interval=stale_task_check_interval, stale_task_detect_timeout=stale_task_detect_timeout, redis_settings=redis_settings, - redis_namespace=redis_namespace, + rabbit_settings=rabbit_settings, + lrt_namespace=lrt_namespace, ) ) diff --git a/packages/service-library/src/servicelib/aiohttp/profiler_middleware.py b/packages/service-library/src/servicelib/aiohttp/profiler_middleware.py index 4256820b4b30..07d3c7127297 100644 --- a/packages/service-library/src/servicelib/aiohttp/profiler_middleware.py +++ b/packages/service-library/src/servicelib/aiohttp/profiler_middleware.py @@ -1,9 +1,6 @@ from aiohttp.web import HTTPInternalServerError, Request, StreamResponse, middleware -from servicelib.mimetype_constants import ( - MIMETYPE_APPLICATION_JSON, - MIMETYPE_APPLICATION_ND_JSON, -) +from ..mimetype_constants import MIMETYPE_APPLICATION_JSON, MIMETYPE_APPLICATION_ND_JSON from ..utils_profiling_middleware import _is_profiling, _profiler, append_profile diff --git a/packages/service-library/src/servicelib/fastapi/client_session.py b/packages/service-library/src/servicelib/fastapi/client_session.py index c77a6ea339ea..f9c126272eec 100644 --- a/packages/service-library/src/servicelib/fastapi/client_session.py +++ b/packages/service-library/src/servicelib/fastapi/client_session.py @@ -2,9 +2,10 @@ import httpx from fastapi import FastAPI -from servicelib.fastapi.tracing import setup_httpx_client_tracing from settings_library.tracing import TracingSettings +from .tracing import setup_httpx_client_tracing + def setup_client_session( app: FastAPI, diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_client.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_client.py index 5a17014e209f..8b7d9b78207b 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_client.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_client.py @@ -171,7 +171,7 @@ async def get_task_result( return result.json() @retry_on_http_errors - async def cancel_and_delete_task( + async def remove_task( self, task_id: TaskId, *, timeout: PositiveFloat | None = None # noqa: ASYNC109 ) -> None: timeout = timeout or self._client_configuration.default_timeout diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py index 6a7ff58814d4..2e618710ca07 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_context_manager.py @@ -121,7 +121,7 @@ async def _wait_for_task_result() -> Any: yield result except TimeoutError as e: - await client.cancel_and_delete_task(task_id) + await client.remove_task(task_id) raise TaskClientTimeoutError( task_id=task_id, timeout=task_timeout, diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py index 6f37eb40825d..8f04f828705e 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py @@ -1,35 +1,11 @@ -import datetime - -from fastapi import FastAPI -from settings_library.redis import RedisSettings +from fastapi import Request from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager -from ...long_running_tasks.task import RedisNamespace, TasksManager +from ...long_running_tasks.models import TaskContext class FastAPILongRunningManager(BaseLongRunningManager): - def __init__( - self, - app: FastAPI, - stale_task_check_interval: datetime.timedelta, - stale_task_detect_timeout: datetime.timedelta, - redis_settings: RedisSettings, - redis_namespace: RedisNamespace, - ): - self._app = app - self._tasks_manager = TasksManager( - stale_task_check_interval=stale_task_check_interval, - stale_task_detect_timeout=stale_task_detect_timeout, - redis_settings=redis_settings, - redis_namespace=redis_namespace, - ) - - @property - def tasks_manager(self) -> TasksManager: - return self._tasks_manager - - async def setup(self) -> None: - await self._tasks_manager.setup() - - async def teardown(self) -> None: - await self._tasks_manager.teardown() + @staticmethod + def get_task_context(request: Request) -> TaskContext: + _ = request + return {} diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py index d3adbb1956db..bf347ba0d0aa 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py @@ -1,6 +1,6 @@ from typing import Annotated, Any -from fastapi import APIRouter, Depends, Request, status +from fastapi import APIRouter, Depends, Query, Request, status from ...long_running_tasks import lrt_api from ...long_running_tasks.models import TaskGet, TaskId, TaskResult, TaskStatus @@ -25,12 +25,12 @@ async def list_tasks( task_id=t.task_id, status_href=str(request.url_for("get_task_status", task_id=t.task_id)), result_href=str(request.url_for("get_task_result", task_id=t.task_id)), - abort_href=str( - request.url_for("cancel_and_delete_task", task_id=t.task_id) - ), + abort_href=str(request.url_for("remove_task", task_id=t.task_id)), ) for t in await lrt_api.list_tasks( - long_running_manager.tasks_manager, task_context={} + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + long_running_manager.get_task_context(request), ) ] @@ -45,14 +45,17 @@ async def list_tasks( @cancel_on_disconnect async def get_task_status( request: Request, - task_id: TaskId, long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], + task_id: TaskId, ) -> TaskStatus: assert request # nosec return await lrt_api.get_task_status( - long_running_manager.tasks_manager, task_context={}, task_id=task_id + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + long_running_manager.get_task_context(request), + task_id=task_id, ) @@ -68,20 +71,23 @@ async def get_task_status( @cancel_on_disconnect async def get_task_result( request: Request, - task_id: TaskId, long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], + task_id: TaskId, ) -> TaskResult | Any: assert request # nosec return await lrt_api.get_task_result( - long_running_manager.tasks_manager, task_context={}, task_id=task_id + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + long_running_manager.get_task_context(request), + task_id=task_id, ) @router.delete( "/{task_id}", - summary="Cancel and deletes a task", + summary="Cancels and removes a task", response_model=None, status_code=status.HTTP_204_NO_CONTENT, responses={ @@ -89,14 +95,28 @@ async def get_task_result( }, ) @cancel_on_disconnect -async def cancel_and_delete_task( +async def remove_task( request: Request, - task_id: TaskId, long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], + task_id: TaskId, + *, + wait_for_removal: Annotated[ + bool, + Query( + description=( + "when True waits for the task to be removed " + "completly instead of returning immediately" + ), + ), + ] = True, ) -> None: assert request # nosec await lrt_api.remove_task( - long_running_manager.tasks_manager, task_context={}, task_id=task_id + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + long_running_manager.get_task_context(request), + task_id=task_id, + wait_for_removal=wait_for_removal, ) diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py index f00e6c8f5215..9cf4c526acee 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_server.py @@ -1,6 +1,7 @@ import datetime from fastapi import APIRouter, FastAPI +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from ...long_running_tasks.constants import ( @@ -8,7 +9,7 @@ DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) from ...long_running_tasks.errors import BaseLongRunningError -from ...long_running_tasks.task import RedisNamespace +from ...long_running_tasks.models import LRTNamespace from ._error_handlers import base_long_running_error_handler from ._manager import FastAPILongRunningManager from ._routes import router @@ -19,18 +20,21 @@ def setup( *, router_prefix: str = "", redis_settings: RedisSettings, - redis_namespace: RedisNamespace, + rabbit_settings: RabbitSettings, + lrt_namespace: LRTNamespace, stale_task_check_interval: datetime.timedelta = DEFAULT_STALE_TASK_CHECK_INTERVAL, stale_task_detect_timeout: datetime.timedelta = DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) -> None: """ - - `router_prefix` APIs are mounted on `/task/...`, this - will change them to be mounted as `{router_prefix}/task/...` - - `stale_task_check_interval_s` interval at which the + - `router_prefix` APIs are mounted on `/...`, this + will change them to be mounted as `{router_prefix}/...` + - `redis_settings` settings for Redis connection + - `rabbit_settings` settings for RabbitMQ connection + - `lrt_namespace` namespace for the long-running tasks + - `stale_task_check_interval` interval at which the TaskManager checks for tasks which are no longer being actively monitored by a client - - `stale_task_detect_timeout_s` interval after which a - task is considered stale + - `stale_task_detect_timeout` interval after which atask is considered stale """ async def on_startup() -> None: @@ -42,19 +46,21 @@ async def on_startup() -> None: # add components to state app.state.long_running_manager = long_running_manager = ( FastAPILongRunningManager( - app=app, stale_task_check_interval=stale_task_check_interval, stale_task_detect_timeout=stale_task_detect_timeout, redis_settings=redis_settings, - redis_namespace=redis_namespace, + rabbit_settings=rabbit_settings, + lrt_namespace=lrt_namespace, ) ) await long_running_manager.setup() async def on_shutdown() -> None: if app.state.long_running_manager: - task_manager: FastAPILongRunningManager = app.state.long_running_manager - await task_manager.teardown() + long_running_manager: FastAPILongRunningManager = ( + app.state.long_running_manager + ) + await long_running_manager.teardown() app.add_event_handler("startup", on_startup) app.add_event_handler("shutdown", on_shutdown) diff --git a/packages/service-library/src/servicelib/long_running_tasks/_rabbit_namespace.py b/packages/service-library/src/servicelib/long_running_tasks/_rabbit_namespace.py new file mode 100644 index 000000000000..7ace2e53a3dd --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/_rabbit_namespace.py @@ -0,0 +1,8 @@ +from models_library.rabbitmq_basic_types import RPCNamespace +from pydantic import TypeAdapter + +from .models import LRTNamespace + + +def get_rabbit_namespace(namespace: LRTNamespace) -> RPCNamespace: + return TypeAdapter(RPCNamespace).validate_python(f"lrt-{namespace}") diff --git a/packages/service-library/src/servicelib/long_running_tasks/_redis_store.py b/packages/service-library/src/servicelib/long_running_tasks/_redis_store.py new file mode 100644 index 000000000000..acf70bb87e48 --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/_redis_store.py @@ -0,0 +1,138 @@ +from typing import Any, Final + +import redis.asyncio as aioredis +from common_library.json_serialization import json_dumps, json_loads +from pydantic import TypeAdapter +from settings_library.redis import RedisDatabase, RedisSettings + +from ..redis._client import RedisClientSDK +from ..redis._utils import handle_redis_returns_union_types +from ..utils import limited_gather +from .models import LRTNamespace, TaskContext, TaskData, TaskId + +_STORE_TYPE_TASK_DATA: Final[str] = "TD" +_STORE_TYPE_CANCELLED_TASKS: Final[str] = "CT" +_LIST_CONCURRENCY: Final[int] = 2 + + +def _to_redis_hash_mapping(data: dict[str, Any]) -> dict[str, str]: + return {k: json_dumps(v) for k, v in data.items()} + + +def _load_from_redis_hash(data: dict[str, str]) -> dict[str, Any]: + return {k: json_loads(v) for k, v in data.items()} + + +class RedisStore: + def __init__(self, redis_settings: RedisSettings, namespace: LRTNamespace): + self.redis_settings = redis_settings + self.namespace: LRTNamespace = namespace.upper() + + self._client: RedisClientSDK | None = None + + async def setup(self) -> None: + self._client = RedisClientSDK( + self.redis_settings.build_redis_dsn(RedisDatabase.LONG_RUNNING_TASKS), + client_name=f"long_running_tasks_store_{self.namespace}", + ) + await self._client.setup() + + async def shutdown(self) -> None: + if self._client: + await self._client.shutdown() + + @property + def _redis(self) -> aioredis.Redis: + assert self._client # nosec + return self._client.redis + + def _get_redis_key_task_data_match(self) -> str: + return f"{self.namespace}:{_STORE_TYPE_TASK_DATA}*" + + def _get_redis_task_data_key(self, task_id: TaskId) -> str: + return f"{self.namespace}:{_STORE_TYPE_TASK_DATA}:{task_id}" + + def _get_key_to_remove(self) -> str: + return f"{self.namespace}:{_STORE_TYPE_CANCELLED_TASKS}" + + # TaskData + + async def get_task_data(self, task_id: TaskId) -> TaskData | None: + result: dict[str, Any] = await handle_redis_returns_union_types( + self._redis.hgetall( + self._get_redis_task_data_key(task_id), + ) + ) + return ( + TypeAdapter(TaskData).validate_python(_load_from_redis_hash(result)) + if result and len(result) + else None + ) + + async def add_task_data(self, task_id: TaskId, value: TaskData) -> None: + await handle_redis_returns_union_types( + self._redis.hset( + self._get_redis_task_data_key(task_id), + mapping=_to_redis_hash_mapping(value.model_dump()), + ) + ) + + async def update_task_data( + self, + task_id: TaskId, + *, + updates: dict[str, Any], + ) -> None: + await handle_redis_returns_union_types( + self._redis.hset( + self._get_redis_task_data_key(task_id), + mapping=_to_redis_hash_mapping(updates), + ) + ) + + async def list_tasks_data(self) -> list[TaskData]: + hash_keys: list[str] = [ + x + async for x in self._redis.scan_iter(self._get_redis_key_task_data_match()) + ] + + result = await limited_gather( + *[ + handle_redis_returns_union_types(self._redis.hgetall(key)) + for key in hash_keys + ], + limit=_LIST_CONCURRENCY, + ) + + return [ + TypeAdapter(TaskData).validate_python(_load_from_redis_hash(item)) + for item in result + if item + ] + + async def delete_task_data(self, task_id: TaskId) -> None: + await handle_redis_returns_union_types( + self._redis.delete(self._get_redis_task_data_key(task_id)) + ) + + # to cancel + + async def mark_task_for_removal( + self, task_id: TaskId, with_task_context: TaskContext + ) -> None: + await handle_redis_returns_union_types( + self._redis.hset( + self._get_key_to_remove(), task_id, json_dumps(with_task_context) + ) + ) + + async def completed_task_removal(self, task_id: TaskId) -> None: + await handle_redis_returns_union_types( + self._redis.hdel(self._get_key_to_remove(), task_id) + ) + + async def list_tasks_to_remove(self) -> dict[TaskId, TaskContext]: + result: dict[str, str | None] = await handle_redis_returns_union_types( + self._redis.hgetall(self._get_key_to_remove()) + ) + return {task_id: json_loads(context) for task_id, context in result.items()} diff --git a/packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py b/packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py new file mode 100644 index 000000000000..3bc3caf5e804 --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py @@ -0,0 +1,143 @@ +import logging +from datetime import timedelta +from typing import Any, Final + +from models_library.rabbitmq_basic_types import RPCMethodName +from pydantic import PositiveInt, TypeAdapter + +from ..logging_utils import log_decorator +from ..rabbitmq._client_rpc import RabbitMQRPCClient +from ._rabbit_namespace import get_rabbit_namespace +from ._serialization import loads +from .errors import RPCTransferrableTaskError +from .models import ( + LRTNamespace, + RegisteredTaskName, + TaskBase, + TaskContext, + TaskId, + TaskStatus, +) + +_logger = logging.getLogger(__name__) + +_RPC_TIMEOUT_SHORT_REQUESTS: Final[PositiveInt] = int( + timedelta(seconds=20).total_seconds() +) + + +@log_decorator(_logger, level=logging.DEBUG) +async def start_task( + rabbitmq_rpc_client: RabbitMQRPCClient, + namespace: LRTNamespace, + *, + registered_task_name: RegisteredTaskName, + unique: bool = False, + task_context: TaskContext | None = None, + task_name: str | None = None, + fire_and_forget: bool = False, + **task_kwargs: Any, +) -> TaskId: + result = await rabbitmq_rpc_client.request( + get_rabbit_namespace(namespace), + TypeAdapter(RPCMethodName).validate_python("start_task"), + registered_task_name=registered_task_name, + unique=unique, + task_context=task_context, + task_name=task_name, + fire_and_forget=fire_and_forget, + **task_kwargs, + timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS, + ) + assert isinstance(result, TaskId) # nosec + return result + + +@log_decorator(_logger, level=logging.DEBUG) +async def list_tasks( + rabbitmq_rpc_client: RabbitMQRPCClient, + namespace: LRTNamespace, + *, + task_context: TaskContext, +) -> list[TaskBase]: + result = await rabbitmq_rpc_client.request( + get_rabbit_namespace(namespace), + TypeAdapter(RPCMethodName).validate_python("list_tasks"), + task_context=task_context, + timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS, + ) + return TypeAdapter(list[TaskBase]).validate_python(result) + + +@log_decorator(_logger, level=logging.DEBUG) +async def get_task_status( + rabbitmq_rpc_client: RabbitMQRPCClient, + namespace: LRTNamespace, + *, + task_context: TaskContext, + task_id: TaskId, +) -> TaskStatus: + result = await rabbitmq_rpc_client.request( + get_rabbit_namespace(namespace), + TypeAdapter(RPCMethodName).validate_python("get_task_status"), + task_context=task_context, + task_id=task_id, + timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS, + ) + assert isinstance(result, TaskStatus) # nosec + return result + + +@log_decorator(_logger, level=logging.DEBUG) +async def get_task_result( + rabbitmq_rpc_client: RabbitMQRPCClient, + namespace: LRTNamespace, + *, + task_context: TaskContext, + task_id: TaskId, +) -> Any: + try: + serialized_result = await rabbitmq_rpc_client.request( + get_rabbit_namespace(namespace), + TypeAdapter(RPCMethodName).validate_python("get_task_result"), + task_context=task_context, + task_id=task_id, + timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS, + ) + assert isinstance(serialized_result, str) # nosec + return loads(serialized_result) + except RPCTransferrableTaskError as e: + decoded_error = loads(f"{e}") + raise decoded_error from e + + +@log_decorator(_logger, level=logging.DEBUG) +async def remove_task( + rabbitmq_rpc_client: RabbitMQRPCClient, + namespace: LRTNamespace, + *, + task_context: TaskContext, + task_id: TaskId, + wait_for_removal: bool, + cancellation_timeout: timedelta | None, +) -> None: + timeout_s = ( + None + if cancellation_timeout is None + else int(cancellation_timeout.total_seconds()) + ) + + # NOTE: task always gets cancelled even if not waiting for it + # request will return immediatlye, no need to wait so much + if wait_for_removal is False: + timeout_s = _RPC_TIMEOUT_SHORT_REQUESTS + + result = await rabbitmq_rpc_client.request( + get_rabbit_namespace(namespace), + TypeAdapter(RPCMethodName).validate_python("remove_task"), + task_context=task_context, + task_id=task_id, + wait_for_removal=wait_for_removal, + timeout_s=timeout_s, + ) + assert result is None # nosec diff --git a/packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py b/packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py new file mode 100644 index 000000000000..d63c5d370ce4 --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py @@ -0,0 +1,107 @@ +import logging +from contextlib import suppress +from typing import TYPE_CHECKING, Any + +from ..rabbitmq import RPCRouter +from .errors import BaseLongRunningError, RPCTransferrableTaskError, TaskNotFoundError +from .models import ( + RegisteredTaskName, + TaskBase, + TaskContext, + TaskId, + TaskStatus, +) + +_logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .base_long_running_manager import BaseLongRunningManager + + +router = RPCRouter() + + +@router.expose(reraise_if_error_type=(BaseLongRunningError,)) +async def start_task( + long_running_manager: "BaseLongRunningManager", + *, + registered_task_name: RegisteredTaskName, + unique: bool = False, + task_context: TaskContext | None = None, + task_name: str | None = None, + fire_and_forget: bool = False, + **task_kwargs: Any, +) -> TaskId: + return await long_running_manager.tasks_manager.start_task( + registered_task_name, + unique=unique, + task_context=task_context, + task_name=task_name, + fire_and_forget=fire_and_forget, + **task_kwargs, + ) + + +@router.expose(reraise_if_error_type=(BaseLongRunningError,)) +async def list_tasks( + long_running_manager: "BaseLongRunningManager", *, task_context: TaskContext +) -> list[TaskBase]: + return await long_running_manager.tasks_manager.list_tasks( + with_task_context=task_context + ) + + +@router.expose(reraise_if_error_type=(BaseLongRunningError,)) +async def get_task_status( + long_running_manager: "BaseLongRunningManager", + *, + task_context: TaskContext, + task_id: TaskId, +) -> TaskStatus: + return await long_running_manager.tasks_manager.get_task_status( + task_id=task_id, with_task_context=task_context + ) + + +@router.expose(reraise_if_error_type=(BaseLongRunningError, RPCTransferrableTaskError)) +async def get_task_result( + long_running_manager: "BaseLongRunningManager", + *, + task_context: TaskContext, + task_id: TaskId, +) -> str: + try: + result_field = await long_running_manager.tasks_manager.get_task_result( + task_id, with_task_context=task_context + ) + if result_field.str_error is not None: + raise RPCTransferrableTaskError(result_field.str_error) + + if result_field.str_result is not None: + return result_field.str_result + + msg = f"Please check {result_field=}, both fields should never be None" + raise ValueError(msg) + finally: + # Ensure the task is removed regardless of the result + with suppress(TaskNotFoundError): + await long_running_manager.tasks_manager.remove_task( + task_id, + with_task_context=task_context, + wait_for_removal=True, + ) + + +@router.expose(reraise_if_error_type=(BaseLongRunningError,)) +async def remove_task( + long_running_manager: "BaseLongRunningManager", + *, + task_context: TaskContext, + task_id: TaskId, + wait_for_removal: bool, +) -> None: + await long_running_manager.tasks_manager.remove_task( + task_id, + with_task_context=task_context, + wait_for_removal=wait_for_removal, + ) diff --git a/packages/service-library/src/servicelib/long_running_tasks/_redis_serialization.py b/packages/service-library/src/servicelib/long_running_tasks/_serialization.py similarity index 85% rename from packages/service-library/src/servicelib/long_running_tasks/_redis_serialization.py rename to packages/service-library/src/servicelib/long_running_tasks/_serialization.py index ae7125147a1f..472b7d80f840 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/_redis_serialization.py +++ b/packages/service-library/src/servicelib/long_running_tasks/_serialization.py @@ -1,12 +1,8 @@ import base64 -import logging import pickle from abc import ABC, abstractmethod from typing import Any, Final, Generic, TypeVar -_logger = logging.getLogger(__name__) - - T = TypeVar("T") @@ -42,26 +38,26 @@ def register_custom_serialization( _MODULE_FIELD: Final[str] = "__pickle__module__field__" -def object_to_string(e: Any) -> str: +def dumps(obj: Any) -> str: """Serialize object to base64-encoded string.""" - to_serialize: Any | dict = e - object_class = type(e) + to_serialize: Any | dict = obj + object_class = type(obj) for registered_class, object_serializer in _SERIALIZERS.items(): if issubclass(object_class, registered_class): to_serialize = { - _TYPE_FIELD: type(e).__name__, - _MODULE_FIELD: type(e).__module__, - **object_serializer.get_init_kwargs_from_object(e), + _TYPE_FIELD: type(obj).__name__, + _MODULE_FIELD: type(obj).__module__, + **object_serializer.get_init_kwargs_from_object(obj), } break return base64.b85encode(pickle.dumps(to_serialize)).decode("utf-8") -def string_to_object(error_str: str) -> Any: +def loads(obj_str: str) -> Any: """Deserialize object from base64-encoded string.""" - data = pickle.loads(base64.b85decode(error_str)) # noqa: S301 + data = pickle.loads(base64.b85decode(obj_str)) # noqa: S301 if isinstance(data, dict) and _TYPE_FIELD in data and _MODULE_FIELD in data: try: @@ -75,7 +71,7 @@ def string_to_object(error_str: str) -> Any: data.pop(_TYPE_FIELD) data.pop(_MODULE_FIELD) - return exception_class( + raise exception_class( **object_serializer.prepare_object_init_kwargs(data) ) except (ImportError, AttributeError, TypeError) as e: diff --git a/packages/service-library/src/servicelib/long_running_tasks/_store/base.py b/packages/service-library/src/servicelib/long_running_tasks/_store/base.py deleted file mode 100644 index 20829c01ca65..000000000000 --- a/packages/service-library/src/servicelib/long_running_tasks/_store/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import abstractmethod - -from ..models import TaskContext, TaskData, TaskId - - -class BaseStore: - - @abstractmethod - async def get_task_data(self, task_id: TaskId) -> TaskData | None: - """Retrieve a tracked task""" - - @abstractmethod - async def set_task_data(self, task_id: TaskId, value: TaskData) -> None: - """Set a tracked task's data""" - - @abstractmethod - async def list_tasks_data(self) -> list[TaskData]: - """List all tracked tasks.""" - - @abstractmethod - async def delete_task_data(self, task_id: TaskId) -> None: - """Delete a tracked task.""" - - @abstractmethod - async def set_as_cancelled( - self, task_id: TaskId, with_task_context: TaskContext - ) -> None: - """Mark a tracked task as cancelled.""" - - @abstractmethod - async def get_cancelled(self) -> dict[TaskId, TaskContext]: - """Get cancelled tasks.""" - - @abstractmethod - async def setup(self) -> None: - """Setup the store, if needed.""" - - @abstractmethod - async def shutdown(self) -> None: - """Shutdown the store, if needed.""" diff --git a/packages/service-library/src/servicelib/long_running_tasks/_store/redis.py b/packages/service-library/src/servicelib/long_running_tasks/_store/redis.py deleted file mode 100644 index 3ac1314e11c9..000000000000 --- a/packages/service-library/src/servicelib/long_running_tasks/_store/redis.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Any, Final - -import redis.asyncio as aioredis -from common_library.json_serialization import json_dumps, json_loads -from pydantic import TypeAdapter -from settings_library.redis import RedisDatabase, RedisSettings - -from ...redis._client import RedisClientSDK -from ...redis._utils import handle_redis_returns_union_types -from ..models import TaskContext, TaskData, TaskId -from .base import BaseStore - -_STORE_TYPE_TASK_DATA: Final[str] = "TD" -_STORE_TYPE_CANCELLED_TASKS: Final[str] = "CT" - - -class RedisStore(BaseStore): - def __init__(self, redis_settings: RedisSettings, namespace: str): - self.redis_settings = redis_settings - self.namespace = namespace.upper() - - self._client: RedisClientSDK | None = None - - async def setup(self) -> None: - self._client = RedisClientSDK( - self.redis_settings.build_redis_dsn(RedisDatabase.LONG_RUNNING_TASKS), - client_name=f"long_running_tasks_store_{self.namespace}", - ) - await self._client.setup() - - async def shutdown(self) -> None: - if self._client: - await self._client.shutdown() - - @property - def _redis(self) -> aioredis.Redis: - assert self._client # nosec - return self._client.redis - - def _get_redis_hash_key(self, store_type: str) -> str: - return f"{self.namespace}:{store_type}" - - def _get_key(self, store_type: str, name: str) -> str: - return f"{self.namespace}:{store_type}:{name}" - - async def get_task_data(self, task_id: TaskId) -> TaskData | None: - result: Any | None = await handle_redis_returns_union_types( - self._redis.hget(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id) - ) - return TypeAdapter(TaskData).validate_json(result) if result else None - - async def set_task_data(self, task_id: TaskId, value: TaskData) -> None: - await handle_redis_returns_union_types( - self._redis.hset( - self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), - task_id, - value.model_dump_json(), - ) - ) - - async def list_tasks_data(self) -> list[TaskData]: - result: list[Any] = await handle_redis_returns_union_types( - self._redis.hvals(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA)) - ) - return [TypeAdapter(TaskData).validate_json(item) for item in result] - - async def delete_task_data(self, task_id: TaskId) -> None: - await handle_redis_returns_union_types( - self._redis.hdel(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id) - ) - - async def set_as_cancelled( - self, task_id: TaskId, with_task_context: TaskContext - ) -> None: - await handle_redis_returns_union_types( - self._redis.hset( - self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS), - task_id, - json_dumps(with_task_context), - ) - ) - - async def get_cancelled(self) -> dict[TaskId, TaskContext]: - result: dict[str, str | None] = await handle_redis_returns_union_types( - self._redis.hgetall(self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS)) - ) - return {task_id: json_loads(context) for task_id, context in result.items()} diff --git a/packages/service-library/src/servicelib/long_running_tasks/base_long_running_manager.py b/packages/service-library/src/servicelib/long_running_tasks/base_long_running_manager.py index d09428d3aa22..2090df4e1adb 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/base_long_running_manager.py +++ b/packages/service-library/src/servicelib/long_running_tasks/base_long_running_manager.py @@ -1,5 +1,13 @@ +import datetime from abc import ABC, abstractmethod +from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings + +from ..rabbitmq._client_rpc import RabbitMQRPCClient +from ._rabbit_namespace import get_rabbit_namespace +from ._rpc_server import router +from .models import LRTNamespace, TaskContext from .task import TasksManager @@ -8,15 +16,72 @@ class BaseLongRunningManager(ABC): Provides a commond inteface for aiohttp and fastapi services """ + def __init__( + self, + stale_task_check_interval: datetime.timedelta, + stale_task_detect_timeout: datetime.timedelta, + redis_settings: RedisSettings, + rabbit_settings: RabbitSettings, + lrt_namespace: LRTNamespace, + ): + self._tasks_manager = TasksManager( + stale_task_check_interval=stale_task_check_interval, + stale_task_detect_timeout=stale_task_detect_timeout, + redis_settings=redis_settings, + lrt_namespace=lrt_namespace, + ) + self._lrt_namespace = lrt_namespace + self.rabbit_settings = rabbit_settings + self._rpc_server: RabbitMQRPCClient | None = None + self._rpc_client: RabbitMQRPCClient | None = None + @property - @abstractmethod def tasks_manager(self) -> TasksManager: - pass + return self._tasks_manager + + @property + def rpc_server(self) -> RabbitMQRPCClient: + assert self._rpc_server is not None # nosec + return self._rpc_server + + @property + def rpc_client(self) -> RabbitMQRPCClient: + assert self._rpc_client is not None # nosec + return self._rpc_client + + @property + def lrt_namespace(self) -> LRTNamespace: + return self._lrt_namespace - @abstractmethod async def setup(self) -> None: - pass + await self._tasks_manager.setup() + self._rpc_server = await RabbitMQRPCClient.create( + client_name=f"lrt-server-{self.lrt_namespace}", + settings=self.rabbit_settings, + ) + self._rpc_client = await RabbitMQRPCClient.create( + client_name=f"lrt-client-{self.lrt_namespace}", + settings=self.rabbit_settings, + ) + + await self.rpc_server.register_router( + router, + get_rabbit_namespace(self.lrt_namespace), + self, + ) - @abstractmethod async def teardown(self) -> None: - pass + await self._tasks_manager.teardown() + + if self._rpc_server is not None: + await self._rpc_server.close() + self._rpc_server = None + + if self._rpc_client is not None: + await self._rpc_client.close() + self._rpc_client = None + + @staticmethod + @abstractmethod + def get_task_context(request) -> TaskContext: + """return the task context based on the current request""" diff --git a/packages/service-library/src/servicelib/long_running_tasks/errors.py b/packages/service-library/src/servicelib/long_running_tasks/errors.py index 75e46da5b0c2..c95a228720b7 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/errors.py +++ b/packages/service-library/src/servicelib/long_running_tasks/errors.py @@ -7,7 +7,7 @@ class BaseLongRunningError(OsparcErrorMixin, Exception): class TaskNotRegisteredError(BaseLongRunningError): msg_template: str = ( - "no task with task_name='{task_name}' was found in the task registry. " + "no task with task_name='{task_name}' was found in the task registry tasks={tasks}. " "Make sure it's registered before starting it." ) @@ -44,3 +44,10 @@ class GenericClientError(BaseLongRunningError): msg_template: str = ( "Unexpected error while '{action}' for '{task_id}': status={status} body={body}" ) + + +class RPCTransferrableTaskError(Exception): + """ + The message contains the task's exception serialized as string. + Decode it and raise to obtain the task's original exception. + """ diff --git a/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py index 6f732d49e49c..73fdebb4cfa9 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py +++ b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py @@ -1,18 +1,21 @@ -import logging +from datetime import timedelta from typing import Any -from common_library.error_codes import create_error_code - -from ..logging_errors import create_troubleshootting_log_kwargs -from .errors import TaskNotCompletedError, TaskNotFoundError -from .models import TaskBase, TaskContext, TaskId, TaskStatus -from .task import RegisteredTaskName, TasksManager - -_logger = logging.getLogger(__name__) +from ..rabbitmq._client_rpc import RabbitMQRPCClient +from . import _rpc_client +from .models import ( + LRTNamespace, + RegisteredTaskName, + TaskBase, + TaskContext, + TaskId, + TaskStatus, +) async def start_task( - tasks_manager: TasksManager, + rabbitmq_rpc_client: RabbitMQRPCClient, + lrt_namespace: LRTNamespace, registered_task_name: RegisteredTaskName, *, unique: bool = False, @@ -46,8 +49,11 @@ async def start_task( Returns: TaskId: the task unique identifier """ - return await tasks_manager.start_task( - registered_task_name, + + return await _rpc_client.start_task( + rabbitmq_rpc_client, + lrt_namespace, + registered_task_name=registered_task_name, unique=unique, task_context=task_context, task_name=task_name, @@ -57,51 +63,59 @@ async def start_task( async def list_tasks( - tasks_manager: TasksManager, task_context: TaskContext + rabbitmq_rpc_client: RabbitMQRPCClient, + lrt_namespace: LRTNamespace, + task_context: TaskContext, ) -> list[TaskBase]: - return await tasks_manager.list_tasks(with_task_context=task_context) + return await _rpc_client.list_tasks( + rabbitmq_rpc_client, lrt_namespace, task_context=task_context + ) async def get_task_status( - tasks_manager: TasksManager, task_context: TaskContext, task_id: TaskId + rabbitmq_rpc_client: RabbitMQRPCClient, + lrt_namespace: LRTNamespace, + task_context: TaskContext, + task_id: TaskId, ) -> TaskStatus: """returns the status of a task""" - return await tasks_manager.get_task_status( - task_id=task_id, with_task_context=task_context + return await _rpc_client.get_task_status( + rabbitmq_rpc_client, lrt_namespace, task_id=task_id, task_context=task_context ) async def get_task_result( - tasks_manager: TasksManager, task_context: TaskContext, task_id: TaskId + rabbitmq_rpc_client: RabbitMQRPCClient, + lrt_namespace: LRTNamespace, + task_context: TaskContext, + task_id: TaskId, ) -> Any: - try: - task_result = await tasks_manager.get_task_result( - task_id, with_task_context=task_context - ) - await tasks_manager.remove_task( - task_id, with_task_context=task_context, reraise_errors=False - ) - return task_result - except (TaskNotFoundError, TaskNotCompletedError): - raise - except Exception as exc: - _logger.exception( - **create_troubleshootting_log_kwargs( - user_error_msg=f"{task_id=} raised an exception while getting its result", - error=exc, - error_code=create_error_code(exc), - error_context={"task_context": task_context, "task_id": task_id}, - ), - ) - # the task shall be removed in this case - await tasks_manager.remove_task( - task_id, with_task_context=task_context, reraise_errors=False - ) - raise + return await _rpc_client.get_task_result( + rabbitmq_rpc_client, + lrt_namespace, + task_context=task_context, + task_id=task_id, + ) async def remove_task( - tasks_manager: TasksManager, task_context: TaskContext, task_id: TaskId + rabbitmq_rpc_client: RabbitMQRPCClient, + lrt_namespace: LRTNamespace, + task_context: TaskContext, + task_id: TaskId, + *, + wait_for_removal: bool, + cancellation_timeout: timedelta | None = None, ) -> None: - """cancels and removes the task""" - await tasks_manager.remove_task(task_id, with_task_context=task_context) + """cancels and removes a task + + When `wait_for_removal` is True, `cancellationt_timeout` is set to _RPC_TIMEOUT_SHORT_REQUESTS + """ + await _rpc_client.remove_task( + rabbitmq_rpc_client, + lrt_namespace, + task_id=task_id, + task_context=task_context, + wait_for_removal=wait_for_removal, + cancellation_timeout=cancellation_timeout, + ) diff --git a/packages/service-library/src/servicelib/long_running_tasks/models.py b/packages/service-library/src/servicelib/long_running_tasks/models.py index a8c626714c1b..7a99f9a5cf34 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/models.py +++ b/packages/service-library/src/servicelib/long_running_tasks/models.py @@ -28,20 +28,25 @@ RequestBody: TypeAlias = Any TaskContext: TypeAlias = dict[str, Any] +LRTNamespace: TypeAlias = str + +RegisteredTaskName: TypeAlias = str + class ResultField(BaseModel): - result: str | None = None - error: str | None = None + str_result: str | None = None + str_error: str | None = None @model_validator(mode="after") def validate_mutually_exclusive(self) -> "ResultField": - if self.result is not None and self.error is not None: + if self.str_result is not None and self.str_error is not None: msg = "Cannot set both 'result' and 'error' - they are mutually exclusive" raise ValueError(msg) return self class TaskData(BaseModel): + registered_task_name: RegisteredTaskName task_id: str task_progress: TaskProgress # NOTE: this context lifetime is with the tracked task (similar to aiohttp storage concept) @@ -79,6 +84,7 @@ class TaskData(BaseModel): json_schema_extra={ "examples": [ { + "registered_task_name": "a-task-name", "task_id": "1a119618-7186-4bc1-b8de-7e3ff314cb7e", "task_name": "running-task", "task_status": "running", diff --git a/packages/service-library/src/servicelib/long_running_tasks/task.py b/packages/service-library/src/servicelib/long_running_tasks/task.py index ca22d9ff6af3..e76b486954cc 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/task.py +++ b/packages/service-library/src/servicelib/long_running_tasks/task.py @@ -4,6 +4,7 @@ import inspect import logging import urllib.parse +from contextlib import suppress from typing import Any, ClassVar, Final, Protocol, TypeAlias from uuid import uuid4 @@ -13,17 +14,17 @@ from settings_library.redis import RedisDatabase, RedisSettings from tenacity import ( AsyncRetrying, - TryAgain, - retry_if_exception_type, + retry_unless_exception_type, stop_after_delay, wait_exponential, ) from ..background_task import create_periodic_task +from ..logging_errors import create_troubleshootting_log_kwargs +from ..logging_utils import log_context from ..redis import RedisClientSDK, exclusive -from ._redis_serialization import object_to_string, string_to_object -from ._store.base import BaseStore -from ._store.redis import RedisStore +from ._redis_store import RedisStore +from ._serialization import dumps from .errors import ( TaskAlreadyRunningError, TaskCancelledError, @@ -31,19 +32,26 @@ TaskNotFoundError, TaskNotRegisteredError, ) -from .models import ResultField, TaskBase, TaskContext, TaskData, TaskId, TaskStatus +from .models import ( + LRTNamespace, + RegisteredTaskName, + ResultField, + TaskBase, + TaskContext, + TaskData, + TaskId, + TaskStatus, +) _logger = logging.getLogger(__name__) _CANCEL_TASKS_CHECK_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) _STATUS_UPDATE_CHECK_INTERNAL: Final[datetime.timedelta] = datetime.timedelta(seconds=1) - +_MAX_EXCLUSIVE_TASK_CANCEL_TIMEOUT: Final[NonNegativeFloat] = 5 _TASK_REMOVAL_MAX_WAIT: Final[NonNegativeFloat] = 60 - -RegisteredTaskName: TypeAlias = str -RedisNamespace: TypeAlias = str +AllowedErrrors: TypeAlias = tuple[type[BaseException], ...] class TaskProtocol(Protocol): @@ -56,20 +64,44 @@ def __name__(self) -> str: ... class TaskRegistry: - REGISTERED_TASKS: ClassVar[dict[RegisteredTaskName, TaskProtocol]] = {} + _REGISTERED_TASKS: ClassVar[ + dict[RegisteredTaskName, tuple[AllowedErrrors, TaskProtocol]] + ] = {} + + @classmethod + def register( + cls, + task: TaskProtocol, + allowed_errors: AllowedErrrors = (), + **partial_kwargs, + ) -> None: + partial_task = functools.partial(task, **partial_kwargs) + # allows to call the partial via it's original name + partial_task.__name__ = task.__name__ # type: ignore[attr-defined] + cls._REGISTERED_TASKS[task.__name__] = [allowed_errors, partial_task] # type: ignore[assignment] + + @classmethod + def get_registered_tasks( + cls, + ) -> dict[RegisteredTaskName, tuple[AllowedErrrors, TaskProtocol]]: + return cls._REGISTERED_TASKS + + @classmethod + def get_task(cls, task_name: RegisteredTaskName) -> TaskProtocol: + return cls._REGISTERED_TASKS[task_name][1] @classmethod - def register(cls, task: TaskProtocol) -> None: - cls.REGISTERED_TASKS[task.__name__] = task + def get_allowed_errors(cls, task_name: RegisteredTaskName) -> AllowedErrrors: + return cls._REGISTERED_TASKS[task_name][0] @classmethod def unregister(cls, task: TaskProtocol) -> None: - if task.__name__ in cls.REGISTERED_TASKS: - del cls.REGISTERED_TASKS[task.__name__] + if task.__name__ in cls._REGISTERED_TASKS: + del cls._REGISTERED_TASKS[task.__name__] async def _get_tasks_to_remove( - tracked_tasks: BaseStore, + tracked_tasks: RedisStore, stale_task_detect_timeout_s: PositiveFloat, ) -> list[tuple[TaskId, TaskContext]]: utc_now = datetime.datetime.now(tz=datetime.UTC) @@ -82,14 +114,16 @@ async def _get_tasks_to_remove( if tracked_task.last_status_check is None: # the task just added or never received a poll request - elapsed_from_start = (utc_now - tracked_task.started).seconds + elapsed_from_start = (utc_now - tracked_task.started).total_seconds() if elapsed_from_start > stale_task_detect_timeout_s: tasks_to_remove.append( (tracked_task.task_id, tracked_task.task_context) ) else: # the task status was already queried by the client - elapsed_from_last_poll = (utc_now - tracked_task.last_status_check).seconds + elapsed_from_last_poll = ( + utc_now - tracked_task.last_status_check + ).total_seconds() if elapsed_from_last_poll > stale_task_detect_timeout_s: tasks_to_remove.append( (tracked_task.task_id, tracked_task.task_context) @@ -107,17 +141,17 @@ def __init__( redis_settings: RedisSettings, stale_task_check_interval: datetime.timedelta, stale_task_detect_timeout: datetime.timedelta, - redis_namespace: RedisNamespace, + lrt_namespace: LRTNamespace, ): # Task groups: Every taskname maps to multiple asyncio.Task within TrackedTask model - self._tasks_data: BaseStore = RedisStore(redis_settings, redis_namespace) + self._tasks_data = RedisStore(redis_settings, lrt_namespace) self._created_tasks: dict[TaskId, asyncio.Task] = {} self.stale_task_check_interval = stale_task_check_interval self.stale_task_detect_timeout_s: PositiveFloat = ( stale_task_detect_timeout.total_seconds() ) - self.redis_namespace = redis_namespace + self.lrt_namespace = lrt_namespace self.redis_settings = redis_settings self.locks_redis_client_sdk: RedisClientSDK | None = None @@ -130,16 +164,16 @@ def __init__( self._task_cancelled_tasks_removal: asyncio.Task | None = None self._started_event_task_cancelled_tasks_removal = asyncio.Event() - # status_update - self._task_status_update: asyncio.Task | None = None - self._started_event_task_status_update = asyncio.Event() + # tasks_monitor + self._task_tasks_monitor: asyncio.Task | None = None + self._started_event_task_tasks_monitor = asyncio.Event() async def setup(self) -> None: await self._tasks_data.setup() self.locks_redis_client_sdk = RedisClientSDK( self.redis_settings.build_redis_dsn(RedisDatabase.LOCKS), - client_name=f"long_running_tasks_store_{self.redis_namespace}_lock", + client_name=f"{__name__}_{self.lrt_namespace}_lock", ) await self.locks_redis_client_sdk.setup() @@ -147,7 +181,7 @@ async def setup(self) -> None: self._task_stale_tasks_monitor = create_periodic_task( task=exclusive( self.locks_redis_client_sdk, - lock_key=f"{__name__}_{self.redis_namespace}_stale_tasks_monitor", + lock_key=f"{__name__}_{self.lrt_namespace}_stale_tasks_monitor", )(self._stale_tasks_monitor), interval=self.stale_task_check_interval, task_name=f"{__name__}.{self._stale_tasks_monitor.__name__}", @@ -162,23 +196,23 @@ async def setup(self) -> None: ) await self._started_event_task_cancelled_tasks_removal.wait() - # status_update - self._task_status_update = create_periodic_task( - task=self._status_update, + # tasks_monitor + self._task_tasks_monitor = create_periodic_task( + task=self._tasks_monitor, interval=_STATUS_UPDATE_CHECK_INTERNAL, - task_name=f"{__name__}.{self._status_update.__name__}", + task_name=f"{__name__}.{self._tasks_monitor.__name__}", ) - await self._started_event_task_status_update.wait() + await self._started_event_task_tasks_monitor.wait() async def teardown(self) -> None: # ensure all created tasks are cancelled for tracked_task in await self._tasks_data.list_tasks_data(): - await self.remove_task( - tracked_task.task_id, - tracked_task.task_context, - # when closing we do not care about pending errors - reraise_errors=False, - ) + with suppress(TaskNotFoundError): + await self.remove_task( + tracked_task.task_id, + tracked_task.task_context, + wait_for_removal=True, + ) for task in self._created_tasks.values(): _logger.warning( @@ -189,15 +223,18 @@ async def teardown(self) -> None: # stale_tasks_monitor if self._task_stale_tasks_monitor: - await cancel_wait_task(self._task_stale_tasks_monitor) + await cancel_wait_task( + self._task_stale_tasks_monitor, + max_delay=_MAX_EXCLUSIVE_TASK_CANCEL_TIMEOUT, + ) # cancelled_tasks_removal if self._task_cancelled_tasks_removal: await cancel_wait_task(self._task_cancelled_tasks_removal) - # status_update - if self._task_status_update: - await cancel_wait_task(self._task_status_update) + # tasks_monitor + if self._task_tasks_monitor: + await cancel_wait_task(self._task_tasks_monitor) if self.locks_redis_client_sdk is not None: await self.locks_redis_client_sdk.shutdown() @@ -234,36 +271,36 @@ async def _stale_tasks_monitor(self) -> None: # - finished with a result # - finished with errors # we just print the status from where one can infer the above - _logger.warning( - "Removing stale task '%s' with status '%s'", - task_id, - ( - await self.get_task_status(task_id, with_task_context=task_context) - ).model_dump_json(), - ) - await self.remove_task( - task_id, with_task_context=task_context, reraise_errors=False - ) + with suppress(TaskNotFoundError): + task_status = await self.get_task_status( + task_id, with_task_context=task_context + ) + with log_context( + _logger, + logging.WARNING, + f"Removing stale task '{task_id}' with status '{task_status.model_dump_json()}'", + ): + await self.remove_task( + task_id, with_task_context=task_context, wait_for_removal=True + ) async def _cancelled_tasks_removal(self) -> None: """ - A task can be cancelled by the client, which implies it does not for sure - run in the same process as the one processing the request. - - This is a periodic task that ensures the cancellation occurred. + Periodically checks which tasks are marked for removal and attempts to remove the + task if it's handled by this process. """ self._started_event_task_cancelled_tasks_removal.set() - cancelled_tasks = await self._tasks_data.get_cancelled() - for task_id, task_context in cancelled_tasks.items(): - await self.remove_task(task_id, task_context, reraise_errors=False) + to_remove = await self._tasks_data.list_tasks_to_remove() + for task_id in to_remove: + await self._attempt_to_remove_local_task(task_id) - async def _status_update(self) -> None: + async def _tasks_monitor(self) -> None: """ A task which monitors locally running tasks and updates their status in the Redis store when they are done. """ - self._started_event_task_status_update.set() + self._started_event_task_tasks_monitor.set() task_id: TaskId for task_id in set(self._created_tasks.keys()): if task := self._created_tasks.get(task_id, None): @@ -278,25 +315,43 @@ async def _status_update(self) -> None: # already done and updatet data in redis continue - # update and store in Redis - task_data.is_done = is_done - + result_field: ResultField | None = None # get task result try: - task_data.result_field = ResultField( - result=object_to_string(task.result()) - ) + result_field = ResultField(str_result=dumps(task.result())) except asyncio.InvalidStateError: # task was not completed try again next time and see if it is done continue except asyncio.CancelledError: - task_data.result_field = ResultField( - error=object_to_string(TaskCancelledError(task_id=task_id)) + result_field = ResultField( + str_error=dumps(TaskCancelledError(task_id=task_id)) ) except Exception as e: # pylint:disable=broad-except - task_data.result_field = ResultField(error=object_to_string(e)) + allowed_errors = TaskRegistry.get_allowed_errors( + task_data.registered_task_name + ) + if type(e) not in allowed_errors: + _logger.exception( + **create_troubleshootting_log_kwargs( + ( + f"Execution of {task_id=} finished with unexpected error, " + f"only the following {allowed_errors=} are permitted" + ), + error=e, + error_context={ + "task_id": task_id, + "task_data": task_data, + "namespace": self.lrt_namespace, + }, + ), + ) + result_field = ResultField(str_error=dumps(e)) - await self._tasks_data.set_task_data(task_id, task_data) + # update and store in Redis + updates = {"is_done": is_done, "result_field": task_data.result_field} + if result_field is not None: + updates["result_field"] = result_field + await self._tasks_data.update_task_data(task_id, updates=updates) async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBase]: if not with_task_context: @@ -333,10 +388,12 @@ async def get_task_status( raises TaskNotFoundError if the task cannot be found """ - task_data: TaskData = await self._get_tracked_task(task_id, with_task_context) - task_data.last_status_check = datetime.datetime.now(tz=datetime.UTC) - await self._tasks_data.set_task_data(task_id, task_data) + task_data = await self._get_tracked_task(task_id, with_task_context) + await self._tasks_data.update_task_data( + task_id, + updates={"last_status_check": datetime.datetime.now(tz=datetime.UTC)}, + ) return TaskStatus.model_validate( { "task_progress": task_data.task_progress, @@ -345,14 +402,24 @@ async def get_task_status( } ) + async def get_allowed_errors( + self, task_id: TaskId, with_task_context: TaskContext + ) -> AllowedErrrors: + """ + returns: the allowed errors for the task + + raises TaskNotFoundError if the task cannot be found + """ + task_data = await self._get_tracked_task(task_id, with_task_context) + return TaskRegistry.get_allowed_errors(task_data.registered_task_name) + async def get_task_result( self, task_id: TaskId, with_task_context: TaskContext - ) -> Any: + ) -> ResultField: """ - returns: the result of the task + returns: the result of the task wrapped in ResultField raises TaskNotFoundError if the task cannot be found - raises TaskCancelledError if the task was cancelled raises TaskNotCompletedError if the task is not completed """ tracked_task = await self._get_tracked_task(task_id, with_task_context) @@ -360,70 +427,53 @@ async def get_task_result( if not tracked_task.is_done or tracked_task.result_field is None: raise TaskNotCompletedError(task_id=task_id) - if tracked_task.result_field.error is not None: - raise string_to_object(tracked_task.result_field.error) - - if tracked_task.result_field.result is None: - return None + return tracked_task.result_field - return string_to_object(tracked_task.result_field.result) + async def _attempt_to_remove_local_task(self, task_id: TaskId) -> None: + """if task is running in the local process, try to remove it""" - async def _cancel_tracked_task( - self, task: asyncio.Task, task_id: TaskId, with_task_context: TaskContext - ) -> None: - try: - await self._tasks_data.set_as_cancelled(task_id, with_task_context) - await cancel_wait_task(task) - except Exception as e: # pylint:disable=broad-except - _logger.info( - "Task %s cancellation failed with error: %s", - task_id, - e, - stack_info=True, - ) + task_to_cancel = self._created_tasks.pop(task_id, None) + if task_to_cancel is not None: + await cancel_wait_task(task_to_cancel) + await self._tasks_data.completed_task_removal(task_id) + await self._tasks_data.delete_task_data(task_id) async def remove_task( self, task_id: TaskId, with_task_context: TaskContext, *, - reraise_errors: bool = True, + wait_for_removal: bool, ) -> None: - """cancels and removes task""" - try: - tracked_task = await self._get_tracked_task(task_id, with_task_context) - except TaskNotFoundError: - if reraise_errors: - raise - return + """ + cancels and removes task + raises TaskNotFoundError if the task cannot be found + """ + tracked_task = await self._get_tracked_task(task_id, with_task_context) - if tracked_task.task_id in self._created_tasks: - task_to_cancel = self._created_tasks.pop(tracked_task.task_id, None) - if task_to_cancel is not None: - # canceling the task affects the worker that started it. - # awaiting the cancelled task is a must since if the CancelledError - # was intercepted, those actions need to finish - await cancel_wait_task(task_to_cancel) + await self._tasks_data.mark_task_for_removal( + tracked_task.task_id, tracked_task.task_context + ) - await self._tasks_data.delete_task_data(task_id) + if not wait_for_removal: + return # wait for task to be removed since it might not have been running # in this process - async for attempt in AsyncRetrying( - wait=wait_exponential(max=2), - stop=stop_after_delay(_TASK_REMOVAL_MAX_WAIT), - retry=retry_if_exception_type(TryAgain), - ): - with attempt: - try: - await self._get_tracked_task(task_id, with_task_context) - raise TryAgain - except TaskNotFoundError: - pass + with suppress(TaskNotFoundError): + async for attempt in AsyncRetrying( + wait=wait_exponential(max=1), + stop=stop_after_delay(_TASK_REMOVAL_MAX_WAIT), + retry=retry_unless_exception_type(TaskNotFoundError), + ): + with attempt: + await self._get_tracked_task( + tracked_task.task_id, tracked_task.task_context + ) def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId: - unique_part = "unique" if is_unique else f"{uuid4()}" - return f"{self.redis_namespace}.{task_name}.{unique_part}" + suffix = "unique" if is_unique else f"{uuid4()}" + return f"{self.lrt_namespace}.{task_name}.{suffix}" async def _update_progress( self, @@ -436,7 +486,9 @@ async def _update_progress( try: tracked_data = await self._get_tracked_task(task_id, task_context) tracked_data.task_progress = task_progress - await self._tasks_data.set_task_data(task_id=task_id, value=tracked_data) + await self._tasks_data.update_task_data( + task_id, updates={"task_progress": task_progress.model_dump()} + ) except TaskNotFoundError: _logger.debug( "Task '%s' not found while updating progress %s", @@ -454,10 +506,13 @@ async def start_task( fire_and_forget: bool, **task_kwargs: Any, ) -> TaskId: - if registered_task_name not in TaskRegistry.REGISTERED_TASKS: - raise TaskNotRegisteredError(task_name=registered_task_name) + registered_tasks = TaskRegistry.get_registered_tasks() + if registered_task_name not in registered_tasks: + raise TaskNotRegisteredError( + task_name=registered_task_name, tasks=registered_tasks + ) - task = TaskRegistry.REGISTERED_TASKS[registered_task_name] + task = TaskRegistry.get_task(registered_task_name) # NOTE: If not task name is given, it will be composed of the handler's module and it's name # to keep the urls shorter and more meaningful. @@ -495,12 +550,13 @@ async def _task_with_progress(progress: TaskProgress, handler: TaskProtocol): ) tracked_task = TaskData( + registered_task_name=registered_task_name, task_id=task_id, task_progress=task_progress, task_context=context_to_use, fire_and_forget=fire_and_forget, ) - await self._tasks_data.set_task_data(task_id, tracked_task) + await self._tasks_data.add_task_data(task_id, tracked_task) return tracked_task.task_id diff --git a/packages/service-library/src/servicelib/rabbitmq/_rpc_router.py b/packages/service-library/src/servicelib/rabbitmq/_rpc_router.py index 49cab08f79b2..10dbf26a4497 100644 --- a/packages/service-library/src/servicelib/rabbitmq/_rpc_router.py +++ b/packages/service-library/src/servicelib/rabbitmq/_rpc_router.py @@ -65,8 +65,8 @@ async def _wrapper(*args, **kwargs): raise _logger.exception( - "Unhandled exception on the rpc-server side." - " Re-raising as RPCServerError." + "Unhandled exception on the rpc-server side. Re-raising as %s.", + RPCServerError.__name__, ) # NOTE: we do not return internal exceptions over RPC formatted_traceback = "\n".join( diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py b/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py index 917cd335c65c..3bf527ab2c82 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py @@ -26,6 +26,10 @@ from tenacity.wait import wait_fixed +class _TestingError(Exception): + pass + + async def _string_list_task( progress: TaskProgress, num_strings: int, @@ -39,7 +43,7 @@ async def _string_list_task( await progress.update(message="generated item", percent=index / num_strings) if fail: msg = "We were asked to fail!!" - raise RuntimeError(msg) + raise _TestingError(msg) # NOTE: this code is used just for the sake of not returning the default 200 return web.json_response( @@ -47,7 +51,7 @@ async def _string_list_task( ) -TaskRegistry.register(_string_list_task) +TaskRegistry.register(_string_list_task, allowed_errors=(_TestingError,)) @pytest.fixture diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py index 94eaecef7e30..49604fd3a15e 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py @@ -23,6 +23,7 @@ from servicelib.aiohttp.rest_middlewares import append_rest_middlewares from servicelib.long_running_tasks.models import TaskGet, TaskId, TaskStatus from servicelib.long_running_tasks.task import TaskContext +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type @@ -30,17 +31,15 @@ from tenacity.wait import wait_fixed pytest_simcore_core_services_selection = [ - "redis", -] - -pytest_simcore_ops_services_selection = [ - "redis-commander", + "rabbit", ] @pytest.fixture def app( - server_routes: web.RouteTableDef, redis_service: RedisSettings + server_routes: web.RouteTableDef, + use_in_memory_redis: RedisSettings, + rabbit_service: RabbitSettings, ) -> web.Application: app = web.Application() app.add_routes(server_routes) @@ -48,8 +47,9 @@ def app( append_rest_middlewares(app, api_version="") long_running_tasks.server.setup( app, - redis_settings=redis_service, - redis_namespace="test", + redis_settings=use_in_memory_redis, + rabbit_settings=rabbit_service, + lrt_namespace="test", router_prefix="/futures", ) @@ -127,7 +127,7 @@ async def test_workflow( [ ("GET", "get_task_status"), ("GET", "get_task_result"), - ("DELETE", "cancel_and_delete_task"), + ("DELETE", "remove_task"), ], ) async def test_get_task_wrong_task_id_raises_not_found( @@ -164,7 +164,7 @@ async def test_failing_task_returns_error( # The actual error details should be logged, not returned in response log_messages = caplog.text assert "OEC" in log_messages - assert "RuntimeError" in log_messages + assert "_TestingError" in log_messages assert "We were asked to fail!!" in log_messages @@ -188,7 +188,7 @@ async def test_cancel_task( task_id = await start_long_running_task(client) # cancel the task - delete_url = client.app.router["cancel_and_delete_task"].url_for(task_id=task_id) + delete_url = client.app.router["remove_task"].url_for(task_id=task_id) result = await client.delete(f"{delete_url}") data, error = await assert_status(result, status.HTTP_204_NO_CONTENT) assert not data diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py index 0d6265197345..9e8c9204acef 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py @@ -15,13 +15,20 @@ long_running_task_request, ) from servicelib.aiohttp.rest_middlewares import append_rest_middlewares +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from yarl import URL +pytest_simcore_core_services_selection = [ + "rabbit", +] + @pytest.fixture def app( - server_routes: web.RouteTableDef, use_in_memory_redis: RedisSettings + server_routes: web.RouteTableDef, + use_in_memory_redis: RedisSettings, + rabbit_service: RabbitSettings, ) -> web.Application: app = web.Application() app.add_routes(server_routes) @@ -30,7 +37,8 @@ def app( long_running_tasks.server.setup( app, redis_settings=use_in_memory_redis, - redis_namespace="test", + rabbit_settings=rabbit_service, + lrt_namespace="test", router_prefix="/futures", ) @@ -58,7 +66,7 @@ async def test_long_running_task_request_raises_400( client: TestClient, long_running_task_url: URL ): # missing parameters raises - with pytest.raises(ClientResponseError): + with pytest.raises(ClientResponseError): # noqa: PT012 async for _ in long_running_task_request( client.session, long_running_task_url, None ): @@ -95,7 +103,7 @@ async def test_long_running_task_request_timeout( ): assert client.app task: LRTask | None = None - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(asyncio.TimeoutError): # noqa: PT012 async for task in long_running_task_request( client.session, long_running_task_url.with_query(num_strings=10, sleep_time=1), diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py index e303d9362ada..cef4a845ab8d 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py @@ -27,10 +27,11 @@ from servicelib.aiohttp.typing_extension import Handler from servicelib.long_running_tasks.models import TaskGet, TaskId from servicelib.long_running_tasks.task import TaskContext +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings pytest_simcore_core_services_selection = [ - "redis", + "rabbit", ] # WITH TASK CONTEXT # NOTE: as the long running task framework may be used in any number of services @@ -67,7 +68,8 @@ async def _test_task_context_decorator( def app_with_task_context( server_routes: web.RouteTableDef, task_context_decorator, - redis_service: RedisSettings, + use_in_memory_redis: RedisSettings, + rabbit_service: RabbitSettings, ) -> web.Application: app = web.Application() app.add_routes(server_routes) @@ -75,8 +77,9 @@ def app_with_task_context( append_rest_middlewares(app, api_version="") long_running_tasks.server.setup( app, - redis_settings=redis_service, - redis_namespace="test", + redis_settings=use_in_memory_redis, + rabbit_settings=rabbit_service, + lrt_namespace="test", router_prefix="/futures_with_task_context", task_request_context_decorator=task_context_decorator, ) @@ -169,7 +172,7 @@ async def test_cancel_task( ): assert client_with_task_context.app task_id = await start_long_running_task(client_with_task_context) - cancel_url = client_with_task_context.app.router["cancel_and_delete_task"].url_for( + cancel_url = client_with_task_context.app.router["remove_task"].url_for( task_id=task_id ) # calling cancel without task context should find nothing diff --git a/packages/service-library/tests/conftest.py b/packages/service-library/tests/conftest.py index 845a8565d226..c4f63a18a1ba 100644 --- a/packages/service-library/tests/conftest.py +++ b/packages/service-library/tests/conftest.py @@ -3,9 +3,10 @@ # pylint: disable=unused-argument # pylint: disable=unused-import +import asyncio import sys from collections.abc import AsyncIterable, AsyncIterator, Callable -from contextlib import AbstractAsyncContextManager, asynccontextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress from copy import deepcopy from pathlib import Path from typing import Any @@ -25,6 +26,7 @@ "pytest_simcore.environment_configs", "pytest_simcore.file_extra", "pytest_simcore.logging", + "pytest_simcore.long_running_tasks", "pytest_simcore.pytest_global_environs", "pytest_simcore.rabbit_service", "pytest_simcore.redis_service", @@ -91,7 +93,8 @@ async def _( yield client - await client.shutdown() + with suppress(TimeoutError): + await asyncio.wait_for(client.shutdown(), timeout=5.0) async def _cleanup_redis_data(clients_manager: RedisClientsManager) -> None: for db in RedisDatabase: diff --git a/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py b/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py index 6a3a872c8616..a26156f2c3f2 100644 --- a/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py +++ b/packages/service-library/tests/deferred_tasks/test__base_deferred_handler.py @@ -50,10 +50,10 @@ class MockKeys(StrAutoEnum): @pytest.fixture async def redis_client_sdk( - redis_service: RedisSettings, + use_in_memory_redis: RedisSettings, ) -> AsyncIterable[RedisClientSDK]: sdk = RedisClientSDK( - redis_service.build_redis_dsn(RedisDatabase.DEFERRED_TASKS), + use_in_memory_redis.build_redis_dsn(RedisDatabase.DEFERRED_TASKS), decode_responses=False, client_name="pytest", ) diff --git a/packages/service-library/tests/fastapi/long_running_tasks/conftest.py b/packages/service-library/tests/fastapi/long_running_tasks/conftest.py index 0cab1161c096..f10a27c322ac 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/conftest.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/conftest.py @@ -9,17 +9,22 @@ from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from servicelib.fastapi import long_running_tasks +from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings @pytest.fixture -async def bg_task_app(router_prefix: str, redis_service: RedisSettings) -> FastAPI: +async def bg_task_app( + router_prefix: str, redis_service: RedisSettings, rabbit_service: RabbitSettings +) -> FastAPI: app = FastAPI() long_running_tasks.server.setup( app, redis_settings=redis_service, - redis_namespace="test", + rabbit_settings=rabbit_service, + lrt_namespace="test", router_prefix=router_prefix, ) return app @@ -33,3 +38,14 @@ async def async_client(bg_task_app: FastAPI) -> AsyncIterable[AsyncClient]: headers={"Content-Type": "application/json"}, ) as client: yield client + + +@pytest.fixture +async def rabbitmq_rpc_client( + rabbit_service: RabbitSettings, +) -> AsyncIterable[RabbitMQRPCClient]: + client = await RabbitMQRPCClient.create( + client_name="test-lrt-rpc-client", settings=rabbit_service + ) + yield client + await client.close() diff --git a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_client.py b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_client.py index 02d392126cbf..42f76a58f724 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_client.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_client.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( "error_class, error_args", [ - (HTTPError, dict(message="")), + (HTTPError, {"message": ""}), ], ) async def test_retry_on_errors( diff --git a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py index 0f10b7a165f5..1b72713dbd5c 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py @@ -33,6 +33,7 @@ TaskStatus, ) from servicelib.long_running_tasks.task import TaskContext, TaskRegistry +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type @@ -40,9 +41,17 @@ from tenacity.wait import wait_fixed from yarl import URL +pytest_simcore_core_services_selection = [ + "rabbit", +] + ITEM_PUBLISH_SLEEP: Final[float] = 0.1 +class _TestingError(Exception): + pass + + async def _string_list_task( progress: TaskProgress, num_strings: int, @@ -56,12 +65,12 @@ async def _string_list_task( await progress.update(message="generated item", percent=index / num_strings) if fail: msg = "We were asked to fail!!" - raise RuntimeError(msg) + raise _TestingError(msg) return generated_strings -TaskRegistry.register(_string_list_task) +TaskRegistry.register(_string_list_task, allowed_errors=(_TestingError,)) @pytest.fixture @@ -81,7 +90,8 @@ async def create_string_list_task( fail: bool = False, ) -> TaskId: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, _string_list_task.__name__, num_strings=num_strings, sleep_time=sleep_time, @@ -93,14 +103,21 @@ async def create_string_list_task( @pytest.fixture async def app( - server_routes: APIRouter, use_in_memory_redis: RedisSettings + server_routes: APIRouter, + use_in_memory_redis: RedisSettings, + rabbit_service: RabbitSettings, ) -> AsyncIterator[FastAPI]: # overrides fastapi/conftest.py:app app = FastAPI(title="test app") app.include_router(server_routes) - setup_server(app, redis_settings=use_in_memory_redis, redis_namespace="test") + setup_server( + app, + redis_settings=use_in_memory_redis, + rabbit_settings=rabbit_service, + lrt_namespace="test", + ) setup_client(app) - async with LifespanManager(app): + async with LifespanManager(app, startup_timeout=30, shutdown_timeout=30): yield app @@ -205,7 +222,7 @@ async def test_workflow( [ ("GET", "get_task_status"), ("GET", "get_task_result"), - ("DELETE", "cancel_and_delete_task"), + ("DELETE", "remove_task"), ], ) async def test_get_task_wrong_task_id_raises_not_found( @@ -229,7 +246,8 @@ async def test_failing_task_returns_error( await wait_for_task(app, client, task_id, {}) # get the result result_url = app.url_path_for("get_task_result", task_id=task_id) - with pytest.raises(RuntimeError) as exec_info: + + with pytest.raises(_TestingError) as exec_info: await client.get(f"{result_url}") assert f"{exec_info.value}" == "We were asked to fail!!" @@ -254,7 +272,7 @@ async def test_cancel_task( task_id = await start_long_running_task(app, client) # cancel the task - delete_url = app.url_path_for("cancel_and_delete_task", task_id=task_id) + delete_url = app.url_path_for("remove_task", task_id=task_id) result = await client.delete(f"{delete_url}") assert result.status_code == status.HTTP_204_NO_CONTENT diff --git a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py index 179c967088f1..30418fd922a3 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py @@ -28,8 +28,13 @@ TaskProgress, ) from servicelib.long_running_tasks.task import TaskRegistry +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings +pytest_simcore_core_services_selection = [ + "rabbit", +] + TASK_SLEEP_INTERVAL: Final[PositiveFloat] = 0.1 # UTILS @@ -51,11 +56,15 @@ async def a_test_task(progress: TaskProgress) -> int: TaskRegistry.register(a_test_task) +class _TestingError(Exception): + pass + + async def a_failing_test_task(progress: TaskProgress) -> None: _ = progress await asyncio.sleep(TASK_SLEEP_INTERVAL) msg = "I am failing as requested" - raise RuntimeError(msg) + raise _TestingError(msg) TaskRegistry.register(a_failing_test_task) @@ -72,7 +81,9 @@ async def create_task_user_defined_route( ], ) -> TaskId: return await lrt_api.start_task( - long_running_manager.tasks_manager, a_test_task.__name__ + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + a_test_task.__name__, ) @router.get("/api/failing", status_code=status.HTTP_200_OK) @@ -82,7 +93,9 @@ async def create_task_which_fails( ], ) -> TaskId: return await lrt_api.start_task( - long_running_manager.tasks_manager, a_failing_test_task.__name__ + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + a_failing_test_task.__name__, ) return router @@ -90,7 +103,10 @@ async def create_task_which_fails( @pytest.fixture async def bg_task_app( - user_routes: APIRouter, router_prefix: str, use_in_memory_redis: RedisSettings + user_routes: APIRouter, + router_prefix: str, + use_in_memory_redis: RedisSettings, + rabbit_service: RabbitSettings, ) -> AsyncIterable[FastAPI]: app = FastAPI() @@ -100,11 +116,12 @@ async def bg_task_app( app, router_prefix=router_prefix, redis_settings=use_in_memory_redis, - redis_namespace="test", + rabbit_settings=rabbit_service, + lrt_namespace="test", ) setup_client(app, router_prefix=router_prefix) - async with LifespanManager(app): + async with LifespanManager(app, startup_timeout=30, shutdown_timeout=30): yield app diff --git a/packages/service-library/tests/long_running_tasks/conftest.py b/packages/service-library/tests/long_running_tasks/conftest.py new file mode 100644 index 000000000000..2d1d6d2d6377 --- /dev/null +++ b/packages/service-library/tests/long_running_tasks/conftest.py @@ -0,0 +1,85 @@ +# pylint: disable=protected-access +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument + +import asyncio +import logging +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from datetime import timedelta + +import pytest +from faker import Faker +from pytest_mock import MockerFixture +from servicelib.logging_utils import log_catch +from servicelib.long_running_tasks.base_long_running_manager import ( + BaseLongRunningManager, +) +from servicelib.long_running_tasks.models import LRTNamespace, TaskContext +from servicelib.long_running_tasks.task import TasksManager +from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient +from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings +from utils import TEST_CHECK_STALE_INTERVAL_S + +_logger = logging.getLogger(__name__) + + +class _TestingLongRunningManager(BaseLongRunningManager): + @staticmethod + def get_task_context(request) -> TaskContext: + _ = request + return {} + + +@pytest.fixture +async def get_long_running_manager( + fast_long_running_tasks_cancellation: None, faker: Faker +) -> AsyncIterator[ + Callable[ + [RedisSettings, RabbitSettings, LRTNamespace | None], + Awaitable[BaseLongRunningManager], + ] +]: + managers: list[BaseLongRunningManager] = [] + + async def _( + redis_settings: RedisSettings, + rabbit_settings: RabbitSettings, + lrt_namespace: LRTNamespace | None, + ) -> BaseLongRunningManager: + manager = _TestingLongRunningManager( + stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), + stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), + redis_settings=redis_settings, + rabbit_settings=rabbit_settings, + lrt_namespace=lrt_namespace or f"test{faker.uuid4()}", + ) + await manager.setup() + managers.append(manager) + return manager + + yield _ + + for manager in managers: + with log_catch(_logger, reraise=False): + await asyncio.wait_for(manager.teardown(), timeout=5) + + +@pytest.fixture +async def rabbitmq_rpc_client( + rabbit_service: RabbitSettings, +) -> AsyncIterable[RabbitMQRPCClient]: + client = await RabbitMQRPCClient.create( + client_name="test-lrt-rpc-client", settings=rabbit_service + ) + yield client + await client.close() + + +@pytest.fixture +def disable_stale_tasks_monitor(mocker: MockerFixture) -> None: + # no need to autoremove stale tasks in these tests + async def _to_replace(self: TasksManager) -> None: + self._started_event_task_stale_tasks_monitor.set() + + mocker.patch.object(TasksManager, "_stale_tasks_monitor", _to_replace) diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__error_serialization.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__error_serialization.py deleted file mode 100644 index f0d0d14f165b..000000000000 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__error_serialization.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Any - -import pytest -from aiohttp.web import HTTPException, HTTPInternalServerError -from servicelib.aiohttp.long_running_tasks._server import AiohttpHTTPExceptionSerializer -from servicelib.long_running_tasks._redis_serialization import ( - object_to_string, - register_custom_serialization, - string_to_object, -) - -register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer) - - -class PositionalArguments: - def __init__(self, arg1, arg2, *args): - self.arg1 = arg1 - self.arg2 = arg2 - self.args = args - - -class MixedArguments: - def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): - self.arg1 = arg1 - self.arg2 = arg2 - self.kwarg1 = kwarg1 - self.kwarg2 = kwarg2 - - -@pytest.mark.parametrize( - "obj", - [ - HTTPInternalServerError(reason="Uh-oh!", text="Failure!"), - PositionalArguments("arg1", "arg2", "arg3", "arg4"), - MixedArguments("arg1", "arg2", kwarg1="kwarg1", kwarg2="kwarg2"), - "a_string", - 1, - ], -) -def test_serialization(obj: Any): - str_data = object_to_string(obj) - - reconstructed_obj = string_to_object(str_data) - - assert type(reconstructed_obj) is type(obj) - if hasattr(obj, "__dict__"): - assert reconstructed_obj.__dict__ == obj.__dict__ diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_serialization.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__serialization.py similarity index 80% rename from packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_serialization.py rename to packages/service-library/tests/long_running_tasks/test_long_running_tasks__serialization.py index f0d0d14f165b..3b7562e55503 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_serialization.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__serialization.py @@ -3,10 +3,10 @@ import pytest from aiohttp.web import HTTPException, HTTPInternalServerError from servicelib.aiohttp.long_running_tasks._server import AiohttpHTTPExceptionSerializer -from servicelib.long_running_tasks._redis_serialization import ( - object_to_string, +from servicelib.long_running_tasks._serialization import ( + dumps, + loads, register_custom_serialization, - string_to_object, ) register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer) @@ -38,9 +38,12 @@ def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): ], ) def test_serialization(obj: Any): - str_data = object_to_string(obj) + str_data = dumps(obj) - reconstructed_obj = string_to_object(str_data) + try: + reconstructed_obj = loads(str_data) + except Exception as exc: # pylint:disable=broad-exception-caught + reconstructed_obj = exc assert type(reconstructed_obj) is type(obj) if hasattr(obj, "__dict__"): diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py index bd4586be6487..218af7a9aaae 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py @@ -5,8 +5,7 @@ import pytest from pydantic import TypeAdapter -from servicelib.long_running_tasks._store.base import BaseStore -from servicelib.long_running_tasks._store.redis import RedisStore +from servicelib.long_running_tasks._redis_store import RedisStore from servicelib.long_running_tasks.models import TaskData from servicelib.redis._client import RedisClientSDK from settings_library.redis import RedisDatabase, RedisSettings @@ -25,7 +24,7 @@ async def store( get_redis_client_sdk: Callable[ [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] ], -) -> AsyncIterable[BaseStore]: +) -> AsyncIterable[RedisStore]: store = RedisStore(redis_settings=use_in_memory_redis, namespace="test") await store.setup() @@ -37,12 +36,12 @@ async def store( pass -async def test_workflow(store: BaseStore, task_data: TaskData) -> None: +async def test_workflow(store: RedisStore, task_data: TaskData) -> None: # task data assert await store.list_tasks_data() == [] assert await store.get_task_data("missing") is None - await store.set_task_data(task_data.task_id, task_data) + await store.add_task_data(task_data.task_id, task_data) assert await store.list_tasks_data() == [task_data] @@ -51,11 +50,13 @@ async def test_workflow(store: BaseStore, task_data: TaskData) -> None: assert await store.list_tasks_data() == [] # cancelled tasks - assert await store.get_cancelled() == {} + assert await store.list_tasks_to_remove() == {} - await store.set_as_cancelled(task_data.task_id, task_data.task_context) + await store.mark_task_for_removal(task_data.task_id, task_data.task_context) - assert await store.get_cancelled() == {task_data.task_id: task_data.task_context} + assert await store.list_tasks_to_remove() == { + task_data.task_id: task_data.task_context + } @pytest.fixture @@ -88,15 +89,15 @@ async def test_workflow_multiple_redis_stores_with_different_namespaces( for store in redis_stores: assert await store.list_tasks_data() == [] - assert await store.get_cancelled() == {} + assert await store.list_tasks_to_remove() == {} for store in redis_stores: - await store.set_task_data(task_data.task_id, task_data) - await store.set_as_cancelled(task_data.task_id, None) + await store.add_task_data(task_data.task_id, task_data) + await store.mark_task_for_removal(task_data.task_id, {}) for store in redis_stores: assert await store.list_tasks_data() == [task_data] - assert await store.get_cancelled() == {task_data.task_id: None} + assert await store.list_tasks_to_remove() == {task_data.task_id: {}} for store in redis_stores: await store.delete_task_data(task_data.task_id) diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_lrt_api.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_lrt_api.py new file mode 100644 index 000000000000..e1742a17013b --- /dev/null +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_lrt_api.py @@ -0,0 +1,327 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument + +import asyncio +import secrets +from collections.abc import Awaitable, Callable +from typing import Any, Final + +import pytest +from models_library.api_schemas_long_running_tasks.base import TaskProgress +from pydantic import NonNegativeInt +from servicelib.long_running_tasks import lrt_api +from servicelib.long_running_tasks.base_long_running_manager import ( + BaseLongRunningManager, +) +from servicelib.long_running_tasks.errors import TaskNotFoundError +from servicelib.long_running_tasks.models import LRTNamespace, TaskContext +from servicelib.long_running_tasks.task import TaskId, TaskRegistry +from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient +from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings +from tenacity import ( + AsyncRetrying, + TryAgain, + retry_if_exception_type, + stop_after_delay, + wait_fixed, +) + +pytest_simcore_core_services_selection = [ + "rabbit", +] + +_RETRY_PARAMS: dict[str, Any] = { + "reraise": True, + "wait": wait_fixed(0.1), + "stop": stop_after_delay(60), + "retry": retry_if_exception_type((AssertionError, TryAgain)), +} + + +async def _task_echo_input(progress: TaskProgress, to_return: Any) -> Any: + return to_return + + +class _TestingError(Exception): + pass + + +async def _task_always_raise(progress: TaskProgress) -> None: + msg = "This task always raises an error" + raise _TestingError(msg) + + +async def _task_takes_too_long(progress: TaskProgress) -> None: + # Simulate a long-running task that is taking too much time + await asyncio.sleep(1e9) + + +TaskRegistry.register(_task_echo_input) +TaskRegistry.register(_task_always_raise, allowed_errors=(_TestingError,)) +TaskRegistry.register(_task_takes_too_long) + + +@pytest.fixture +def managers_count() -> NonNegativeInt: + return 5 + + +@pytest.fixture +async def long_running_managers( + disable_stale_tasks_monitor: None, + managers_count: NonNegativeInt, + use_in_memory_redis: RedisSettings, + rabbit_service: RabbitSettings, + get_long_running_manager: Callable[ + [RedisSettings, RabbitSettings, LRTNamespace | None], + Awaitable[BaseLongRunningManager], + ], +) -> list[BaseLongRunningManager]: + maanagers: list[BaseLongRunningManager] = [] + for _ in range(managers_count): + long_running_manager = await get_long_running_manager( + use_in_memory_redis, rabbit_service, "some-service" + ) + maanagers.append(long_running_manager) + + return maanagers + + +def _get_long_running_manager( + long_running_managers: list[BaseLongRunningManager], +) -> BaseLongRunningManager: + return secrets.choice(long_running_managers) + + +async def _assert_task_status( + rabbitmq_rpc_client: RabbitMQRPCClient, + long_running_manager: BaseLongRunningManager, + task_id: TaskId, + *, + is_done: bool +) -> None: + result = await lrt_api.get_task_status( + rabbitmq_rpc_client, long_running_manager.lrt_namespace, TaskContext(), task_id + ) + assert result.done is is_done + + +async def _assert_task_status_on_random_manager( + rabbitmq_rpc_client: RabbitMQRPCClient, + long_running_managers: list[BaseLongRunningManager], + task_ids: list[TaskId], + *, + is_done: bool = True +) -> None: + for task_id in task_ids: + result = await lrt_api.get_task_status( + rabbitmq_rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + TaskContext(), + task_id, + ) + assert result.done is is_done + + +async def _assert_task_status_done_on_all_managers( + rabbitmq_rpc_client: RabbitMQRPCClient, + long_running_managers: list[BaseLongRunningManager], + task_id: TaskId, + *, + is_done: bool = True +) -> None: + async for attempt in AsyncRetrying(**_RETRY_PARAMS): + with attempt: + await _assert_task_status( + rabbitmq_rpc_client, + _get_long_running_manager(long_running_managers), + task_id, + is_done=is_done, + ) + + # check can do this form any task manager + for manager in long_running_managers: + await _assert_task_status( + rabbitmq_rpc_client, manager, task_id, is_done=is_done + ) + + +async def _assert_list_tasks_from_all_managers( + rabbitmq_rpc_client: RabbitMQRPCClient, + long_running_managers: list[BaseLongRunningManager], + task_context: TaskContext, + task_count: int, +) -> None: + for manager in long_running_managers: + tasks = await lrt_api.list_tasks( + rabbitmq_rpc_client, manager.lrt_namespace, task_context + ) + assert len(tasks) == task_count + + +async def _assert_task_is_no_longer_present( + rabbitmq_rpc_client: RabbitMQRPCClient, + long_running_managers: list[BaseLongRunningManager], + task_context: TaskContext, + task_id: TaskId, +) -> None: + with pytest.raises(TaskNotFoundError): + await lrt_api.get_task_status( + rabbitmq_rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + task_context, + task_id, + ) + + +_TASK_CONTEXT: Final[list[TaskContext | None]] = [{"a": "context"}, None] +_IS_UNIQUE: Final[list[bool]] = [False, True] +_TASK_COUNT: Final[list[int]] = [5] + + +@pytest.mark.parametrize("task_count", _TASK_COUNT) +@pytest.mark.parametrize("task_context", _TASK_CONTEXT) +@pytest.mark.parametrize("is_unique", _IS_UNIQUE) +@pytest.mark.parametrize("to_return", [{"key": "value"}]) +async def test_workflow_with_result( + long_running_managers: list[BaseLongRunningManager], + rabbitmq_rpc_client: RabbitMQRPCClient, + task_count: int, + is_unique: bool, + task_context: TaskContext | None, + to_return: Any, +): + saved_context = task_context or {} + task_count = 1 if is_unique else task_count + + task_ids: list[TaskId] = [] + for _ in range(task_count): + task_id = await lrt_api.start_task( + _get_long_running_manager(long_running_managers).rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + _task_echo_input.__name__, + unique=is_unique, + task_name=None, + task_context=task_context, + fire_and_forget=False, + to_return=to_return, + ) + task_ids.append(task_id) + + for task_id in task_ids: + await _assert_task_status_done_on_all_managers( + rabbitmq_rpc_client, long_running_managers, task_id + ) + + await _assert_list_tasks_from_all_managers( + rabbitmq_rpc_client, long_running_managers, saved_context, task_count=task_count + ) + + # avoids tasks getting garbage collected + await _assert_task_status_on_random_manager( + rabbitmq_rpc_client, long_running_managers, task_ids, is_done=True + ) + + for task_id in task_ids: + result = await lrt_api.get_task_result( + rabbitmq_rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + saved_context, + task_id, + ) + assert result == to_return + + await _assert_task_is_no_longer_present( + rabbitmq_rpc_client, long_running_managers, saved_context, task_id + ) + + +@pytest.mark.parametrize("task_count", _TASK_COUNT) +@pytest.mark.parametrize("task_context", _TASK_CONTEXT) +@pytest.mark.parametrize("is_unique", _IS_UNIQUE) +async def test_workflow_raises_error( + long_running_managers: list[BaseLongRunningManager], + rabbitmq_rpc_client: RabbitMQRPCClient, + task_count: int, + is_unique: bool, + task_context: TaskContext | None, +): + saved_context = task_context or {} + task_count = 1 if is_unique else task_count + + task_ids: list[TaskId] = [] + for _ in range(task_count): + task_id = await lrt_api.start_task( + _get_long_running_manager(long_running_managers).rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + _task_always_raise.__name__, + unique=is_unique, + task_name=None, + task_context=task_context, + fire_and_forget=False, + ) + task_ids.append(task_id) + + for task_id in task_ids: + await _assert_task_status_done_on_all_managers( + rabbitmq_rpc_client, long_running_managers, task_id + ) + + await _assert_list_tasks_from_all_managers( + rabbitmq_rpc_client, long_running_managers, saved_context, task_count=task_count + ) + + # avoids tasks getting garbage collected + await _assert_task_status_on_random_manager( + rabbitmq_rpc_client, long_running_managers, task_ids, is_done=True + ) + + for task_id in task_ids: + with pytest.raises(_TestingError, match="This task always raises an error"): + await lrt_api.get_task_result( + rabbitmq_rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + saved_context, + task_id, + ) + + await _assert_task_is_no_longer_present( + rabbitmq_rpc_client, long_running_managers, saved_context, task_id + ) + + +@pytest.mark.parametrize("task_context", _TASK_CONTEXT) +@pytest.mark.parametrize("is_unique", _IS_UNIQUE) +async def test_remove_task( + long_running_managers: list[BaseLongRunningManager], + rabbitmq_rpc_client: RabbitMQRPCClient, + is_unique: bool, + task_context: TaskContext | None, +): + task_id = await lrt_api.start_task( + _get_long_running_manager(long_running_managers).rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + _task_takes_too_long.__name__, + unique=is_unique, + task_name=None, + task_context=task_context, + fire_and_forget=False, + ) + saved_context = task_context or {} + + await _assert_task_status_done_on_all_managers( + rabbitmq_rpc_client, long_running_managers, task_id, is_done=False + ) + + await lrt_api.remove_task( + rabbitmq_rpc_client, + _get_long_running_manager(long_running_managers).lrt_namespace, + saved_context, + task_id, + wait_for_removal=True, + ) + + await _assert_task_is_no_longer_present( + rabbitmq_rpc_client, long_running_managers, saved_context, task_id + ) diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py index 902061cbbf2c..78b3ac74a0b6 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py @@ -6,33 +6,47 @@ import asyncio import urllib.parse -from collections.abc import AsyncIterator, Awaitable, Callable -from contextlib import suppress -from datetime import datetime, timedelta -from typing import Any, Final +from collections.abc import Awaitable, Callable +from datetime import datetime +from typing import Any import pytest from faker import Faker from models_library.api_schemas_long_running_tasks.base import ProgressMessage from servicelib.long_running_tasks import lrt_api +from servicelib.long_running_tasks._serialization import ( + loads, +) +from servicelib.long_running_tasks.base_long_running_manager import ( + BaseLongRunningManager, +) from servicelib.long_running_tasks.errors import ( TaskAlreadyRunningError, TaskNotCompletedError, TaskNotFoundError, TaskNotRegisteredError, ) -from servicelib.long_running_tasks.models import TaskContext, TaskProgress, TaskStatus -from servicelib.long_running_tasks.task import ( - RedisNamespace, - TaskRegistry, - TasksManager, +from servicelib.long_running_tasks.models import ( + LRTNamespace, + ResultField, + TaskContext, + TaskProgress, + TaskStatus, ) +from servicelib.long_running_tasks.task import TaskRegistry +from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from tenacity import TryAgain from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type from tenacity.stop import stop_after_delay from tenacity.wait import wait_fixed +from utils import TEST_CHECK_STALE_INTERVAL_S + +pytest_simcore_core_services_selection = [ + "rabbit", +] _RETRY_PARAMS: dict[str, Any] = { "reraise": True, @@ -42,6 +56,10 @@ } +class _TetingError(Exception): + pass + + async def a_background_task( progress: TaskProgress, raise_when_finished: bool, @@ -53,7 +71,7 @@ async def a_background_task( await progress.update(percent=(i + 1) / total_sleep) if raise_when_finished: msg = "raised this error as instructed" - raise RuntimeError(msg) + raise _TetingError(msg) return 42 @@ -66,15 +84,13 @@ async def fast_background_task(progress: TaskProgress) -> int: async def failing_background_task(progress: TaskProgress): """this task does nothing and returns a constant""" msg = "failing asap" - raise RuntimeError(msg) + raise _TetingError(msg) TaskRegistry.register(a_background_task) TaskRegistry.register(fast_background_task) TaskRegistry.register(failing_background_task) -TEST_CHECK_STALE_INTERVAL_S: Final[float] = 1 - @pytest.fixture def empty_context() -> TaskContext: @@ -82,51 +98,28 @@ def empty_context() -> TaskContext: @pytest.fixture -async def get_tasks_manager( - faker: Faker, -) -> AsyncIterator[ - Callable[[RedisSettings, RedisNamespace | None], Awaitable[TasksManager]] -]: - managers: list[TasksManager] = [] - - async def _( - redis_settings: RedisSettings, namespace: RedisNamespace | None - ) -> TasksManager: - tasks_manager = TasksManager( - stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), - stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), - redis_settings=redis_settings, - redis_namespace=namespace or f"test{faker.uuid4()}", - ) - await tasks_manager.setup() - managers.append(tasks_manager) - return tasks_manager - - yield _ - - for manager in managers: - with suppress(TimeoutError): # avoids tets hanging on teardown - await asyncio.wait_for(manager.teardown(), timeout=10) - - -@pytest.fixture -async def tasks_manager( +async def long_running_manager( use_in_memory_redis: RedisSettings, - get_tasks_manager: Callable[ - [RedisSettings, RedisNamespace | None], Awaitable[TasksManager] + rabbit_service: RabbitSettings, + get_long_running_manager: Callable[ + [RedisSettings, RabbitSettings, LRTNamespace | None], + Awaitable[BaseLongRunningManager], ], -) -> TasksManager: - return await get_tasks_manager(use_in_memory_redis, None) +) -> BaseLongRunningManager: + return await get_long_running_manager( + use_in_memory_redis, rabbit_service, "rabbit-namespace" + ) @pytest.mark.parametrize("check_task_presence_before", [True, False]) async def test_task_is_auto_removed( - tasks_manager: TasksManager, + long_running_manager: BaseLongRunningManager, check_task_presence_before: bool, empty_context: TaskContext, ): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10 * TEST_CHECK_STALE_INTERVAL_S, @@ -135,7 +128,7 @@ async def test_task_is_auto_removed( if check_task_presence_before: # immediately after starting the task is still there - task_status = await tasks_manager.get_task_status( + task_status = await long_running_manager.tasks_manager.get_task_status( task_id, with_task_context=empty_context ) assert task_status @@ -145,45 +138,61 @@ async def test_task_is_auto_removed( async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: if ( - await tasks_manager._tasks_data.get_task_data(task_id) # noqa: SLF001 + await long_running_manager.tasks_manager._tasks_data.get_task_data( # noqa: SLF001 + task_id + ) is not None ): msg = "wait till no element is found any longer" raise TryAgain(msg) with pytest.raises(TaskNotFoundError): - await tasks_manager.get_task_status(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) with pytest.raises(TaskNotFoundError): - await tasks_manager.get_task_result(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.get_task_result( + task_id, with_task_context=empty_context + ) +@pytest.mark.parametrize("wait_multiplier", [1, 2, 3, 4, 5, 6]) async def test_checked_task_is_not_auto_removed( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, + empty_context: TaskContext, + wait_multiplier: int, ): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, - total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, + total_sleep=wait_multiplier * TEST_CHECK_STALE_INTERVAL_S, task_context=empty_context, ) async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: - status = await tasks_manager.get_task_status( + status = await long_running_manager.tasks_manager.get_task_status( task_id, with_task_context=empty_context ) assert status.done, f"task {task_id} not complete" - result = await tasks_manager.get_task_result( + result = await long_running_manager.tasks_manager.get_task_result( task_id, with_task_context=empty_context ) assert result +def _get_resutlt(result_field: ResultField) -> Any: + assert result_field.str_result + return loads(result_field.str_result) + + async def test_fire_and_forget_task_is_not_auto_removed( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, @@ -192,35 +201,38 @@ async def test_fire_and_forget_task_is_not_auto_removed( ) await asyncio.sleep(3 * TEST_CHECK_STALE_INTERVAL_S) # the task shall still be present even if we did not check the status before - status = await tasks_manager.get_task_status( + status = await long_running_manager.tasks_manager.get_task_status( task_id, with_task_context=empty_context ) assert not status.done, "task was removed although it is fire and forget" # the task shall finish await asyncio.sleep(4 * TEST_CHECK_STALE_INTERVAL_S) # get the result - task_result = await tasks_manager.get_task_result( + task_result = await long_running_manager.tasks_manager.get_task_result( task_id, with_task_context=empty_context ) - assert task_result == 42 + assert _get_resutlt(task_result) == 42 async def test_get_result_of_unfinished_task_raises( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, task_context=empty_context, ) with pytest.raises(TaskNotCompletedError): - await tasks_manager.get_task_result(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.get_task_result( + task_id, with_task_context=empty_context + ) async def test_unique_task_already_running( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): async def unique_task(progress: TaskProgress): _ = progress @@ -229,13 +241,21 @@ async def unique_task(progress: TaskProgress): TaskRegistry.register(unique_task) await lrt_api.start_task( - tasks_manager, unique_task.__name__, unique=True, task_context=empty_context + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + unique_task.__name__, + unique=True, + task_context=empty_context, ) # ensure unique running task regardless of how many times it gets started with pytest.raises(TaskAlreadyRunningError) as exec_info: await lrt_api.start_task( - tasks_manager, unique_task.__name__, unique=True, task_context=empty_context + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + unique_task.__name__, + unique=True, + task_context=empty_context, ) assert "must be unique, found: " in f"{exec_info.value}" @@ -243,7 +263,7 @@ async def unique_task(progress: TaskProgress): async def test_start_multiple_not_unique_tasks( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): async def not_unique_task(progress: TaskProgress): await asyncio.sleep(1) @@ -252,28 +272,40 @@ async def not_unique_task(progress: TaskProgress): for _ in range(5): await lrt_api.start_task( - tasks_manager, not_unique_task.__name__, task_context=empty_context + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + not_unique_task.__name__, + task_context=empty_context, ) TaskRegistry.unregister(not_unique_task) @pytest.mark.parametrize("is_unique", [True, False]) -async def test_get_task_id(tasks_manager: TasksManager, faker: Faker, is_unique: bool): - obj1 = tasks_manager._get_task_id(faker.word(), is_unique=is_unique) # noqa: SLF001 - obj2 = tasks_manager._get_task_id(faker.word(), is_unique=is_unique) # noqa: SLF001 +async def test_get_task_id( + long_running_manager: BaseLongRunningManager, faker: Faker, is_unique: bool +): + obj1 = long_running_manager.tasks_manager._get_task_id( # noqa: SLF001 + faker.word(), is_unique=is_unique + ) + obj2 = long_running_manager.tasks_manager._get_task_id( # noqa: SLF001 + faker.word(), is_unique=is_unique + ) assert obj1 != obj2 -async def test_get_status(tasks_manager: TasksManager, empty_context: TaskContext): +async def test_get_status( + long_running_manager: BaseLongRunningManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context=empty_context, ) - task_status = await tasks_manager.get_task_status( + task_status = await long_running_manager.tasks_manager.get_task_status( task_id, with_task_context=empty_context ) assert isinstance(task_status, TaskStatus) @@ -284,75 +316,97 @@ async def test_get_status(tasks_manager: TasksManager, empty_context: TaskContex async def test_get_status_missing( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): with pytest.raises(TaskNotFoundError) as exec_info: - await tasks_manager.get_task_status( + await long_running_manager.tasks_manager.get_task_status( "missing_task_id", with_task_context=empty_context ) assert f"{exec_info.value}" == "No task with missing_task_id found" -async def test_get_result(tasks_manager: TasksManager, empty_context: TaskContext): +async def test_get_result( + long_running_manager: BaseLongRunningManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( - tasks_manager, fast_background_task.__name__, task_context=empty_context + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + fast_background_task.__name__, + task_context=empty_context, ) async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: - status = await tasks_manager.get_task_status( + status = await long_running_manager.tasks_manager.get_task_status( task_id, with_task_context=empty_context ) assert status.done is True - result = await tasks_manager.get_task_result( + result = await long_running_manager.tasks_manager.get_task_result( task_id, with_task_context=empty_context ) - assert result == 42 + assert _get_resutlt(result) == 42 async def test_get_result_missing( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): with pytest.raises(TaskNotFoundError) as exec_info: - await tasks_manager.get_task_result( + await long_running_manager.tasks_manager.get_task_result( "missing_task_id", with_task_context=empty_context ) assert f"{exec_info.value}" == "No task with missing_task_id found" async def test_get_result_finished_with_error( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): task_id = await lrt_api.start_task( - tasks_manager, failing_background_task.__name__, task_context=empty_context + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + failing_background_task.__name__, + task_context=empty_context, ) # wait for result async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: assert ( - await tasks_manager.get_task_status( + await long_running_manager.tasks_manager.get_task_status( task_id, with_task_context=empty_context ) ).done - with pytest.raises(RuntimeError, match="failing asap"): - await tasks_manager.get_task_result(task_id, with_task_context=empty_context) + result = await long_running_manager.tasks_manager.get_task_result( + task_id, with_task_context=empty_context + ) + assert result.str_error is not None # nosec + error = loads(result.str_error) + with pytest.raises(_TetingError, match="failing asap"): + raise error async def test_cancel_task_from_different_manager( + rabbit_service: RabbitSettings, use_in_memory_redis: RedisSettings, - get_tasks_manager: Callable[ - [RedisSettings, RedisNamespace | None], Awaitable[TasksManager] + get_long_running_manager: Callable[ + [RedisSettings, RabbitSettings, LRTNamespace | None], + Awaitable[BaseLongRunningManager], ], empty_context: TaskContext, ): - manager_1 = await get_tasks_manager(use_in_memory_redis, "test-namespace") - manager_2 = await get_tasks_manager(use_in_memory_redis, "test-namespace") - manager_3 = await get_tasks_manager(use_in_memory_redis, "test-namespace") + manager_1 = await get_long_running_manager( + use_in_memory_redis, rabbit_service, "test-namespace" + ) + manager_2 = await get_long_running_manager( + use_in_memory_redis, rabbit_service, "test-namespace" + ) + manager_3 = await get_long_running_manager( + use_in_memory_redis, rabbit_service, "test-namespace" + ) task_id = await lrt_api.start_task( - manager_1, + manager_1.rpc_client, + manager_1.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=1, @@ -361,42 +415,58 @@ async def test_cancel_task_from_different_manager( # wati for task to complete for manager in (manager_1, manager_2, manager_3): - status = await manager.get_task_status(task_id, empty_context) + status = await manager.tasks_manager.get_task_status(task_id, empty_context) assert status.done is False async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: for manager in (manager_1, manager_2, manager_3): - status = await manager.get_task_status(task_id, empty_context) + status = await manager.tasks_manager.get_task_status( + task_id, empty_context + ) assert status.done is True # check all provide the same result for manager in (manager_1, manager_2, manager_3): - task_result = await manager.get_task_result(task_id, empty_context) - assert task_result == 42 + task_result = await manager.tasks_manager.get_task_result( + task_id, empty_context + ) + assert _get_resutlt(task_result) == 42 -async def test_remove_task(tasks_manager: TasksManager, empty_context: TaskContext): +async def test_remove_task( + long_running_manager: BaseLongRunningManager, empty_context: TaskContext +): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context=empty_context, ) - await tasks_manager.get_task_status(task_id, with_task_context=empty_context) - await tasks_manager.remove_task(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) + await long_running_manager.tasks_manager.remove_task( + task_id, with_task_context=empty_context, wait_for_removal=True + ) with pytest.raises(TaskNotFoundError): - await tasks_manager.get_task_status(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) with pytest.raises(TaskNotFoundError): - await tasks_manager.get_task_result(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.get_task_result( + task_id, with_task_context=empty_context + ) async def test_remove_task_with_task_context( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, @@ -404,41 +474,44 @@ async def test_remove_task_with_task_context( ) # getting status fails if wrong task context given with pytest.raises(TaskNotFoundError): - await tasks_manager.get_task_status( + await long_running_manager.tasks_manager.get_task_status( task_id, with_task_context={"wrong_task_context": 12} ) - await tasks_manager.get_task_status(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.get_task_status( + task_id, with_task_context=empty_context + ) # removing task fails if wrong task context given with pytest.raises(TaskNotFoundError): - await tasks_manager.remove_task( - task_id, with_task_context={"wrong_task_context": 12} + await long_running_manager.tasks_manager.remove_task( + task_id, with_task_context={"wrong_task_context": 12}, wait_for_removal=True ) - await tasks_manager.remove_task(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.remove_task( + task_id, with_task_context=empty_context, wait_for_removal=True + ) async def test_remove_unknown_task( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): with pytest.raises(TaskNotFoundError): - await tasks_manager.remove_task("invalid_id", with_task_context=empty_context) - - await tasks_manager.remove_task( - "invalid_id", with_task_context=empty_context, reraise_errors=False - ) + await long_running_manager.tasks_manager.remove_task( + "invalid_id", with_task_context=empty_context, wait_for_removal=True + ) async def test__cancelled_tasks_worker_equivalent_of_cancellation_from_a_different_process( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context=empty_context, ) - await tasks_manager._tasks_data.set_as_cancelled( # noqa: SLF001 + await long_running_manager.tasks_manager._tasks_data.mark_task_for_removal( # noqa: SLF001 task_id, with_task_context=empty_context ) @@ -446,19 +519,32 @@ async def test__cancelled_tasks_worker_equivalent_of_cancellation_from_a_differe with attempt: # noqa: SIM117 with pytest.raises(TaskNotFoundError): assert ( - await tasks_manager.get_task_status(task_id, empty_context) is None + await long_running_manager.tasks_manager.get_task_status( + task_id, empty_context + ) + is None ) -async def test_list_tasks(tasks_manager: TasksManager, empty_context: TaskContext): - assert await tasks_manager.list_tasks(with_task_context=empty_context) == [] +async def test_list_tasks( + disable_stale_tasks_monitor: None, + long_running_manager: BaseLongRunningManager, + empty_context: TaskContext, +): + assert ( + await long_running_manager.tasks_manager.list_tasks( + with_task_context=empty_context + ) + == [] + ) # start a bunch of tasks NUM_TASKS = 10 task_ids = [] for _ in range(NUM_TASKS): task_ids.append( # noqa: PERF401 await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, @@ -466,45 +552,70 @@ async def test_list_tasks(tasks_manager: TasksManager, empty_context: TaskContex ) ) assert ( - len(await tasks_manager.list_tasks(with_task_context=empty_context)) + len( + await long_running_manager.tasks_manager.list_tasks( + with_task_context=empty_context + ) + ) == NUM_TASKS ) for task_index, task_id in enumerate(task_ids): - await tasks_manager.remove_task(task_id, with_task_context=empty_context) + await long_running_manager.tasks_manager.remove_task( + task_id, with_task_context=empty_context, wait_for_removal=True + ) assert len( - await tasks_manager.list_tasks(with_task_context=empty_context) + await long_running_manager.tasks_manager.list_tasks( + with_task_context=empty_context + ) ) == NUM_TASKS - (task_index + 1) async def test_list_tasks_filtering( - tasks_manager: TasksManager, empty_context: TaskContext + long_running_manager: BaseLongRunningManager, empty_context: TaskContext ): await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context=empty_context, ) await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context={"user_id": 213}, ) await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context={"user_id": 213, "product": "osparc"}, ) - assert len(await tasks_manager.list_tasks(with_task_context=empty_context)) == 3 - assert len(await tasks_manager.list_tasks(with_task_context={"user_id": 213})) == 1 assert ( len( - await tasks_manager.list_tasks( + await long_running_manager.tasks_manager.list_tasks( + with_task_context=empty_context + ) + ) + == 3 + ) + assert ( + len( + await long_running_manager.tasks_manager.list_tasks( + with_task_context={"user_id": 213} + ) + ) + == 1 + ) + assert ( + len( + await long_running_manager.tasks_manager.list_tasks( with_task_context={"user_id": 213, "product": "osparc"} ) ) @@ -512,7 +623,7 @@ async def test_list_tasks_filtering( ) assert ( len( - await tasks_manager.list_tasks( + await long_running_manager.tasks_manager.list_tasks( with_task_context={"user_id": 120, "product": "osparc"} ) ) @@ -520,10 +631,13 @@ async def test_list_tasks_filtering( ) -async def test_define_task_name(tasks_manager: TasksManager, faker: Faker): +async def test_define_task_name( + long_running_manager: BaseLongRunningManager, faker: Faker +): task_name = faker.name() task_id = await lrt_api.start_task( - tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, a_background_task.__name__, raise_when_finished=False, total_sleep=10, @@ -532,6 +646,13 @@ async def test_define_task_name(tasks_manager: TasksManager, faker: Faker): assert urllib.parse.quote(task_name, safe="") in task_id -async def test_start_not_registered_task(tasks_manager: TasksManager): +async def test_start_not_registered_task( + rabbitmq_rpc_client: RabbitMQRPCClient, + long_running_manager: BaseLongRunningManager, +): with pytest.raises(TaskNotRegisteredError): - await lrt_api.start_task(tasks_manager, "not_registered_task") + await lrt_api.start_task( + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, + "not_registered_task", + ) diff --git a/packages/service-library/tests/long_running_tasks/utils.py b/packages/service-library/tests/long_running_tasks/utils.py new file mode 100644 index 000000000000..e473dd7e1daf --- /dev/null +++ b/packages/service-library/tests/long_running_tasks/utils.py @@ -0,0 +1,3 @@ +from typing import Final + +TEST_CHECK_STALE_INTERVAL_S: Final[float] = 1 diff --git a/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py b/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py index 2c84fd0bb363..53aaac235044 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py @@ -116,7 +116,8 @@ async def _progress_callback( try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, _task_remove_service_containers.__name__, unique=True, node_uuid=node_uuid, @@ -181,7 +182,8 @@ async def _progress_callback( try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, _task_save_service_state.__name__, unique=True, node_uuid=node_uuid, @@ -228,7 +230,8 @@ async def _progress_callback( try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, _task_push_service_outputs.__name__, unique=True, node_uuid=node_uuid, @@ -270,7 +273,8 @@ async def _task_cleanup_service_docker_resources( try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, _task_cleanup_service_docker_resources.__name__, unique=True, node_uuid=node_uuid, diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py index 253f9be601df..5381566045c4 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/module_setup.py @@ -1,12 +1,10 @@ from fastapi import FastAPI from servicelib.fastapi import long_running_tasks -from servicelib.long_running_tasks.task import RedisNamespace +from ..._meta import APP_NAME from ...core.settings import AppSettings from . import api_client, scheduler -_LONG_RUNNING_TASKS_NAMESPACE: RedisNamespace = "director-v2" - def setup(app: FastAPI) -> None: settings: AppSettings = app.state.settings @@ -15,7 +13,8 @@ def setup(app: FastAPI) -> None: long_running_tasks.server.setup( app, redis_settings=settings.REDIS, - redis_namespace=_LONG_RUNNING_TASKS_NAMESPACE, + rabbit_settings=settings.DIRECTOR_V2_RABBITMQ, + lrt_namespace=APP_NAME, ) async def on_startup() -> None: diff --git a/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py b/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py index aa0eea7c7d7c..b573f4292fa8 100644 --- a/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py +++ b/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py @@ -7,6 +7,8 @@ import pytest import respx +from common_library.json_serialization import json_dumps +from common_library.serialization import model_dump_with_secrets from faker import Faker from fastapi import status from httpx import Response @@ -15,6 +17,7 @@ from models_library.service_settings_labels import SimcoreServiceLabels from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import SchedulerData from simcore_service_director_v2.modules.dynamic_sidecar.errors import ( @@ -25,12 +28,16 @@ ) from starlette.testclient import TestClient +pytest_simcore_core_services_selection = [ + "rabbit", +] + @pytest.fixture def mock_env( use_in_memory_redis: RedisSettings, mock_exclusive: None, - disable_rabbitmq: None, + rabbit_service: RabbitSettings, disable_postgres: None, mock_env: EnvVarsDict, monkeypatch: pytest.MonkeyPatch, @@ -50,6 +57,10 @@ def mock_env( monkeypatch.setenv("S3_REGION", faker.pystr()) monkeypatch.setenv("S3_SECRET_KEY", faker.pystr()) monkeypatch.setenv("S3_BUCKET_NAME", faker.pystr()) + monkeypatch.setenv( + "DIRECTOR_V2_RABBITMQ", + json_dumps(model_dump_with_secrets(rabbit_service, show_secrets=True)), + ) @pytest.fixture @@ -203,10 +214,7 @@ async def test_409_response( ) assert response.status_code == status.HTTP_202_ACCEPTED task_id = response.json() - assert ( - f"simcore_service_director_v2.api.routes.dynamic_scheduler.{task_name}" - in task_id - ) + assert f"director-v2.functools.{task_name}" in task_id response = client.request( method, diff --git a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py index 07deb1aeb8e6..e2c79a8287d1 100644 --- a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py +++ b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler.py @@ -12,6 +12,8 @@ import pytest import respx +from common_library.json_serialization import json_dumps +from common_library.serialization import model_dump_with_secrets from faker import Faker from fastapi import FastAPI from models_library.api_schemas_directorv2.dynamic_services_service import ( @@ -23,6 +25,7 @@ from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict from respx.router import MockRouter +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import ( DockerContainerInspect, @@ -54,10 +57,12 @@ # and ensure faster tests _TEST_SCHEDULER_INTERVAL_SECONDS: Final[NonNegativeFloat] = 0.1 -log = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) -pytest_simcore_core_services_selection = ["postgres"] +pytest_simcore_core_services_selection = [ + "rabbit", +] pytest_simcore_ops_services_selection = ["adminer"] @@ -128,7 +133,7 @@ def mock_env( use_in_memory_redis: RedisSettings, mock_exclusive: None, disable_postgres: None, - disable_rabbitmq: None, + rabbit_service: RabbitSettings, mock_env: EnvVarsDict, monkeypatch: pytest.MonkeyPatch, simcore_services_network_name: str, @@ -146,6 +151,10 @@ def mock_env( monkeypatch.setenv("S3_REGION", faker.pystr()) monkeypatch.setenv("S3_SECRET_KEY", faker.pystr()) monkeypatch.setenv("S3_BUCKET_NAME", faker.pystr()) + monkeypatch.setenv( + "DIRECTOR_V2_RABBITMQ", + json_dumps(model_dump_with_secrets(rabbit_service, show_secrets=True)), + ) @pytest.fixture @@ -166,7 +175,7 @@ async def action( scheduler_data: SchedulerData, # noqa: ARG003 ) -> None: message = f"{cls.__name__} action triggered" - log.warning(message) + _logger.warning(message) # replace REGISTERED EVENTS mocker.patch( diff --git a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py index e8ed258bbea5..4ce2d40ed545 100644 --- a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py +++ b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_scheduler_task.py @@ -12,6 +12,8 @@ import httpx import pytest import respx +from common_library.json_serialization import json_dumps +from common_library.serialization import model_dump_with_secrets from faker import Faker from fastapi import FastAPI from models_library.docker import DockerNodeID @@ -20,6 +22,7 @@ from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from respx.router import MockRouter +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import SchedulerData from simcore_service_director_v2.modules.dynamic_sidecar.api_client._public import ( @@ -39,6 +42,10 @@ DynamicSidecarsScheduler, ) +pytest_simcore_core_services_selection = [ + "rabbit", +] + SCHEDULER_INTERVAL_SECONDS: Final[float] = 0.1 @@ -46,7 +53,7 @@ def mock_env( use_in_memory_redis: RedisSettings, disable_postgres: None, - disable_rabbitmq: None, + rabbit_service: RabbitSettings, mock_env: EnvVarsDict, monkeypatch: pytest.MonkeyPatch, simcore_services_network_name: str, @@ -64,7 +71,11 @@ def mock_env( "POSTGRES_USER": "", "POSTGRES_PASSWORD": "", "POSTGRES_DB": "", + "DIRECTOR_V2_RABBITMQ": json_dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ), } + setenvs_from_dict(monkeypatch, disabled_services_envs) monkeypatch.setenv("DIRECTOR_V2_DYNAMIC_SCHEDULER_ENABLED", "true") diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py b/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py index ba858c5940cc..8a70f85eb170 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_dynamic_services.py @@ -40,6 +40,7 @@ X_DYNAMIC_SIDECAR_REQUEST_SCHEME, X_SIMCORE_USER_AGENT, ) +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_service_director_v2.models.dynamic_services_scheduler import SchedulerData from simcore_service_director_v2.modules.dynamic_sidecar.errors import ( @@ -53,6 +54,7 @@ pytest_simcore_core_services_selection = [ "postgres", + "rabbit", "redis", ] pytest_simcore_ops_services_selection = [ @@ -72,7 +74,6 @@ class ServiceParams(NamedTuple): @pytest.fixture def minimal_config( - disable_rabbitmq: None, mock_env: EnvVarsDict, postgres_host_config: dict[str, str], monkeypatch: pytest.MonkeyPatch, @@ -100,8 +101,8 @@ def mock_env( mock_env: EnvVarsDict, mock_exclusive: None, disable_postgres: None, - disable_rabbitmq: None, redis_service: RedisSettings, + rabbit_service: RabbitSettings, monkeypatch: pytest.MonkeyPatch, faker: Faker, ) -> None: @@ -126,11 +127,6 @@ def mock_env( monkeypatch.setenv("COMPUTATIONAL_BACKEND_DEFAULT_CLUSTER_AUTH", "{}") monkeypatch.setenv("DIRECTOR_V2_DYNAMIC_SCHEDULER_ENABLED", "true") - monkeypatch.setenv("RABBIT_HOST", "mocked_host") - monkeypatch.setenv("RABBIT_SECURE", "false") - monkeypatch.setenv("RABBIT_USER", "mocked_user") - monkeypatch.setenv("RABBIT_PASSWORD", "mocked_password") - monkeypatch.setenv("REGISTRY_AUTH", "false") monkeypatch.setenv("REGISTRY_USER", "test") monkeypatch.setenv("REGISTRY_PW", "test") diff --git a/services/director-v2/tests/unit/with_dbs/test_cli.py b/services/director-v2/tests/unit/with_dbs/test_cli.py index cce3c67968dc..a4a231984a4c 100644 --- a/services/director-v2/tests/unit/with_dbs/test_cli.py +++ b/services/director-v2/tests/unit/with_dbs/test_cli.py @@ -23,6 +23,7 @@ from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.long_running_tasks.models import ProgressCallback +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_service_director_v2.cli import DEFAULT_NODE_SAVE_ATTEMPTS, main from simcore_service_director_v2.cli._close_and_save_service import ( @@ -32,6 +33,7 @@ pytest_simcore_core_services_selection = [ "postgres", + "rabbit", "redis", ] pytest_simcore_ops_services_selection = [ @@ -44,6 +46,7 @@ def minimal_configuration( mock_env: EnvVarsDict, postgres_host_config: dict[str, str], redis_service: RedisSettings, + rabbit_service: RabbitSettings, monkeypatch: pytest.MonkeyPatch, faker: Faker, with_product: dict[str, Any], diff --git a/services/docker-compose.yml b/services/docker-compose.yml index a6f8d3df8553..7e263ed2139e 100644 --- a/services/docker-compose.yml +++ b/services/docker-compose.yml @@ -699,6 +699,9 @@ services: WEBSERVER_LOG_FORMAT_LOCAL_DEV_ENABLED: ${LOG_FORMAT_LOCAL_DEV_ENABLED} WEBSERVER_LOG_FILTER_MAPPING: ${LOG_FILTER_MAPPING} + # NOTE: keep in sync with the prefix form the hostname + LONG_RUNNING_TASKS_NAMESPACE_SUFFIX: wb + # WEBSERVER_SERVER_HOST WEBSERVER_HOST: ${WEBSERVER_HOST} @@ -929,6 +932,9 @@ services: WEBSERVER_STATICWEB: "null" WEBSERVER_FUNCTIONS: ${WEBSERVER_FUNCTIONS} # needed for api-server + # NOTE: keep in sync with the prefix form the hostname + LONG_RUNNING_TASKS_NAMESPACE_SUFFIX: api + networks: *webserver_networks wb-db-event-listener: @@ -938,6 +944,9 @@ services: environment: WEBSERVER_LOGLEVEL: ${WB_DB_EL_LOGLEVEL} + # NOTE: keep in sync with the prefix form the hostname + LONG_RUNNING_TASKS_NAMESPACE_SUFFIX: db + WEBSERVER_HOST: ${WEBSERVER_HOST} WEBSERVER_PORT: ${WEBSERVER_PORT} @@ -1034,6 +1043,9 @@ services: LOG_FILTER_MAPPING: ${LOG_FILTER_MAPPING} LOG_FORMAT_LOCAL_DEV_ENABLED: ${LOG_FORMAT_LOCAL_DEV_ENABLED} + # NOTE: keep in sync with the prefix form the hostname + LONG_RUNNING_TASKS_NAMESPACE_SUFFIX: gc + # WEBSERVER_DB POSTGRES_DB: ${POSTGRES_DB} POSTGRES_ENDPOINT: ${POSTGRES_ENDPOINT} @@ -1126,6 +1138,9 @@ services: WEBSERVER_APP_FACTORY_NAME: WEBSERVER_AUTHZ_APP_FACTORY WEBSERVER_LOGLEVEL: ${WB_AUTH_LOGLEVEL} + # NOTE: keep in sync with the prefix form the hostname + LONG_RUNNING_TASKS_NAMESPACE_SUFFIX: auth + GUNICORN_CMD_ARGS: ${WEBSERVER_GUNICORN_CMD_ARGS} # WEBSERVER_DB diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py index 7e9fdb3d0b8f..a8f7c3e69ccc 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py @@ -1,7 +1,7 @@ from textwrap import dedent from typing import Annotated, cast -from fastapi import APIRouter, Depends, FastAPI, Request, status +from fastapi import APIRouter, Depends, Request, status from servicelib.fastapi.long_running_tasks._manager import FastAPILongRunningManager from servicelib.fastapi.long_running_tasks.server import get_long_running_manager from servicelib.fastapi.requests_decorators import cancel_on_disconnect @@ -12,7 +12,6 @@ from ...core.settings import ApplicationSettings from ...models.schemas.application_health import ApplicationHealth from ...models.schemas.containers import ContainersCreate -from ...models.shared_store import SharedStore from ...modules.inputs import InputsState from ...modules.long_running_tasks import ( task_containers_restart, @@ -25,16 +24,10 @@ task_runs_docker_compose_down, task_save_state, ) -from ...modules.mounted_fs import MountedVolumes -from ...modules.outputs import OutputsManager from ._dependencies import ( - get_application, get_application_health, get_inputs_state, - get_mounted_volumes, - get_outputs_manager, get_settings, - get_shared_store, ) router = APIRouter() @@ -52,18 +45,15 @@ async def pull_user_servcices_docker_images( long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], - shared_store: Annotated[SharedStore, Depends(get_shared_store)], - app: Annotated[FastAPI, Depends(get_application)], ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_pull_user_servcices_docker_images.__name__, unique=True, - app=app, - shared_store=shared_store, ) except TaskAlreadyRunningError as e: return cast(str, e.managed_task.task_id) # type: ignore[attr-defined] # pylint:disable=no-member @@ -92,21 +82,18 @@ async def create_service_containers_task( # pylint: disable=too-many-arguments FastAPILongRunningManager, Depends(get_long_running_manager) ], settings: Annotated[ApplicationSettings, Depends(get_settings)], - shared_store: Annotated[SharedStore, Depends(get_shared_store)], - app: Annotated[FastAPI, Depends(get_application)], application_health: Annotated[ApplicationHealth, Depends(get_application_health)], ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_create_service_containers.__name__, unique=True, settings=settings, containers_create=containers_create, - shared_store=shared_store, - app=app, application_health=application_health, ) except TaskAlreadyRunningError as e: @@ -126,21 +113,16 @@ async def runs_docker_compose_down_task( FastAPILongRunningManager, Depends(get_long_running_manager) ], settings: Annotated[ApplicationSettings, Depends(get_settings)], - shared_store: Annotated[SharedStore, Depends(get_shared_store)], - app: Annotated[FastAPI, Depends(get_application)], - mounted_volumes: Annotated[MountedVolumes, Depends(get_mounted_volumes)], ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_runs_docker_compose_down.__name__, unique=True, - app=app, - shared_store=shared_store, settings=settings, - mounted_volumes=mounted_volumes, ) except TaskAlreadyRunningError as e: return cast(str, e.managed_task.task_id) # type: ignore[attr-defined] # pylint:disable=no-member @@ -159,19 +141,16 @@ async def state_restore_task( FastAPILongRunningManager, Depends(get_long_running_manager) ], settings: Annotated[ApplicationSettings, Depends(get_settings)], - mounted_volumes: Annotated[MountedVolumes, Depends(get_mounted_volumes)], - app: Annotated[FastAPI, Depends(get_application)], ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_restore_state.__name__, unique=True, settings=settings, - mounted_volumes=mounted_volumes, - app=app, ) except TaskAlreadyRunningError as e: return cast(str, e.managed_task.task_id) # type: ignore[attr-defined] # pylint:disable=no-member @@ -189,20 +168,17 @@ async def state_save_task( long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], - app: Annotated[FastAPI, Depends(get_application)], - mounted_volumes: Annotated[MountedVolumes, Depends(get_mounted_volumes)], settings: Annotated[ApplicationSettings, Depends(get_settings)], ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_save_state.__name__, unique=True, settings=settings, - mounted_volumes=mounted_volumes, - app=app, ) except TaskAlreadyRunningError as e: return cast(str, e.managed_task.task_id) # type: ignore[attr-defined] # pylint:disable=no-member @@ -220,9 +196,7 @@ async def ports_inputs_pull_task( long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], - app: Annotated[FastAPI, Depends(get_application)], settings: Annotated[ApplicationSettings, Depends(get_settings)], - mounted_volumes: Annotated[MountedVolumes, Depends(get_mounted_volumes)], inputs_state: Annotated[InputsState, Depends(get_inputs_state)], port_keys: list[str] | None = None, ) -> TaskId: @@ -230,12 +204,11 @@ async def ports_inputs_pull_task( try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_ports_inputs_pull.__name__, unique=True, port_keys=port_keys, - mounted_volumes=mounted_volumes, - app=app, settings=settings, inputs_pulling_enabled=inputs_state.inputs_pulling_enabled, ) @@ -255,20 +228,17 @@ async def ports_outputs_pull_task( long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], - app: Annotated[FastAPI, Depends(get_application)], - mounted_volumes: Annotated[MountedVolumes, Depends(get_mounted_volumes)], port_keys: list[str] | None = None, ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_ports_outputs_pull.__name__, unique=True, port_keys=port_keys, - mounted_volumes=mounted_volumes, - app=app, ) except TaskAlreadyRunningError as e: return cast(str, e.managed_task.task_id) # type: ignore[attr-defined] # pylint:disable=no-member @@ -286,18 +256,15 @@ async def ports_outputs_push_task( long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], - outputs_manager: Annotated[OutputsManager, Depends(get_outputs_manager)], - app: Annotated[FastAPI, Depends(get_application)], ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_ports_outputs_push.__name__, unique=True, - outputs_manager=outputs_manager, - app=app, ) except TaskAlreadyRunningError as e: return cast(str, e.managed_task.task_id) # type: ignore[attr-defined] # pylint:disable=no-member @@ -315,20 +282,17 @@ async def containers_restart_task( long_running_manager: Annotated[ FastAPILongRunningManager, Depends(get_long_running_manager) ], - app: Annotated[FastAPI, Depends(get_application)], settings: Annotated[ApplicationSettings, Depends(get_settings)], - shared_store: Annotated[SharedStore, Depends(get_shared_store)], ) -> TaskId: assert request # nosec try: return await lrt_api.start_task( - long_running_manager.tasks_manager, + long_running_manager.rpc_client, + long_running_manager.lrt_namespace, task_containers_restart.__name__, unique=True, - app=app, settings=settings, - shared_store=shared_store, ) except TaskAlreadyRunningError as e: return cast(str, e.managed_task.task_id) # type: ignore[attr-defined] # pylint:disable=no-member diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py index 12a1b77a7259..4c82ec3279ca 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/core/application.py @@ -5,7 +5,6 @@ from common_library.json_serialization import json_dumps from fastapi import FastAPI from servicelib.async_utils import cancel_sequential_workers -from servicelib.fastapi import long_running_tasks from servicelib.fastapi.logging_lifespan import create_logging_shutdown_event from servicelib.fastapi.openapi import ( get_common_oas_options, @@ -24,6 +23,7 @@ from ..models.shared_store import SharedStore, setup_shared_store from ..modules.attribute_monitor import setup_attribute_monitor from ..modules.inputs import setup_inputs +from ..modules.long_running_tasks import setup_long_running_tasks from ..modules.mounted_fs import MountedVolumes, setup_mounted_fs from ..modules.notifications import setup_notifications from ..modules.outputs import setup_outputs @@ -146,12 +146,6 @@ def create_base_app() -> FastAPI: override_fastapi_openapi_method(app) app.state.settings = app_settings - long_running_tasks.server.setup( - app, - redis_settings=app_settings.REDIS_SETTINGS, - redis_namespace=f"dy_sidecar-{app_settings.DY_SIDECAR_RUN_ID}", - ) - app.include_router(get_main_router(app)) setup_reserved_space(app) @@ -191,6 +185,8 @@ def create_app() -> FastAPI: setup_inputs(app) setup_outputs(app) + setup_long_running_tasks(app) + setup_attribute_monitor(app) setup_user_services_preferences(app) diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/models/schemas/application_health.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/models/schemas/application_health.py index 4da644858b9d..72413188e4b7 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/models/schemas/application_health.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/models/schemas/application_health.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field @@ -7,6 +5,6 @@ class ApplicationHealth(BaseModel): is_healthy: bool = Field( default=True, description="returns True if the service sis running correctly" ) - error_message: Optional[str] = Field( + error_message: str | None = Field( default=None, description="in case of error this gets set" ) diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py index 254ef2968b32..94fb33ea5c7a 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path -from typing import Final +from typing import Any, Final from fastapi import FastAPI from models_library.api_schemas_long_running_tasks.base import TaskProgress @@ -11,9 +11,10 @@ from models_library.rabbitmq_messages import ProgressType, SimcorePlatformStatus from models_library.service_settings_labels import LegacyState from pydantic import PositiveInt +from servicelib.fastapi import long_running_tasks from servicelib.file_utils import log_directory_changes from servicelib.logging_utils import log_context -from servicelib.long_running_tasks.task import TaskRegistry +from servicelib.long_running_tasks.task import TaskProtocol, TaskRegistry from servicelib.progress_bar import ProgressBarData from servicelib.utils import logged_gather from simcore_sdk.node_data import data_manager @@ -629,15 +630,60 @@ async def task_containers_restart( await progress.update(message="started log fetching", percent=0.99) -for task in ( - task_pull_user_servcices_docker_images, - task_create_service_containers, - task_runs_docker_compose_down, - task_restore_state, - task_save_state, - task_ports_inputs_pull, - task_ports_outputs_pull, - task_ports_outputs_push, - task_containers_restart, -): - TaskRegistry.register(task) +def setup_long_running_tasks(app: FastAPI) -> None: + app_settings: ApplicationSettings = app.state.settings + long_running_tasks.server.setup( + app, + redis_settings=app_settings.REDIS_SETTINGS, + rabbit_settings=app_settings.RABBIT_SETTINGS, + lrt_namespace=f"{APP_NAME}-{app_settings.DY_SIDECAR_RUN_ID}", + ) + + task_context: dict[TaskProtocol, dict[str, Any]] = {} + + async def on_startup() -> None: + shared_store: SharedStore = app.state.shared_store + mounted_volumes: MountedVolumes = app.state.mounted_volumes + outputs_manager: OutputsManager = app.state.outputs_manager + + context_app_store: dict[str, Any] = { + "app": app, + "shared_store": shared_store, + } + context_app_store_volumes: dict[str, Any] = { + "app": app, + "shared_store": shared_store, + "mounted_volumes": mounted_volumes, + } + context_app_volumes: dict[str, Any] = { + "app": app, + "mounted_volumes": mounted_volumes, + } + context_app_outputs: dict[str, Any] = { + "app": app, + "outputs_manager": outputs_manager, + } + + task_context.update( + { + task_pull_user_servcices_docker_images: context_app_store, + task_create_service_containers: context_app_store, + task_runs_docker_compose_down: context_app_store_volumes, + task_restore_state: context_app_volumes, + task_save_state: context_app_volumes, + task_ports_inputs_pull: context_app_volumes, + task_ports_outputs_pull: context_app_volumes, + task_ports_outputs_push: context_app_outputs, + task_containers_restart: context_app_store, + } + ) + + for handler, context in task_context.items(): + TaskRegistry.register(handler, **context) + + async def _on_shutdown() -> None: + for handler in task_context: + TaskRegistry.unregister(handler) + + app.add_event_handler("startup", on_startup) + app.add_event_handler("shutdown", _on_shutdown) diff --git a/services/dynamic-sidecar/tests/conftest.py b/services/dynamic-sidecar/tests/conftest.py index 1bf41e23c834..3943692efdb9 100644 --- a/services/dynamic-sidecar/tests/conftest.py +++ b/services/dynamic-sidecar/tests/conftest.py @@ -42,6 +42,7 @@ "pytest_simcore.docker_swarm", "pytest_simcore.faker_users_data", "pytest_simcore.logging", + "pytest_simcore.long_running_tasks", "pytest_simcore.minio_service", "pytest_simcore.postgres_service", "pytest_simcore.pytest_global_environs", @@ -169,6 +170,7 @@ def mock_rabbit_check(mocker: MockerFixture) -> None: @pytest.fixture def base_mock_envs( + fast_long_running_tasks_cancellation: None, use_in_memory_redis: RedisSettings, dy_volumes: Path, shared_store_dir: Path, @@ -211,6 +213,7 @@ def base_mock_envs( @pytest.fixture def mock_environment( + fast_long_running_tasks_cancellation: None, use_in_memory_redis: RedisSettings, mock_storage_check: None, mock_postgres_check: None, @@ -357,9 +360,7 @@ def mock_stop_heart_beat_task(mocker: MockerFixture) -> AsyncMock: @pytest.fixture def mock_metrics_params(faker: Faker) -> CreateServiceMetricsAdditionalParams: return TypeAdapter(CreateServiceMetricsAdditionalParams).validate_python( - CreateServiceMetricsAdditionalParams.model_config["json_schema_extra"][ - "example" - ], + CreateServiceMetricsAdditionalParams.model_json_schema()["example"] ) diff --git a/services/dynamic-sidecar/tests/unit/api/rest/test_disk.py b/services/dynamic-sidecar/tests/unit/api/rest/test_disk.py index 3d6bda8d8f1b..fa466827e0dd 100644 --- a/services/dynamic-sidecar/tests/unit/api/rest/test_disk.py +++ b/services/dynamic-sidecar/tests/unit/api/rest/test_disk.py @@ -1,5 +1,7 @@ # pylint:disable=unused-argument +from unittest.mock import AsyncMock + from async_asgi_testclient import TestClient from fastapi import status from simcore_service_dynamic_sidecar._meta import API_VTAG @@ -9,7 +11,9 @@ async def test_reserved_disk_space_freed( - cleanup_reserved_disk_space: None, test_client: TestClient + mock_core_rabbitmq: dict[str, AsyncMock], + cleanup_reserved_disk_space: None, + test_client: TestClient, ): assert _RESERVED_DISK_SPACE_NAME.exists() response = await test_client.post(f"/{API_VTAG}/disk/reserved:free") diff --git a/services/dynamic-sidecar/tests/unit/api/rest/test_volumes.py b/services/dynamic-sidecar/tests/unit/api/rest/test_volumes.py index 40eab12336a3..5bf729dfe0db 100644 --- a/services/dynamic-sidecar/tests/unit/api/rest/test_volumes.py +++ b/services/dynamic-sidecar/tests/unit/api/rest/test_volumes.py @@ -1,6 +1,7 @@ # pylint: disable=unused-argument from pathlib import Path +from unittest.mock import AsyncMock import pytest from async_asgi_testclient import TestClient @@ -20,6 +21,7 @@ ], ) async def test_volumes_state_saved_ok( + mock_core_rabbitmq: dict[str, AsyncMock], ensure_shared_store_dir: Path, test_client: TestClient, volume_category: VolumeCategory, @@ -46,6 +48,7 @@ async def test_volumes_state_saved_ok( @pytest.mark.parametrize("invalid_volume_category", ["outputs", "outputS"]) async def test_volumes_state_saved_error( + mock_core_rabbitmq: dict[str, AsyncMock], ensure_shared_store_dir: Path, test_client: TestClient, invalid_volume_category: VolumeCategory, diff --git a/services/dynamic-sidecar/tests/unit/conftest.py b/services/dynamic-sidecar/tests/unit/conftest.py index 75b9d316c103..fc113f1c674a 100644 --- a/services/dynamic-sidecar/tests/unit/conftest.py +++ b/services/dynamic-sidecar/tests/unit/conftest.py @@ -12,7 +12,7 @@ from async_asgi_testclient import TestClient from fastapi import FastAPI from pytest_mock.plugin import MockerFixture -from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict +from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict from simcore_service_dynamic_sidecar.core.application import AppState, create_app from simcore_service_dynamic_sidecar.core.docker_compose_utils import ( docker_compose_down, @@ -40,11 +40,7 @@ @pytest.fixture -def app( - mock_environment: EnvVarsDict, - mock_registry_service: AsyncMock, - mock_core_rabbitmq: dict[str, AsyncMock], -) -> FastAPI: +def app(mock_environment: EnvVarsDict, mock_registry_service: AsyncMock) -> FastAPI: """creates app with registry and rabbitMQ services mocked""" return create_app() @@ -131,24 +127,6 @@ async def cleanup_containers(app: FastAPI) -> AsyncIterator[None]: await docker_compose_down(app_state.compose_spec, app_state.settings) -@pytest.fixture -def mock_rabbitmq_envs( - mock_core_rabbitmq: dict[str, AsyncMock], - monkeypatch: pytest.MonkeyPatch, - mock_environment: EnvVarsDict, -) -> EnvVarsDict: - setenvs_from_dict( - monkeypatch, - { - "RABBIT_HOST": "mocked_host", - "RABBIT_SECURE": "false", - "RABBIT_USER": "mocked_user", - "RABBIT_PASSWORD": "mocked_password", - }, - ) - return mock_environment - - @pytest.fixture def port_notifier(app: FastAPI) -> PortNotifier: settings: ApplicationSettings = app.state.settings diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_containers.py b/services/dynamic-sidecar/tests/unit/test_api_rest_containers.py index 970f8aeb67e4..ba0e3f629ceb 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_containers.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_containers.py @@ -19,6 +19,7 @@ from aiodocker.volumes import DockerVolume from aiofiles.os import mkdir from async_asgi_testclient import TestClient +from common_library.serialization import model_dump_with_secrets from faker import Faker from fastapi import FastAPI, status from models_library.api_schemas_dynamic_sidecar.containers import ActivityInfo @@ -29,6 +30,7 @@ from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from servicelib.docker_constants import SUFFIX_EGRESS_PROXY_NAME from servicelib.long_running_tasks.models import TaskId +from settings_library.rabbit import RabbitSettings from simcore_service_dynamic_sidecar._meta import API_VTAG from simcore_service_dynamic_sidecar.api.rest.containers import _INACTIVE_FOR_LONG_TIME from simcore_service_dynamic_sidecar.core.application import AppState @@ -47,6 +49,10 @@ from tenacity.stop import stop_after_delay from tenacity.wait import wait_fixed +pytest_simcore_core_services_selection = [ + "rabbit", +] + WAIT_FOR_OUTPUTS_WATCHER: Final[float] = 0.1 FAST_POLLING_INTERVAL: Final[float] = 0.1 @@ -162,9 +168,19 @@ async def _assert_compose_spec_pulled(compose_spec: str, settings: ApplicationSe @pytest.fixture def mock_environment( - mock_environment: EnvVarsDict, mock_rabbitmq_envs: EnvVarsDict + monkeypatch: pytest.MonkeyPatch, + rabbit_service: RabbitSettings, + mock_environment: EnvVarsDict, ) -> EnvVarsDict: - return mock_rabbitmq_envs + return setenvs_from_dict( + monkeypatch, + { + **mock_environment, + "RABBIT_SETTINGS": json.dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ), + }, + ) @pytest.fixture @@ -267,10 +283,10 @@ def not_started_containers() -> list[str]: def mock_outputs_labels() -> dict[str, ServiceOutput]: return { "output_port_1": TypeAdapter(ServiceOutput).validate_python( - ServiceOutput.model_config["json_schema_extra"]["examples"][3] + ServiceOutput.model_json_schema()["examples"][3] ), "output_port_2": TypeAdapter(ServiceOutput).validate_python( - ServiceOutput.model_config["json_schema_extra"]["examples"][3] + ServiceOutput.model_json_schema()["examples"][3] ), } diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py b/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py index 75180ce4a00b..15e78861aca9 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py @@ -17,6 +17,7 @@ from aiodocker.containers import DockerContainer from aiodocker.volumes import DockerVolume from asgi_lifespan import LifespanManager +from common_library.serialization import model_dump_with_secrets from fastapi import FastAPI from fastapi.routing import APIRoute from httpx import ASGITransport, AsyncClient @@ -28,12 +29,13 @@ from models_library.services_creation import CreateServiceMetricsAdditionalParams from pydantic import AnyHttpUrl, TypeAdapter from pytest_mock.plugin import MockerFixture -from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict +from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from servicelib.fastapi.long_running_tasks.client import Client, periodic_task_result from servicelib.fastapi.long_running_tasks.client import setup as client_setup from servicelib.long_running_tasks.errors import TaskExceptionError from servicelib.long_running_tasks.models import TaskId from servicelib.long_running_tasks.task import TaskRegistry +from settings_library.rabbit import RabbitSettings from simcore_sdk.node_ports_common.exceptions import NodeNotFound from simcore_service_dynamic_sidecar._meta import API_VTAG from simcore_service_dynamic_sidecar.api.rest import containers_long_running_tasks @@ -53,6 +55,10 @@ wait_fixed, ) +pytest_simcore_core_services_selection = [ + "rabbit", +] + FAST_STATUS_POLL: Final[float] = 0.1 CREATE_SERVICE_CONTAINERS_TIMEOUT: Final[float] = 60 DEFAULT_COMMAND_TIMEOUT: Final[int] = 5 @@ -106,7 +112,7 @@ async def auto_remove_task(client: Client, task_id: TaskId) -> AsyncIterator[Non try: yield finally: - await client.cancel_and_delete_task(task_id, timeout=10) + await client.remove_task(task_id, timeout=10) async def _get_container_timestamps( @@ -171,8 +177,20 @@ def backend_url() -> AnyHttpUrl: @pytest.fixture -def mock_environment(mock_rabbitmq_envs: EnvVarsDict) -> EnvVarsDict: - return mock_rabbitmq_envs +def mock_environment( + monkeypatch: pytest.MonkeyPatch, + rabbit_service: RabbitSettings, + mock_environment: EnvVarsDict, +) -> EnvVarsDict: + return setenvs_from_dict( + monkeypatch, + { + **mock_environment, + "RABBIT_SETTINGS": json.dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ), + }, + ) @pytest.fixture diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_health.py b/services/dynamic-sidecar/tests/unit/test_api_rest_health.py index 987ddbf1e636..a5542917b117 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_health.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_health.py @@ -1,6 +1,8 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument +from unittest.mock import AsyncMock + from async_asgi_testclient import TestClient from fastapi import status from simcore_service_dynamic_sidecar.models.schemas.application_health import ( @@ -8,14 +10,18 @@ ) -async def test_is_healthy(test_client: TestClient) -> None: +async def test_is_healthy( + mock_core_rabbitmq: dict[str, AsyncMock], test_client: TestClient +) -> None: test_client.application.state.application_health.is_healthy = True response = await test_client.get("/health") assert response.status_code == status.HTTP_200_OK, response assert response.json() == ApplicationHealth(is_healthy=True).model_dump() -async def test_is_unhealthy(test_client: TestClient) -> None: +async def test_is_unhealthy( + mock_core_rabbitmq: dict[str, AsyncMock], test_client: TestClient +) -> None: test_client.application.state.application_health.is_healthy = False response = await test_client.get("/health") assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE, response @@ -24,7 +30,9 @@ async def test_is_unhealthy(test_client: TestClient) -> None: } -async def test_is_unhealthy_via_rabbitmq(test_client: TestClient) -> None: +async def test_is_unhealthy_via_rabbitmq( + mock_core_rabbitmq: dict[str, AsyncMock], test_client: TestClient +) -> None: # pylint: disable=protected-access test_client.application.state.rabbitmq_client._healthy_state = False # noqa: SLF001 response = await test_client.get("/health") diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_prometheus_metrics.py b/services/dynamic-sidecar/tests/unit/test_api_rest_prometheus_metrics.py index 78e5b22046ea..8a653fad77d1 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_prometheus_metrics.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_prometheus_metrics.py @@ -5,11 +5,11 @@ import json from collections.abc import AsyncIterable from typing import Final -from unittest.mock import AsyncMock import pytest from aiodocker.volumes import DockerVolume from asgi_lifespan import LifespanManager +from common_library.serialization import model_dump_with_secrets from fastapi import FastAPI, status from httpx import ASGITransport, AsyncClient from models_library.api_schemas_dynamic_sidecar.containers import DockerComposeYamlStr @@ -20,6 +20,7 @@ from servicelib.fastapi.long_running_tasks.client import Client, periodic_task_result from servicelib.fastapi.long_running_tasks.client import setup as client_setup from servicelib.long_running_tasks.models import TaskId +from settings_library.rabbit import RabbitSettings from simcore_service_dynamic_sidecar._meta import API_VTAG from simcore_service_dynamic_sidecar.models.schemas.containers import ( ContainersComposeSpec, @@ -30,10 +31,31 @@ UserServicesMetrics, ) +pytest_simcore_core_services_selection = [ + "rabbit", +] + _FAST_STATUS_POLL: Final[float] = 0.1 _CREATE_SERVICE_CONTAINERS_TIMEOUT: Final[float] = 60 +@pytest.fixture +def mock_environment( + monkeypatch: pytest.MonkeyPatch, + rabbit_service: RabbitSettings, + mock_environment: EnvVarsDict, +) -> EnvVarsDict: + return setenvs_from_dict( + monkeypatch, + { + **mock_environment, + "RABBIT_SETTINGS": json.dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ), + }, + ) + + @pytest.fixture async def enable_prometheus_metrics( monkeypatch: pytest.MonkeyPatch, mock_environment: EnvVarsDict @@ -42,14 +64,14 @@ async def enable_prometheus_metrics( monkeypatch, { "DY_SIDECAR_CALLBACKS_MAPPING": json.dumps( - CallbacksMapping.model_config["json_schema_extra"]["examples"][2] - ) + CallbacksMapping.model_json_schema()["examples"][2] + ), }, ) @pytest.fixture -async def app(mock_rabbitmq_envs: EnvVarsDict, app: FastAPI) -> AsyncIterable[FastAPI]: +async def app(app: FastAPI) -> AsyncIterable[FastAPI]: client_setup(app) async with LifespanManager(app): yield app @@ -118,17 +140,13 @@ async def _get_task_id_create_service_containers( return task_id -async def test_metrics_disabled( - mock_core_rabbitmq: dict[str, AsyncMock], httpx_async_client: AsyncClient -) -> None: +async def test_metrics_disabled(httpx_async_client: AsyncClient) -> None: response = await httpx_async_client.get("/metrics") assert response.status_code == status.HTTP_404_NOT_FOUND, response async def test_metrics_enabled_no_containers_running( - enable_prometheus_metrics: None, - mock_core_rabbitmq: dict[str, AsyncMock], - httpx_async_client: AsyncClient, + enable_prometheus_metrics: None, httpx_async_client: AsyncClient ) -> None: response = await httpx_async_client.get("/metrics") assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR, response @@ -137,7 +155,6 @@ async def test_metrics_enabled_no_containers_running( async def test_metrics_enabled_containers_will_start( enable_prometheus_metrics: None, - mock_core_rabbitmq: dict[str, AsyncMock], app: FastAPI, httpx_async_client: AsyncClient, client: Client, diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py b/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py index 145fd791fd3a..520a01ad46c3 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py @@ -16,6 +16,7 @@ from aiodocker.utils import clean_filters from aiodocker.volumes import DockerVolume from asgi_lifespan import LifespanManager +from common_library.serialization import model_dump_with_secrets from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from models_library.api_schemas_dynamic_sidecar.containers import DockerComposeYamlStr @@ -36,7 +37,9 @@ from servicelib.fastapi.long_running_tasks.client import setup as client_setup from servicelib.long_running_tasks.errors import TaskExceptionError from servicelib.long_running_tasks.models import TaskId +from settings_library.rabbit import RabbitSettings from simcore_service_dynamic_sidecar._meta import API_VTAG +from simcore_service_dynamic_sidecar.core.application import create_app from simcore_service_dynamic_sidecar.core.docker_utils import get_container_states from simcore_service_dynamic_sidecar.models.schemas.containers import ( ContainersComposeSpec, @@ -47,6 +50,10 @@ from tenacity.stop import stop_after_delay from tenacity.wait import wait_fixed +pytest_simcore_core_services_selection = [ + "rabbit", +] + _FAST_STATUS_POLL: Final[float] = 0.1 _CREATE_SERVICE_CONTAINERS_TIMEOUT: Final[float] = 60 _BASE_HEART_BEAT_INTERVAL: Final[float] = 0.1 @@ -81,26 +88,35 @@ def backend_url() -> AnyHttpUrl: @pytest.fixture -def mock_environment( +async def mock_environment( mock_postgres_check: None, + mock_registry_service: AsyncMock, + mock_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPatch, - mock_rabbitmq_envs: EnvVarsDict, + rabbit_service: RabbitSettings, ) -> EnvVarsDict: - setenvs_from_dict( + return setenvs_from_dict( monkeypatch, - {"RESOURCE_TRACKING_HEARTBEAT_INTERVAL": f"{_BASE_HEART_BEAT_INTERVAL}"}, + { + **mock_environment, + "RESOURCE_TRACKING_HEARTBEAT_INTERVAL": f"{_BASE_HEART_BEAT_INTERVAL}", + "RABBIT_SETTINGS": json.dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ), + }, ) - return mock_rabbitmq_envs @pytest.fixture -async def app(app: FastAPI) -> AsyncIterable[FastAPI]: +async def app(mock_environment: EnvVarsDict) -> AsyncIterable[FastAPI]: + local_app = create_app() # add the client setup to the same application # this is only required for testing, in reality # this will be in a different process - client_setup(app) - async with LifespanManager(app): - yield app + client_setup(local_app) + + async with LifespanManager(local_app): + yield local_app @pytest.fixture @@ -122,7 +138,7 @@ async def httpx_async_client( @pytest.fixture -def client( +async def client( app: FastAPI, httpx_async_client: AsyncClient, backend_url: AnyHttpUrl ) -> Client: return Client(app=app, async_client=httpx_async_client, base_url=f"{backend_url}") @@ -144,6 +160,15 @@ def mock_user_services_fail_to_stop(mocker: MockerFixture) -> None: ) +@pytest.fixture +def mock_post_rabbit_message(mocker: MockerFixture) -> AsyncMock: + return mocker.patch( + "simcore_service_dynamic_sidecar.core.rabbitmq._post_rabbit_message", + return_value=None, + autospec=True, + ) + + async def _get_task_id_create_service_containers( httpx_async_client: AsyncClient, compose_spec: DockerComposeYamlStr, @@ -173,11 +198,11 @@ async def _get_task_id_docker_compose_down(httpx_async_client: AsyncClient) -> T def _get_resource_tracking_messages( - mock_core_rabbitmq: dict[str, AsyncMock], + mock_post_rabbit_message: AsyncMock, ) -> list[RabbitResourceTrackingMessages]: return [ x[0][1] - for x in mock_core_rabbitmq["post_rabbit_message"].call_args_list + for x in mock_post_rabbit_message.call_args_list if isinstance(x[0][1], RabbitResourceTrackingMessages) ] @@ -201,7 +226,7 @@ async def _wait_for_containers_to_be_running(app: FastAPI) -> None: async def test_service_starts_and_closes_as_expected( - mock_core_rabbitmq: dict[str, AsyncMock], + mock_post_rabbit_message: AsyncMock, app: FastAPI, httpx_async_client: AsyncClient, client: Client, @@ -235,7 +260,9 @@ async def test_service_starts_and_closes_as_expected( await asyncio.sleep(_BASE_HEART_BEAT_INTERVAL * 10) # Ensure messages arrive in the expected order - resource_tracking_messages = _get_resource_tracking_messages(mock_core_rabbitmq) + resource_tracking_messages = _get_resource_tracking_messages( + mock_post_rabbit_message + ) assert len(resource_tracking_messages) >= 3 start_message = resource_tracking_messages[0] @@ -252,7 +279,7 @@ async def test_service_starts_and_closes_as_expected( @pytest.mark.parametrize("with_compose_down", [True, False]) async def test_user_services_fail_to_start( - mock_core_rabbitmq: dict[str, AsyncMock], + mock_post_rabbit_message: AsyncMock, app: FastAPI, httpx_async_client: AsyncClient, client: Client, @@ -284,12 +311,14 @@ async def test_user_services_fail_to_start( assert result is None # no messages were sent - resource_tracking_messages = _get_resource_tracking_messages(mock_core_rabbitmq) + resource_tracking_messages = _get_resource_tracking_messages( + mock_post_rabbit_message + ) assert len(resource_tracking_messages) == 0 async def test_user_services_fail_to_stop_or_save_data( - mock_core_rabbitmq: dict[str, AsyncMock], + mock_post_rabbit_message: AsyncMock, app: FastAPI, httpx_async_client: AsyncClient, client: Client, @@ -327,7 +356,9 @@ async def test_user_services_fail_to_stop_or_save_data( ... # Ensure messages arrive in the expected order - resource_tracking_messages = _get_resource_tracking_messages(mock_core_rabbitmq) + resource_tracking_messages = _get_resource_tracking_messages( + mock_post_rabbit_message + ) assert len(resource_tracking_messages) >= 3 start_message = resource_tracking_messages[0] @@ -384,7 +415,7 @@ async def _mocked_get_container_states( @pytest.mark.parametrize("expected_platform_state", SimcorePlatformStatus) async def test_user_services_crash_when_running( - mock_core_rabbitmq: dict[str, AsyncMock], + mock_post_rabbit_message: AsyncMock, app: FastAPI, httpx_async_client: AsyncClient, client: Client, @@ -419,7 +450,9 @@ async def test_user_services_crash_when_running( await _simulate_container_crash(container_names) # check only start and heartbeats are present - resource_tracking_messages = _get_resource_tracking_messages(mock_core_rabbitmq) + resource_tracking_messages = _get_resource_tracking_messages( + mock_post_rabbit_message + ) assert len(resource_tracking_messages) >= 2 start_message = resource_tracking_messages[0] @@ -431,11 +464,13 @@ async def test_user_services_crash_when_running( # reset mock await asyncio.sleep(_BASE_HEART_BEAT_INTERVAL * 2) - mock_core_rabbitmq["post_rabbit_message"].reset_mock() + mock_post_rabbit_message.reset_mock() # wait a bit more and check no further heartbeats are sent await asyncio.sleep(_BASE_HEART_BEAT_INTERVAL * 2) - new_resource_tracking_messages = _get_resource_tracking_messages(mock_core_rabbitmq) + new_resource_tracking_messages = _get_resource_tracking_messages( + mock_post_rabbit_message + ) assert len(new_resource_tracking_messages) == 0 # sending stop events, and since there was an issue multiple stops @@ -450,7 +485,9 @@ async def test_user_services_crash_when_running( ) as result: assert result is None - resource_tracking_messages = _get_resource_tracking_messages(mock_core_rabbitmq) + resource_tracking_messages = _get_resource_tracking_messages( + mock_post_rabbit_message + ) # NOTE: only 1 stop event arrives here since the stopping of the containers # was successful assert len(resource_tracking_messages) == 1 diff --git a/services/dynamic-sidecar/tests/unit/test_models_shared_store.py b/services/dynamic-sidecar/tests/unit/test_models_shared_store.py index 2c2b474a0290..7ecf24a2d33f 100644 --- a/services/dynamic-sidecar/tests/unit/test_models_shared_store.py +++ b/services/dynamic-sidecar/tests/unit/test_models_shared_store.py @@ -5,6 +5,7 @@ from copy import deepcopy from pathlib import Path from typing import Any +from unittest.mock import AsyncMock import arrow import pytest @@ -23,6 +24,7 @@ @pytest.fixture def trigger_setup_shutdown_events( + mock_core_rabbitmq: dict[str, AsyncMock], shared_store_dir: Path, app: FastAPI, test_client: TestClient, diff --git a/services/storage/src/simcore_service_storage/core/application.py b/services/storage/src/simcore_service_storage/core/application.py index ebe84c5643ff..305e13bb3eae 100644 --- a/services/storage/src/simcore_service_storage/core/application.py +++ b/services/storage/src/simcore_service_storage/core/application.py @@ -38,7 +38,6 @@ from ..exceptions.handlers import set_exception_handlers from ..modules.celery import setup_task_manager from ..modules.db import setup_db -from ..modules.long_running_tasks import setup_rest_api_long_running_tasks_for_uploads from ..modules.rabbitmq import setup as setup_rabbitmq from ..modules.redis import setup as setup_redis from ..modules.s3 import setup_s3 @@ -77,7 +76,7 @@ def create_app(settings: ApplicationSettings) -> FastAPI: # noqa: C901 setup_task_manager(app, celery_settings=settings.STORAGE_CELERY) setup_rpc_routes(app) - setup_rest_api_long_running_tasks_for_uploads(app) + setup_rest_api_routes(app, API_VTAG) set_exception_handlers(app) diff --git a/services/storage/src/simcore_service_storage/core/settings.py b/services/storage/src/simcore_service_storage/core/settings.py index a3725ac48576..9d49f4660ba8 100644 --- a/services/storage/src/simcore_service_storage/core/settings.py +++ b/services/storage/src/simcore_service_storage/core/settings.py @@ -75,10 +75,7 @@ class ApplicationSettings(BaseApplicationSettings, MixinLoggingSettings): ] STORAGE_RABBITMQ: Annotated[ - RabbitSettings | None, - Field( - json_schema_extra={"auto_default_from_env": True}, - ), + RabbitSettings, Field(json_schema_extra={"auto_default_from_env": True}) ] STORAGE_S3_CLIENT_MAX_TRANSFER_CONCURRENCY: Annotated[ diff --git a/services/storage/src/simcore_service_storage/modules/long_running_tasks.py b/services/storage/src/simcore_service_storage/modules/long_running_tasks.py deleted file mode 100644 index 834ec2dcbb76..000000000000 --- a/services/storage/src/simcore_service_storage/modules/long_running_tasks.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Final - -from fastapi import FastAPI -from servicelib.fastapi.long_running_tasks._server import setup -from servicelib.long_running_tasks.task import RedisNamespace - -from .._meta import API_VTAG -from ..core.settings import get_application_settings - -_LONG_RUNNING_TASKS_NAMESPACE: Final[RedisNamespace] = "storage" - - -def setup_rest_api_long_running_tasks_for_uploads(app: FastAPI) -> None: - setup( - app, - router_prefix=f"/{API_VTAG}/futures", - redis_settings=get_application_settings(app).STORAGE_REDIS, - redis_namespace=_LONG_RUNNING_TASKS_NAMESPACE, - ) diff --git a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml index 6ba4afa38b03..2e40aa33250c 100644 --- a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml +++ b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml @@ -3851,7 +3851,7 @@ paths: get: tags: - long-running-tasks - summary: List Tasks + summary: Get Async Jobs description: Lists all long running tasks operationId: get_async_jobs responses: @@ -3889,7 +3889,7 @@ paths: get: tags: - long-running-tasks - summary: Get Task Status + summary: Get Async Job Status description: Retrieves the status of a task operationId: get_async_job_status parameters: @@ -3933,8 +3933,8 @@ paths: delete: tags: - long-running-tasks - summary: Cancel And Delete Task - description: Cancels and deletes a task + summary: Cancel Async Job + description: Cancels and removes a task operationId: cancel_async_job parameters: - name: task_id @@ -3974,7 +3974,7 @@ paths: get: tags: - long-running-tasks - summary: Get Task Result + summary: Get Async Job Result description: Retrieves the result of a task operationId: get_async_job_result parameters: @@ -4053,9 +4053,9 @@ paths: delete: tags: - long-running-tasks-legacy - summary: Cancel And Delete Task - description: Cancels and deletes a task - operationId: cancel_and_delete_task + summary: Remove Task + description: Cancels and removes a task + operationId: remove_task parameters: - name: task_id in: path diff --git a/services/web/server/src/simcore_service_webserver/application.py b/services/web/server/src/simcore_service_webserver/application.py index 32e96663433b..1d6721823fdd 100644 --- a/services/web/server/src/simcore_service_webserver/application.py +++ b/services/web/server/src/simcore_service_webserver/application.py @@ -40,7 +40,7 @@ from .licenses.plugin import setup_licenses from .login.plugin import setup_login from .login_auth.plugin import setup_login_auth -from .long_running_tasks import setup_long_running_tasks +from .long_running_tasks.plugin import setup_long_running_tasks from .notifications.plugin import setup_notifications from .payments.plugin import setup_payments from .products.plugin import setup_products diff --git a/services/web/server/src/simcore_service_webserver/application_settings.py b/services/web/server/src/simcore_service_webserver/application_settings.py index 919b3947f9b0..8b0174f867a9 100644 --- a/services/web/server/src/simcore_service_webserver/application_settings.py +++ b/services/web/server/src/simcore_service_webserver/application_settings.py @@ -40,6 +40,7 @@ from .invitations.settings import InvitationsSettings from .licenses.settings import LicensesSettings from .login.settings import LoginSettings +from .long_running_tasks.settings import LongRunningTasksSettings from .payments.settings import PaymentsSettings from .projects.settings import ProjectsSettings from .resource_manager.settings import ResourceManagerSettings @@ -266,6 +267,14 @@ class ApplicationSettings(BaseApplicationSettings, MixinLoggingSettings): ), ] + WEBSERVER_LONG_RUNNING_TASKS: Annotated[ + LongRunningTasksSettings | None, + Field( + json_schema_extra={"auto_default_from_env": True}, + description="long running tasks plugin", + ), + ] + WEBSERVER_PAYMENTS: Annotated[ PaymentsSettings | None, Field( @@ -579,6 +588,9 @@ def to_client_statics(self) -> dict[str, Any]: "WEBSERVER_TRASH": { "TRASH_RETENTION_DAYS", }, + "WEBSERVER_LONG_RUNNING_TASKS": { + "LONG_RUNNING_TASKS_NAMESPACE_SUFFIX", + }, }, exclude_none=True, ) diff --git a/packages/service-library/src/servicelib/long_running_tasks/_store/__init__.py b/services/web/server/src/simcore_service_webserver/long_running_tasks/__init__.py similarity index 100% rename from packages/service-library/src/servicelib/long_running_tasks/_store/__init__.py rename to services/web/server/src/simcore_service_webserver/long_running_tasks/__init__.py diff --git a/services/web/server/src/simcore_service_webserver/long_running_tasks.py b/services/web/server/src/simcore_service_webserver/long_running_tasks/plugin.py similarity index 57% rename from services/web/server/src/simcore_service_webserver/long_running_tasks.py rename to services/web/server/src/simcore_service_webserver/long_running_tasks/plugin.py index a97c82a5852d..dee65b67aa4c 100644 --- a/services/web/server/src/simcore_service_webserver/long_running_tasks.py +++ b/services/web/server/src/simcore_service_webserver/long_running_tasks/plugin.py @@ -1,25 +1,27 @@ import logging from functools import wraps -from typing import Final from aiohttp import web from models_library.utils.fastapi_encoders import jsonable_encoder -from servicelib.aiohttp.application_setup import ensure_single_setup +from servicelib.aiohttp.application_setup import ModuleCategory, app_module_setup from servicelib.aiohttp.long_running_tasks._constants import ( RQT_LONG_RUNNING_TASKS_CONTEXT_KEY, ) from servicelib.aiohttp.long_running_tasks.server import setup from servicelib.aiohttp.typing_extension import Handler -from servicelib.long_running_tasks.task import RedisNamespace -from . import redis -from ._meta import API_VTAG -from .login.decorators import login_required -from .models import AuthenticatedRequestContext +from .. import rabbitmq_settings, redis +from .._meta import API_VTAG, APP_NAME +from ..login.decorators import login_required +from ..models import AuthenticatedRequestContext +from ..projects.plugin import register_projects_long_running_tasks +from . import settings as long_running_tasks_settings _logger = logging.getLogger(__name__) -_LONG_RUNNING_TASKS_NAMESPACE: Final[RedisNamespace] = "webserver-legacy" + +def _get_lrt_namespace(suffix: str) -> str: + return f"{APP_NAME}-{suffix}" def webserver_request_context_decorator(handler: Handler): @@ -35,12 +37,23 @@ async def _test_task_context_decorator( return _test_task_context_decorator -@ensure_single_setup(__name__, logger=_logger) +@app_module_setup( + __name__, + ModuleCategory.ADDON, + settings_name="WEBSERVER_LONG_RUNNING_TASKS", + logger=_logger, +) def setup_long_running_tasks(app: web.Application) -> None: + # register all long-running tasks from different modules + register_projects_long_running_tasks(app) + + settings = long_running_tasks_settings.get_plugin_settings(app) + setup( app, redis_settings=redis.get_plugin_settings(app), - redis_namespace=_LONG_RUNNING_TASKS_NAMESPACE, + rabbit_settings=rabbitmq_settings.get_plugin_settings(app), + lrt_namespace=_get_lrt_namespace(settings.LONG_RUNNING_TASKS_NAMESPACE_SUFFIX), router_prefix=f"/{API_VTAG}/tasks-legacy", handler_check_decorator=login_required, task_request_context_decorator=webserver_request_context_decorator, diff --git a/services/web/server/src/simcore_service_webserver/long_running_tasks/settings.py b/services/web/server/src/simcore_service_webserver/long_running_tasks/settings.py new file mode 100644 index 000000000000..ac3feb588005 --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/long_running_tasks/settings.py @@ -0,0 +1,26 @@ +from typing import Annotated + +from aiohttp import web +from pydantic import Field +from settings_library.base import BaseCustomSettings + +from ..constants import APP_SETTINGS_KEY + + +class LongRunningTasksSettings(BaseCustomSettings): + LONG_RUNNING_TASKS_NAMESPACE_SUFFIX: Annotated[ + str, + Field( + description=( + "suffix to distinguish between the various services based on this image " + "inside the long_running_tasks framework" + ), + ), + ] + + +def get_plugin_settings(app: web.Application) -> LongRunningTasksSettings: + settings = app[APP_SETTINGS_KEY].WEBSERVER_LONG_RUNNING_TASKS + assert settings, "setup_settings not called?" # nosec + assert isinstance(settings, LongRunningTasksSettings) # nosec + return settings diff --git a/services/web/server/src/simcore_service_webserver/projects/_controller/_rest_utils.py b/services/web/server/src/simcore_service_webserver/projects/_controller/_rest_utils.py index beab5959668f..077761083a23 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_controller/_rest_utils.py +++ b/services/web/server/src/simcore_service_webserver/projects/_controller/_rest_utils.py @@ -4,6 +4,7 @@ from models_library.rest_pagination_utils import paginate_data from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON from servicelib.rest_constants import RESPONSE_MODEL_POLICY +from yarl import URL from .. import _permalink_service from .._crud_api_read import _paralell_update @@ -11,13 +12,17 @@ async def aggregate_data_to_projects_from_request( - request: web.Request, + app: web.Application, + url: URL, + headers: dict[str, str], projects: list[ProjectDict], ) -> list[ProjectDict]: update_permalink_per_project = [ # permalink - _permalink_service.aggregate_permalink_in_project(request, project=prj) + _permalink_service.aggregate_permalink_in_project( + app, url, headers, project=prj + ) for prj in projects ] diff --git a/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py b/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py index bc28c3a89575..8208df41a0ed 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py +++ b/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py @@ -331,7 +331,10 @@ async def _stop_dynamic_service_task( return web.json_response(status=status.HTTP_204_NO_CONTENT) -TaskRegistry.register(_stop_dynamic_service_task) +def register_stop_dynamic_service_task(app: web.Application) -> None: + TaskRegistry.register( + _stop_dynamic_service_task, allowed_errors=(web.HTTPNotFound,), app=app + ) @routes.post( @@ -361,7 +364,6 @@ async def stop_node(request: web.Request) -> web.Response: _stop_dynamic_service_task.__name__, task_context=jsonable_encoder(req_ctx), # task arguments from here on --- - app=request.app, dynamic_service_stop=DynamicServiceStop( user_id=req_ctx.user_id, project_id=path_params.project_id, diff --git a/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py b/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py index 538644fc2a20..4168ec608def 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py +++ b/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py @@ -101,7 +101,8 @@ async def create_project(request: web.Request): fire_and_forget=True, task_context=jsonable_encoder(req_ctx), # arguments - request=request, + request_url=request.url, + request_headers=dict(request.headers), new_project_was_hidden_before_data_was_copied=query_params.hidden, from_study=query_params.from_study, as_template=query_params.as_template, @@ -158,7 +159,7 @@ async def list_projects(request: web.Request): ) projects = await _rest_utils.aggregate_data_to_projects_from_request( - request, projects + request.app, request.url, dict(request.headers), projects ) return _rest_utils.create_page_response( @@ -197,7 +198,7 @@ async def list_projects_full_search(request: web.Request): ) projects = await _rest_utils.aggregate_data_to_projects_from_request( - request, projects + request.app, request.url, dict(request.headers), projects ) return _rest_utils.create_page_response( @@ -247,7 +248,9 @@ async def get_active_project(request: web.Request) -> web.Response: ) # updates project's permalink field - await update_or_pop_permalink_in_project(request, project) + await update_or_pop_permalink_in_project( + request.app, request.url, dict(request.headers), project + ) data = ProjectGet.from_domain_model(project).data(exclude_unset=True) @@ -280,7 +283,9 @@ async def get_project(request: web.Request): ) # Adds permalink - await update_or_pop_permalink_in_project(request, project) + await update_or_pop_permalink_in_project( + request.app, request.url, dict(request.headers), project + ) data = ProjectGet.from_domain_model(project).data(exclude_unset=True) return envelope_json_response(data) @@ -419,7 +424,8 @@ async def clone_project(request: web.Request): fire_and_forget=True, task_context=jsonable_encoder(req_ctx), # arguments - request=request, + request_url=request.url, + request_headers=dict(request.headers), new_project_was_hidden_before_data_was_copied=False, from_study=path_params.project_id, as_template=False, diff --git a/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py b/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py index dbe61e993159..310ecb91022e 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py +++ b/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py @@ -27,6 +27,7 @@ ProjectNodeCreate, ) from simcore_postgres_database.webserver_models import ProjectType as ProjectTypeDB +from yarl import URL from ..application_settings import get_application_settings from ..catalog import catalog_service @@ -249,7 +250,9 @@ async def _compose_project_data( async def create_project( # pylint: disable=too-many-arguments,too-many-branches,too-many-statements # noqa: C901, PLR0913 progress: TaskProgress, *, - request: web.Request, + app: web.Application, + request_url: URL, + request_headers: dict[str, str], new_project_was_hidden_before_data_was_copied: bool, from_study: ProjectID | None, as_template: bool, @@ -281,7 +284,6 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche web.HTTPUnauthorized: """ - assert request.app # nosec _logger.info( "create_project for '%s' with %s %s %s", f"{user_id=}", @@ -290,7 +292,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche f"{from_study=}", ) - _projects_repository_legacy = ProjectDBAPI.get_from_app_context(request.app) + _projects_repository_legacy = ProjectDBAPI.get_from_app_context(app) new_project: ProjectDict = {} copy_file_coro = None @@ -303,7 +305,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche if predefined_project: if workspace_id := predefined_project.get("workspaceId", None): await check_user_workspace_access( - request.app, + app, user_id=user_id, workspace_id=workspace_id, product_name=product_name, @@ -312,7 +314,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche if folder_id := predefined_project.get("folderId", None): # Check user has access to folder await folders_folders_repository.get_for_user_or_workspace( - request.app, + app, folder_id=folder_id, product_name=product_name, user_id=user_id if workspace_id is None else None, @@ -328,7 +330,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche project_node_coro, copy_file_coro, ) = await _prepare_project_copy( - request.app, + app, user_id=user_id, product_name=product_name, src_project_uuid=from_study, @@ -342,7 +344,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche # 1.2 does project belong to some folder? workspace_id = new_project["workspaceId"] prj_to_folder_db = await _folders_repository.get_project_to_folder( - request.app, + app, project_id=from_study, private_workspace_user_id_or_none=( user_id if workspace_id is None else None @@ -361,7 +363,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche if predefined_project: # 2. overrides with optional body and re-validate new_project, project_nodes = await _compose_project_data( - request.app, + app, user_id=user_id, new_project=new_project, predefined_project=predefined_project, @@ -378,7 +380,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche ) # add parent linking if needed await set_project_ancestors( - request.app, + app, user_id=user_id, project_uuid=new_project["uuid"], parent_project_uuid=parent_project_uuid, @@ -389,7 +391,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche # 3.2 move project to proper folder if folder_id: await _folders_repository.insert_project_to_folder( - request.app, + app, project_id=new_project["uuid"], folder_id=folder_id, private_workspace_user_id_or_none=( @@ -405,20 +407,20 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche # 5. unhide the project if needed since it is now complete if not new_project_was_hidden_before_data_was_copied: await _projects_repository.patch_project( - request.app, + app, project_uuid=new_project["uuid"], new_partial_project_data={"hidden": False}, ) # update the network information in director-v2 await dynamic_scheduler_service.update_projects_networks( - request.app, project_id=ProjectID(new_project["uuid"]) + app, project_id=ProjectID(new_project["uuid"]) ) await progress.update() # This is a new project and every new graph needs to be reflected in the pipeline tables await director_v2_service.create_or_update_pipeline( - request.app, + app, user_id, new_project["uuid"], product_name, @@ -430,12 +432,14 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche ) # Appends state new_project = await _projects_service.add_project_states_for_user( - user_id=user_id, project=new_project, app=request.app + user_id=user_id, project=new_project, app=app ) await progress.update() # Adds permalink - await update_or_pop_permalink_in_project(request, new_project) + await update_or_pop_permalink_in_project( + app, request_url, request_headers, new_project + ) # Adds folderId user_specific_project_data_db = ( @@ -451,7 +455,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche # Overwrite project access rights if workspace_id: workspace: UserWorkspaceWithAccessRights = await get_user_workspace( - request.app, + app, user_id=user_id, workspace_id=workspace_id, product_name=product_name, @@ -494,7 +498,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche except (ParentProjectNotFoundError, ParentNodeNotFoundError) as exc: if project_uuid := new_project.get("uuid"): await _projects_service.submit_delete_project_task( - app=request.app, + app=app, project_uuid=project_uuid, user_id=user_id, simcore_user_agent=simcore_user_agent, @@ -508,7 +512,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche ) if project_uuid := new_project.get("uuid"): await _projects_service.submit_delete_project_task( - app=request.app, + app=app, project_uuid=project_uuid, user_id=user_id, simcore_user_agent=simcore_user_agent, @@ -516,4 +520,14 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche raise -TaskRegistry.register(create_project) +def register_create_project_task(app: web.Application) -> None: + TaskRegistry.register( + create_project, + allowed_errors=( + web.HTTPUnprocessableEntity, + web.HTTPBadRequest, + web.HTTPNotFound, + web.HTTPForbidden, + ), + app=app, + ) diff --git a/services/web/server/src/simcore_service_webserver/projects/_permalink_service.py b/services/web/server/src/simcore_service_webserver/projects/_permalink_service.py index e6fa6e61a8b2..d3c52a985012 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_permalink_service.py +++ b/services/web/server/src/simcore_service_webserver/projects/_permalink_service.py @@ -5,6 +5,7 @@ from aiohttp import web from models_library.api_schemas_webserver.permalinks import ProjectPermalink from models_library.projects import ProjectID +from yarl import URL from .exceptions import PermalinkFactoryError, PermalinkNotAllowedError from .models import ProjectDict @@ -15,9 +16,12 @@ class CreateLinkCoroutine(Protocol): async def __call__( - self, request: web.Request, project_uuid: ProjectID - ) -> ProjectPermalink: - ... + self, + app: web.Application, + request_url: URL, + request_headers: dict[str, str], + project_uuid: ProjectID, + ) -> ProjectPermalink: ... def register_factory(app: web.Application, factory_coro: CreateLinkCoroutine): @@ -39,13 +43,16 @@ def _get_factory(app: web.Application) -> CreateLinkCoroutine: async def _create_permalink( - request: web.Request, project_uuid: ProjectID + app: web.Application, + request_url: URL, + request_headers: dict[str, str], + project_uuid: ProjectID, ) -> ProjectPermalink: - create_coro: CreateLinkCoroutine = _get_factory(request.app) + create_coro: CreateLinkCoroutine = _get_factory(app) try: permalink: ProjectPermalink = await asyncio.wait_for( - create_coro(request=request, project_uuid=project_uuid), + create_coro(app, request_url, request_headers, project_uuid), timeout=_PERMALINK_CREATE_TIMEOUT_S, ) return permalink @@ -55,7 +62,10 @@ async def _create_permalink( async def update_or_pop_permalink_in_project( - request: web.Request, project: ProjectDict + app: web.Application, + request_url: URL, + request_headers: dict[str, str], + project: ProjectDict, ) -> ProjectPermalink | None: """Updates permalink entry in project @@ -64,7 +74,9 @@ async def update_or_pop_permalink_in_project( If fails, it pops it from project (so it is not set in the pydantic model. SEE ProjectGet.permalink) """ try: - permalink = await _create_permalink(request, project_uuid=project["uuid"]) + permalink = await _create_permalink( + app, request_url, request_headers, project_uuid=project["uuid"] + ) assert permalink # nosec project["permalink"] = permalink @@ -78,12 +90,12 @@ async def update_or_pop_permalink_in_project( async def aggregate_permalink_in_project( - request: web.Request, project: ProjectDict + app: web.Application, url: URL, headers: dict[str, str], project: ProjectDict ) -> ProjectDict: """ Adapter to use in parallel aggregation of fields in a project dataset """ - await update_or_pop_permalink_in_project(request, project) + await update_or_pop_permalink_in_project(app, url, headers, project) return project diff --git a/services/web/server/src/simcore_service_webserver/projects/plugin.py b/services/web/server/src/simcore_service_webserver/projects/plugin.py index 5028739d881b..e714ed350d73 100644 --- a/services/web/server/src/simcore_service_webserver/projects/plugin.py +++ b/services/web/server/src/simcore_service_webserver/projects/plugin.py @@ -29,12 +29,19 @@ wallets_rest, workspaces_rest, ) +from ._controller.nodes_rest import register_stop_dynamic_service_task +from ._crud_api_create import register_create_project_task from ._projects_repository_legacy import setup_projects_db from ._security_service import setup_projects_access logger = logging.getLogger(__name__) +def register_projects_long_running_tasks(app: web.Application) -> None: + register_create_project_task(app) + register_stop_dynamic_service_task(app) + + @app_module_setup( "simcore_service_webserver.projects", ModuleCategory.ADDON, diff --git a/services/web/server/src/simcore_service_webserver/studies_dispatcher/_projects_permalinks.py b/services/web/server/src/simcore_service_webserver/studies_dispatcher/_projects_permalinks.py index 92206532c6be..f63fc3cf1c0f 100644 --- a/services/web/server/src/simcore_service_webserver/studies_dispatcher/_projects_permalinks.py +++ b/services/web/server/src/simcore_service_webserver/studies_dispatcher/_projects_permalinks.py @@ -9,6 +9,7 @@ from typing_extensions import ( # https://docs.pydantic.dev/latest/api/standard_library_types/#typeddict TypedDict, ) +from yarl import URL from ..db.plugin import get_database_engine_legacy from ..projects.exceptions import PermalinkNotAllowedError, ProjectNotFoundError @@ -33,8 +34,10 @@ class _GroupAccessRightsDict(TypedDict): def create_permalink_for_study( - request: web.Request, + app: web.Application, *, + request_url: URL, + request_headers: dict[str, str], project_uuid: ProjectID | ProjectIDStr, project_type: ProjectType, project_access_rights: dict[_GroupID, _GroupAccessRightsDict], @@ -65,7 +68,7 @@ def create_permalink_for_study( raise PermalinkNotAllowedError(msg) # create - url_for = create_url_for_function(request) + url_for = create_url_for_function(app, request_url, request_headers) permalink = TypeAdapter(HttpUrl).validate_python( url_for(route_name="get_redirection_to_study_page", id=f"{project_uuid}"), ) @@ -77,14 +80,17 @@ def create_permalink_for_study( async def permalink_factory( - request: web.Request, project_uuid: ProjectID + app: web.Application, + request_url: URL, + request_headers: dict[str, str], + project_uuid: ProjectID, ) -> ProjectPermalink: """ - Assumes project_id is up-to-date in the database """ # NOTE: next iterations will mobe this as part of the project repository pattern - engine = get_database_engine_legacy(request.app) + engine = get_database_engine_legacy(app) async with engine.acquire() as conn: access_rights_subquery = ( sa.select( @@ -121,7 +127,9 @@ async def permalink_factory( raise ProjectNotFoundError(project_uuid=project_uuid) return create_permalink_for_study( - request, + app, + request_url=request_url, + request_headers=request_headers, project_uuid=row.uuid, project_type=row.type, project_access_rights=row.access_rights, diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index 375165e7c775..45c6457bc582 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -34,7 +34,7 @@ from .._meta import API_VTAG from ..constants import ASYNC_JOB_CLIENT_NAME from ..login.decorators import login_required -from ..long_running_tasks import webserver_request_context_decorator +from ..long_running_tasks.plugin import webserver_request_context_decorator from ..models import AuthenticatedRequestContext from ..rabbitmq import get_rabbitmq_rpc_client from ..security.decorators import permission_required @@ -59,7 +59,8 @@ async def get_async_jobs(request: web.Request) -> web.Response: inprocess_long_running_manager = get_long_running_manager(request.app) inprocess_tracked_tasks = await lrt_api.list_tasks( - inprocess_long_running_manager.tasks_manager, + inprocess_long_running_manager.rpc_client, + inprocess_long_running_manager.lrt_namespace, inprocess_long_running_manager.get_task_context(request), ) @@ -92,7 +93,7 @@ async def get_async_jobs(request: web.Request) -> web.Response: TaskGet( task_id=f"{task.task_id}", status_href=f"{request.app.router['get_task_status'].url_for(task_id=task.task_id)}", - abort_href=f"{request.app.router['cancel_and_delete_task'].url_for(task_id=task.task_id)}", + abort_href=f"{request.app.router['remove_task'].url_for(task_id=task.task_id)}", result_href=f"{request.app.router['get_task_result'].url_for(task_id=task.task_id)}", ) for task in inprocess_tracked_tasks diff --git a/services/web/server/src/simcore_service_webserver/utils_aiohttp.py b/services/web/server/src/simcore_service_webserver/utils_aiohttp.py index b70a6c6897aa..e0b753e8aee9 100644 --- a/services/web/server/src/simcore_service_webserver/utils_aiohttp.py +++ b/services/web/server/src/simcore_service_webserver/utils_aiohttp.py @@ -35,8 +35,9 @@ def get_routes_view(routes: RouteTableDef) -> str: return fh.getvalue() -def create_url_for_function(request: web.Request) -> Callable: - app = request.app +def create_url_for_function( + app: web.Application, request_url: URL, request_headers: dict[str, str] +) -> Callable: def _url_for(route_name: str, **params: dict[str, Any]) -> str: """Reverse URL constructing using named resources""" @@ -44,16 +45,16 @@ def _url_for(route_name: str, **params: dict[str, Any]) -> str: rel_url: URL = app.router[route_name].url_for( **{k: f"{v}" for k, v in params.items()} ) - url: URL = ( - request.url.origin() + _url: URL = ( + request_url.origin() .with_scheme( # Custom header by traefik. See labels in docker-compose as: # - traefik.http.middlewares.${SWARM_STACK_NAME_NO_HYPHEN}_sslheader.headers.customrequestheaders.X-Forwarded-Proto=http - request.headers.get(X_FORWARDED_PROTO, request.url.scheme) + request_headers.get(X_FORWARDED_PROTO, request_url.scheme) ) .with_path(str(rel_url)) ) - return f"{url}" + return f"{_url}" except KeyError as err: msg = f"Cannot find URL because there is no resource registered as {route_name=}Check name spelling or whether the router was not registered" diff --git a/services/web/server/tests/unit/isolated/test_studies_dispatcher_projects_permalinks.py b/services/web/server/tests/unit/isolated/test_studies_dispatcher_projects_permalinks.py index def692faeccd..384041df0bd9 100644 --- a/services/web/server/tests/unit/isolated/test_studies_dispatcher_projects_permalinks.py +++ b/services/web/server/tests/unit/isolated/test_studies_dispatcher_projects_permalinks.py @@ -91,7 +91,9 @@ def test_create_permalink(fake_get_project_request: web.Request, is_public: bool project_uuid: str = fake_get_project_request.match_info["project_uuid"] permalink = create_permalink_for_study( - fake_get_project_request, + fake_get_project_request.app, + request_url=fake_get_project_request.url, + request_headers=dict(fake_get_project_request.headers), project_uuid=project_uuid, project_type=ProjectType.TEMPLATE, project_access_rights={"1": {"read": True, "write": False, "delete": False}}, @@ -119,7 +121,9 @@ def test_permalink_only_for_template_projects( ): with pytest.raises(PermalinkNotAllowedError): create_permalink_for_study( - fake_get_project_request, + fake_get_project_request.app, + request_url=fake_get_project_request.url, + request_headers=dict(fake_get_project_request.headers), **{**valid_project_kwargs, "project_type": ProjectType.STANDARD} ) @@ -129,7 +133,9 @@ def test_permalink_only_when_read_access_to_everyone( ): with pytest.raises(PermalinkNotAllowedError): create_permalink_for_study( - fake_get_project_request, + fake_get_project_request.app, + request_url=fake_get_project_request.url, + request_headers=dict(fake_get_project_request.headers), **{ **valid_project_kwargs, "project_access_rights": { @@ -140,7 +146,9 @@ def test_permalink_only_when_read_access_to_everyone( with pytest.raises(PermalinkNotAllowedError): create_permalink_for_study( - fake_get_project_request, + fake_get_project_request.app, + request_url=fake_get_project_request.url, + request_headers=dict(fake_get_project_request.headers), **{ **valid_project_kwargs, "project_access_rights": { diff --git a/services/web/server/tests/unit/with_dbs/01/test_api_keys.py b/services/web/server/tests/unit/with_dbs/01/test_api_keys.py index 38966e1fa827..902942840f54 100644 --- a/services/web/server/tests/unit/with_dbs/01/test_api_keys.py +++ b/services/web/server/tests/unit/with_dbs/01/test_api_keys.py @@ -248,7 +248,7 @@ async def test_create_api_key_with_expiration( "/v0/auth/api-keys", json={ "displayName": expected_api_key, - "expiration": expiration_interval.seconds, + "expiration": expiration_interval.total_seconds(), }, ) @@ -264,7 +264,9 @@ async def test_create_api_key_with_expiration( assert [d["displayName"] for d in data] == [expected_api_key] # wait for api-key for it to expire and force-run scheduled task - await asyncio.sleep(EXPIRATION_WAIT_FACTOR * expiration_interval.seconds) + await asyncio.sleep( + EXPIRATION_WAIT_FACTOR * expiration_interval.total_seconds() + ) deleted = await api_keys_service.prune_expired_api_keys(client.app) assert deleted == [expected_api_key] diff --git a/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py b/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py index 4a1f654e3787..7a06060369e2 100644 --- a/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py +++ b/services/web/server/tests/unit/with_dbs/01/test_long_running_tasks.py @@ -27,7 +27,7 @@ ("GET", "list_tasks", {}), ("GET", "get_task_status", {"task_id": "some_fake_task_id"}), ("GET", "get_task_result", {"task_id": "some_fake_task_id"}), - ("DELETE", "cancel_and_delete_task", {"task_id": "some_fake_task_id"}), + ("DELETE", "remove_task", {"task_id": "some_fake_task_id"}), ], ) async def test_long_running_tasks_access_restricted_to_logged_users( diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py b/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py index e9c28fa7825d..3e535cab5b5d 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_cancellations.py @@ -27,6 +27,7 @@ from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import ( AsyncJobComposedResult, ) +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_postgres_database.models.users import UserRole from simcore_service_webserver._meta import api_version_prefix @@ -36,19 +37,21 @@ from tenacity.stop import stop_after_delay from tenacity.wait import wait_fixed +pytest_simcore_core_services_selection = [ + "rabbit", +] + API_PREFIX = "/" + api_version_prefix @pytest.fixture def app_environment( use_in_memory_redis: RedisSettings, + rabbit_settings: RabbitSettings, app_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPatch, ) -> EnvVarsDict: - envs_plugins = setenvs_from_dict( - monkeypatch, - {}, - ) + envs_plugins = setenvs_from_dict(monkeypatch, {}) return app_environment | envs_plugins diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers.py b/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers.py index 7c7c4a6ed377..2b457d3579a4 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_crud_handlers.py @@ -33,6 +33,7 @@ from pytest_simcore.helpers.webserver_users import UserInfoDict from servicelib.aiohttp import status from servicelib.rest_constants import X_PRODUCT_NAME_HEADER +from settings_library.rabbit import RabbitSettings from simcore_postgres_database.models.products import products from simcore_postgres_database.models.projects_to_products import projects_to_products from simcore_service_webserver._meta import api_version_prefix @@ -48,6 +49,10 @@ from simcore_service_webserver.utils import to_datetime from yarl import URL +pytest_simcore_core_services_selection = [ + "rabbit", +] + API_PREFIX = "/" + api_version_prefix @@ -190,6 +195,7 @@ async def _assert_get_same_project( ], ) async def test_list_projects( + rabbit_settings: RabbitSettings, client: TestClient, mocked_dynamic_services_interface: dict[str, mock.MagicMock], logged_user: dict[str, Any], diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_metadata_handlers.py b/services/web/server/tests/unit/with_dbs/02/test_projects_metadata_handlers.py index 9aa4f2b8161d..ee94a5640aab 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_metadata_handlers.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_metadata_handlers.py @@ -28,12 +28,17 @@ ) from pytest_simcore.helpers.webserver_users import UserInfoDict from servicelib.aiohttp import status +from settings_library.rabbit import RabbitSettings from simcore_postgres_database.utils_projects_metadata import ( get as get_db_project_metadata, ) from simcore_service_webserver.projects import _crud_api_delete from simcore_service_webserver.projects.models import ProjectDict +pytest_simcore_core_services_selection = [ + "rabbit", +] + @pytest.mark.acceptance_test( "For https://github.com/ITISFoundation/osparc-simcore/issues/4313" @@ -113,6 +118,7 @@ async def _wait_until_deleted(): @pytest.mark.parametrize(*standard_user_role_response()) async def test_new_project_with_parent_project_node( + rabbit_settings: RabbitSettings, mock_dynamic_scheduler: None, # for deletion mocked_dynamic_services_interface: dict[str, MagicMock], @@ -269,6 +275,7 @@ async def test_new_project_with_invalid_parent_project_node( @pytest.mark.parametrize(*standard_user_role_response()) async def test_set_project_parent_backward_compatibility( + rabbit_settings: RabbitSettings, mock_dynamic_scheduler: None, # for deletion mocked_dynamic_services_interface: dict[str, MagicMock], diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py b/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py index 0db1b82b3781..ce031477e935 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_nodes_handler.py @@ -51,6 +51,7 @@ ) from servicelib.aiohttp import status from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_postgres_database.models.projects import projects as projects_db_model from simcore_service_webserver.db.models import UserRole @@ -68,6 +69,10 @@ wait_fixed, ) +pytest_simcore_core_services_selection = [ + "rabbit", +] + @pytest.mark.parametrize( "user_role,expected", @@ -1060,6 +1065,7 @@ async def test_start_node_raises_if_called_with_wrong_data( @pytest.mark.parametrize(*standard_role_response(), ids=str) async def test_stop_node( + rabbit_settings: RabbitSettings, use_in_memory_redis: RedisSettings, client: TestClient, user_project_with_num_dynamic_services: Callable[[int], Awaitable[ProjectDict]], diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_states_handlers.py b/services/web/server/tests/unit/with_dbs/02/test_projects_states_handlers.py index 5e79fe538500..6be61d40b09f 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_states_handlers.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_states_handlers.py @@ -63,6 +63,7 @@ from pytest_simcore.helpers.webserver_users import UserInfoDict from servicelib.aiohttp import status from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE +from settings_library.rabbit import RabbitSettings from simcore_postgres_database.models.products import products from simcore_postgres_database.models.wallets import wallets from simcore_service_webserver._meta import API_VTAG @@ -81,6 +82,11 @@ wait_fixed, ) +pytest_simcore_core_services_selection = [ + "rabbit", +] + + RESOURCE_NAME = "projects" API_PREFIX = f"/{API_VTAG}" @@ -282,6 +288,7 @@ async def _delete_project(client: TestClient, project: dict) -> ClientResponse: @pytest.mark.parametrize(*standard_role_response()) async def test_share_project_user_roles( + rabbit_service: RabbitSettings, mock_dynamic_scheduler: None, client: TestClient, logged_user: dict, diff --git a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/conftest.py b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/conftest.py index 1b4b8e20ff28..93ec781cab84 100644 --- a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/conftest.py +++ b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/conftest.py @@ -13,7 +13,9 @@ @pytest.fixture -def app_environment(app_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPatch): +def app_environment( + app_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPatch +) -> EnvVarsDict: envs_plugins = setenvs_from_dict( monkeypatch, { diff --git a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_handlers.py b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_handlers.py index 2acc3b965b1e..e48f716bd369 100644 --- a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_handlers.py +++ b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_handlers.py @@ -16,17 +16,22 @@ from aiohttp import ClientResponse, ClientSession from aiohttp.test_utils import TestClient, TestServer from aioresponses import aioresponses +from common_library.json_serialization import json_dumps +from common_library.serialization import model_dump_with_secrets from common_library.users_enums import UserRole from models_library.projects_state import ProjectShareState, ProjectStatus from pydantic import BaseModel, ByteSize, TypeAdapter from pytest_mock import MockerFixture from pytest_simcore.helpers.assert_checks import assert_status +from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict +from pytest_simcore.helpers.typing_env import EnvVarsDict from pytest_simcore.helpers.webserver_users import UserInfoDict from pytest_simcore.pydantic_models import ( assert_validation_model, walk_model_examples_in_package, ) from servicelib.aiohttp import status +from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from settings_library.utils_session import DEFAULT_SESSION_COOKIE_NAME from simcore_service_webserver.studies_dispatcher._core import ViewerInfo @@ -34,6 +39,10 @@ from sqlalchemy.sql import text from yarl import URL +pytest_simcore_core_services_selection = [ + "rabbit", +] + # # FIXTURES OVERRIDES # @@ -77,7 +86,25 @@ def postgres_db(postgres_db: sa.engine.Engine) -> sa.engine.Engine: @pytest.fixture -def web_server(redis_service: RedisSettings, web_server: TestServer) -> TestServer: +def app_environment( + app_environment: EnvVarsDict, + monkeypatch: pytest.MonkeyPatch, + rabbit_service: RabbitSettings, +) -> EnvVarsDict: + return setenvs_from_dict( + monkeypatch, + { + "WEBSERVER_RABBITMQ": json_dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ) + }, + ) + + +@pytest.fixture +def web_server( + redis_service: RedisSettings, rabbit_service: RabbitSettings, web_server: TestServer +) -> TestServer: # # Extends web_server to start redis_service # diff --git a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_projects.py b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_projects.py index 75e09bf72fc8..a268b839f478 100644 --- a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_projects.py +++ b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_projects.py @@ -10,13 +10,18 @@ import pytest from aiohttp.test_utils import TestClient +from common_library.json_serialization import json_dumps +from common_library.serialization import model_dump_with_secrets from faker import Faker from models_library.projects import Project, ProjectID from models_library.projects_nodes_io import NodeID from pytest_mock import MockerFixture +from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict +from pytest_simcore.helpers.typing_env import EnvVarsDict from pytest_simcore.helpers.webserver_fake_services_data import list_fake_file_consumers from pytest_simcore.helpers.webserver_login import NewUser from pytest_simcore.helpers.webserver_projects import delete_all_projects +from settings_library.rabbit import RabbitSettings from simcore_service_webserver.groups.api import auto_add_user_to_groups from simcore_service_webserver.projects._projects_service import get_project_for_user from simcore_service_webserver.studies_dispatcher._models import ServiceInfo @@ -29,9 +34,30 @@ ) from simcore_service_webserver.users.users_service import get_user +pytest_simcore_core_services_selection = [ + "rabbit", +] + + FAKE_FILE_VIEWS = list_fake_file_consumers() +@pytest.fixture +def app_environment( + app_environment: EnvVarsDict, + monkeypatch: pytest.MonkeyPatch, + rabbit_service: RabbitSettings, +) -> EnvVarsDict: + return setenvs_from_dict( + monkeypatch, + { + "WEBSERVER_RABBITMQ": json_dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ) + }, + ) + + @pytest.fixture async def user(client: TestClient) -> AsyncIterator[UserInfo]: async with NewUser(app=client.app) as user_db: diff --git a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py index d722c6c9dc94..5321874a1686 100644 --- a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py +++ b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py @@ -18,6 +18,8 @@ import redis.asyncio as aioredis from aiohttp import ClientResponse, ClientSession, web from aiohttp.test_utils import TestClient, TestServer +from common_library.json_serialization import json_dumps +from common_library.serialization import model_dump_with_secrets from common_library.users_enums import UserRole from faker import Faker from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobStatus @@ -30,6 +32,8 @@ from pytest_mock import MockerFixture from pytest_simcore.aioresponses_mocker import AioResponsesMock from pytest_simcore.helpers.assert_checks import assert_status +from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict +from pytest_simcore.helpers.typing_env import EnvVarsDict from pytest_simcore.helpers.webserver_parametrizations import MockedStorageSubsystem from pytest_simcore.helpers.webserver_projects import NewProject, delete_all_projects from pytest_simcore.helpers.webserver_users import UserInfoDict @@ -39,6 +43,7 @@ AsyncJobComposedResult, ) from servicelib.rest_responses import unwrap_envelope +from settings_library.rabbit import RabbitSettings from settings_library.utils_session import DEFAULT_SESSION_COOKIE_NAME from simcore_service_webserver.projects._projects_service import ( submit_delete_project_task, @@ -51,6 +56,10 @@ ) from tenacity import retry, stop_after_attempt, wait_fixed +pytest_simcore_core_services_selection = [ + "rabbit", +] + async def _get_user_projects(client) -> list[ProjectDict]: url = client.app.router["list_projects"].url_for() @@ -89,6 +98,22 @@ def _is_user_authenticated(session: ClientSession) -> bool: return DEFAULT_SESSION_COOKIE_NAME in [c.key for c in session.cookie_jar] +@pytest.fixture +def app_environment( + app_environment: EnvVarsDict, + monkeypatch: pytest.MonkeyPatch, + rabbit_service: RabbitSettings, +) -> EnvVarsDict: + return setenvs_from_dict( + monkeypatch, + { + "WEBSERVER_RABBITMQ": json_dumps( + model_dump_with_secrets(rabbit_service, show_secrets=True) + ) + }, + ) + + @pytest.fixture async def published_project( client: TestClient,