diff --git a/packages/celery-library/src/celery_library/rpc/__init__.py b/packages/celery-library/src/celery_library/rpc/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py b/packages/celery-library/src/celery_library/rpc/_async_jobs.py similarity index 80% rename from services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py rename to packages/celery-library/src/celery_library/rpc/_async_jobs.py index 8628413e83c..4972142a457 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py +++ b/packages/celery-library/src/celery_library/rpc/_async_jobs.py @@ -3,11 +3,6 @@ import logging from celery.exceptions import CeleryError # type: ignore[import-untyped] -from celery_library.errors import ( - TransferrableCeleryError, - decode_celery_transferrable_error, -) -from fastapi import FastAPI from models_library.api_schemas_rpc_async_jobs.async_jobs import ( AsyncJobGet, AsyncJobId, @@ -22,21 +17,27 @@ JobSchedulerError, ) from servicelib.celery.models import TaskState +from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_catch from servicelib.rabbitmq import RPCRouter -from ...modules.celery import get_task_manager_from_app +from ..errors import ( + TransferrableCeleryError, + decode_celery_transferrable_error, +) _logger = logging.getLogger(__name__) router = RPCRouter() @router.expose(reraise_if_error_type=(JobSchedulerError,)) -async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData): - assert app # nosec +async def cancel( + task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData +): + assert task_manager # nosec assert job_id_data # nosec try: - await get_task_manager_from_app(app).cancel_task( + await task_manager.cancel_task( task_context=job_id_data.model_dump(), task_uuid=job_id, ) @@ -46,13 +47,13 @@ async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData @router.expose(reraise_if_error_type=(JobSchedulerError,)) async def status( - app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData + task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData ) -> AsyncJobStatus: - assert app # nosec + assert task_manager # nosec assert job_id_data # nosec try: - task_status = await get_task_manager_from_app(app).get_task_status( + task_status = await task_manager.get_task_status( task_context=job_id_data.model_dump(), task_uuid=job_id, ) @@ -75,20 +76,20 @@ async def status( ) ) async def result( - app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData + task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData ) -> AsyncJobResult: - assert app # nosec + assert task_manager # nosec assert job_id # nosec assert job_id_data # nosec try: - _status = await get_task_manager_from_app(app).get_task_status( + _status = await task_manager.get_task_status( task_context=job_id_data.model_dump(), task_uuid=job_id, ) if not _status.is_done: raise JobNotDoneError(job_id=job_id) - _result = await get_task_manager_from_app(app).get_task_result( + _result = await task_manager.get_task_result( task_context=job_id_data.model_dump(), task_uuid=job_id, ) @@ -122,12 +123,12 @@ async def result( @router.expose(reraise_if_error_type=(JobSchedulerError,)) async def list_jobs( - app: FastAPI, filter_: str, job_id_data: AsyncJobNameData + task_manager: TaskManager, filter_: str, job_id_data: AsyncJobNameData ) -> list[AsyncJobGet]: _ = filter_ - assert app # nosec + assert task_manager # nosec try: - tasks = await get_task_manager_from_app(app).list_tasks( + tasks = await task_manager.list_tasks( task_context=job_id_data.model_dump(), ) except CeleryError as exc: diff --git a/packages/celery-library/tests/conftest.py b/packages/celery-library/tests/conftest.py index 4d37e3a4e6d..2553d9df6d7 100644 --- a/packages/celery-library/tests/conftest.py +++ b/packages/celery-library/tests/conftest.py @@ -15,6 +15,7 @@ from celery_library.common import create_task_manager from celery_library.signals import on_worker_init, on_worker_shutdown from celery_library.task_manager import CeleryTaskManager +from celery_library.types import register_celery_types from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.celery.app_server import BaseAppServer @@ -25,6 +26,7 @@ "pytest_simcore.docker_compose", "pytest_simcore.docker_swarm", "pytest_simcore.environment_configs", + "pytest_simcore.rabbit_service", "pytest_simcore.redis_service", "pytest_simcore.repository_paths", ] @@ -123,6 +125,8 @@ async def celery_task_manager( celery_settings: CelerySettings, with_celery_worker: TestWorkController, ) -> CeleryTaskManager: + register_celery_types() + return await create_task_manager( celery_app, celery_settings, diff --git a/services/storage/tests/unit/test_async_jobs.py b/packages/celery-library/tests/unit/test_async_jobs.py similarity index 71% rename from services/storage/tests/unit/test_async_jobs.py rename to packages/celery-library/tests/unit/test_async_jobs.py index 26140eb037c..02c8362c1aa 100644 --- a/services/storage/tests/unit/test_async_jobs.py +++ b/packages/celery-library/tests/unit/test_async_jobs.py @@ -3,16 +3,18 @@ import asyncio import pickle -from collections.abc import Callable +from collections.abc import Awaitable, Callable from datetime import timedelta from enum import Enum -from typing import Any +from typing import Any, Final import pytest from celery import Celery, Task from celery.contrib.testing.worker import TestWorkController +from celery_library.rpc import _async_jobs from celery_library.task import register_task -from fastapi import FastAPI +from common_library.errors_classes import OsparcErrorMixin +from faker import Faker from models_library.api_schemas_rpc_async_jobs.async_jobs import ( AsyncJobGet, AsyncJobNameData, @@ -21,15 +23,14 @@ JobAbortedError, JobError, ) -from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE -from models_library.api_schemas_storage.export_data_async_jobs import AccessRightError from models_library.products import ProductName +from models_library.rabbitmq_basic_types import RPCNamespace from models_library.users import UserID +from pydantic import TypeAdapter from servicelib.celery.models import TaskID, TaskMetadata +from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RabbitMQRPCClient, RPCRouter from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs -from simcore_service_storage.api.rpc.routes import get_rabbitmq_rpc_server -from simcore_service_storage.modules.celery import get_task_manager_from_app from tenacity import ( AsyncRetrying, retry_if_exception_type, @@ -39,20 +40,49 @@ pytest_simcore_core_services_selection = [ "rabbit", - "postgres", + "redis", ] +class AccessRightError(OsparcErrorMixin, RuntimeError): + msg_template: str = ( + "User {user_id} does not have access to file {file_id} with location {location_id}" + ) + + +@pytest.fixture +async def async_jobs_rabbitmq_rpc_client( + rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], +) -> RabbitMQRPCClient: + rpc_client = await rabbitmq_rpc_client("pytest_async_jobs_rpc_client") + assert rpc_client + return rpc_client + + +@pytest.fixture +def user_id(faker: Faker) -> UserID: + return faker.pyint(min_value=1) + + +@pytest.fixture +def product_name(faker: Faker) -> ProductName: + return faker.word() + + ###### RPC Interface ###### router = RPCRouter() +ASYNC_JOBS_RPC_NAMESPACE: Final[RPCNamespace] = TypeAdapter( + RPCNamespace +).validate_python("async_jobs") + @router.expose() async def rpc_sync_job( - app: FastAPI, *, job_id_data: AsyncJobNameData, **kwargs: Any + task_manager: TaskManager, *, job_id_data: AsyncJobNameData, **kwargs: Any ) -> AsyncJobGet: task_name = sync_job.__name__ - task_uuid = await get_task_manager_from_app(app).submit_task( + task_uuid = await task_manager.submit_task( TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs ) @@ -61,10 +91,10 @@ async def rpc_sync_job( @router.expose() async def rpc_async_job( - app: FastAPI, *, job_id_data: AsyncJobNameData, **kwargs: Any + task_manager: TaskManager, *, job_id_data: AsyncJobNameData, **kwargs: Any ) -> AsyncJobGet: task_name = async_job.__name__ - task_uuid = await get_task_manager_from_app(app).submit_task( + task_uuid = await task_manager.submit_task( TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs ) @@ -108,9 +138,15 @@ async def async_job(task: Task, task_id: TaskID, action: Action, payload: Any) - @pytest.fixture -async def register_rpc_routes(initialized_app: FastAPI) -> None: - rpc_server = get_rabbitmq_rpc_server(initialized_app) - await rpc_server.register_router(router, STORAGE_RPC_NAMESPACE, initialized_app) +async def register_rpc_routes( + async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, celery_task_manager: TaskManager +) -> None: + await async_jobs_rabbitmq_rpc_client.register_router( + _async_jobs.router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager + ) + await async_jobs_rabbitmq_rpc_client.register_router( + router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager + ) async def _start_task_via_rpc( @@ -124,7 +160,7 @@ async def _start_task_via_rpc( job_id_data = AsyncJobNameData(user_id=user_id, product_name=product_name) async_job_get = await async_jobs.submit( rabbitmq_rpc_client=client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, method_name=rpc_task_name, job_id_data=job_id_data, **kwargs, @@ -170,7 +206,7 @@ async def _wait_for_job( with attempt: result = await async_jobs.status( rpc_client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, job_id_data=job_id_data, ) @@ -198,17 +234,16 @@ async def _wait_for_job( ], ) async def test_async_jobs_workflow( - initialized_app: FastAPI, register_rpc_routes: None, - storage_rabbitmq_rpc_client: RabbitMQRPCClient, - with_storage_celery_worker: TestWorkController, + async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, + with_celery_worker: TestWorkController, user_id: UserID, product_name: ProductName, exposed_rpc_start: str, payload: Any, ): async_job_get, job_id_data = await _start_task_via_rpc( - storage_rabbitmq_rpc_client, + async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, product_name=product_name, @@ -217,22 +252,22 @@ async def test_async_jobs_workflow( ) jobs = await async_jobs.list_jobs( - storage_rabbitmq_rpc_client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + async_jobs_rabbitmq_rpc_client, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, filter_="", # currently not used job_id_data=job_id_data, ) assert len(jobs) > 0 await _wait_for_job( - storage_rabbitmq_rpc_client, + async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, job_id_data=job_id_data, ) async_job_result = await async_jobs.result( - storage_rabbitmq_rpc_client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + async_jobs_rabbitmq_rpc_client, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, job_id_data=job_id_data, ) @@ -246,16 +281,15 @@ async def test_async_jobs_workflow( ], ) async def test_async_jobs_cancel( - initialized_app: FastAPI, register_rpc_routes: None, - storage_rabbitmq_rpc_client: RabbitMQRPCClient, - with_storage_celery_worker: TestWorkController, + async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, + with_celery_worker: TestWorkController, user_id: UserID, product_name: ProductName, exposed_rpc_start: str, ): async_job_get, job_id_data = await _start_task_via_rpc( - storage_rabbitmq_rpc_client, + async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, product_name=product_name, @@ -264,21 +298,21 @@ async def test_async_jobs_cancel( ) await async_jobs.cancel( - storage_rabbitmq_rpc_client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + async_jobs_rabbitmq_rpc_client, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, job_id_data=job_id_data, ) await _wait_for_job( - storage_rabbitmq_rpc_client, + async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, job_id_data=job_id_data, ) jobs = await async_jobs.list_jobs( - storage_rabbitmq_rpc_client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + async_jobs_rabbitmq_rpc_client, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, filter_="", # currently not used job_id_data=job_id_data, ) @@ -286,8 +320,8 @@ async def test_async_jobs_cancel( with pytest.raises(JobAbortedError): await async_jobs.result( - storage_rabbitmq_rpc_client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + async_jobs_rabbitmq_rpc_client, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, job_id_data=job_id_data, ) @@ -311,17 +345,16 @@ async def test_async_jobs_cancel( ], ) async def test_async_jobs_raises( - initialized_app: FastAPI, register_rpc_routes: None, - storage_rabbitmq_rpc_client: RabbitMQRPCClient, - with_storage_celery_worker: TestWorkController, + async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, + with_celery_worker: TestWorkController, user_id: UserID, product_name: ProductName, exposed_rpc_start: str, error: Exception, ): async_job_get, job_id_data = await _start_task_via_rpc( - storage_rabbitmq_rpc_client, + async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, product_name=product_name, @@ -330,7 +363,7 @@ async def test_async_jobs_raises( ) await _wait_for_job( - storage_rabbitmq_rpc_client, + async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, job_id_data=job_id_data, stop_after=timedelta(minutes=1), @@ -338,8 +371,8 @@ async def test_async_jobs_raises( with pytest.raises(JobError) as exc: await async_jobs.result( - storage_rabbitmq_rpc_client, - rpc_namespace=STORAGE_RPC_NAMESPACE, + async_jobs_rabbitmq_rpc_client, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, job_id_data=job_id_data, ) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_paths.py b/services/storage/src/simcore_service_storage/api/rpc/_paths.py index 9c704d46820..f4b0eae297d 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_paths.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_paths.py @@ -1,16 +1,15 @@ import logging from pathlib import Path -from fastapi import FastAPI from models_library.api_schemas_rpc_async_jobs.async_jobs import ( AsyncJobGet, AsyncJobNameData, ) from models_library.projects_nodes_io import LocationID from servicelib.celery.models import TaskMetadata +from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RPCRouter -from ...modules.celery import get_task_manager_from_app from .._worker_tasks._paths import compute_path_size as remote_compute_path_size from .._worker_tasks._paths import delete_paths as remote_delete_paths @@ -20,13 +19,13 @@ @router.expose(reraise_if_error_type=None) async def compute_path_size( - app: FastAPI, + task_manager: TaskManager, job_id_data: AsyncJobNameData, location_id: LocationID, path: Path, ) -> AsyncJobGet: task_name = remote_compute_path_size.__name__ - task_uuid = await get_task_manager_from_app(app).submit_task( + task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ), @@ -41,13 +40,13 @@ async def compute_path_size( @router.expose(reraise_if_error_type=None) async def delete_paths( - app: FastAPI, + task_manager: TaskManager, job_id_data: AsyncJobNameData, location_id: LocationID, paths: set[Path], ) -> AsyncJobGet: task_name = remote_delete_paths.__name__ - task_uuid = await get_task_manager_from_app(app).submit_task( + task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ), diff --git a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py index 6b0b27f87a5..bd144179cd2 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py @@ -1,4 +1,3 @@ -from fastapi import FastAPI from models_library.api_schemas_rpc_async_jobs.async_jobs import ( AsyncJobGet, AsyncJobNameData, @@ -6,9 +5,9 @@ from models_library.api_schemas_storage.storage_schemas import FoldersBody from models_library.api_schemas_webserver.storage import PathToExport from servicelib.celery.models import TaskMetadata, TasksQueue +from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RPCRouter -from ...modules.celery import get_task_manager_from_app from .._worker_tasks._simcore_s3 import deep_copy_files_from_project, export_data router = RPCRouter() @@ -16,12 +15,12 @@ @router.expose(reraise_if_error_type=None) async def copy_folders_from_project( - app: FastAPI, + task_manager: TaskManager, job_id_data: AsyncJobNameData, body: FoldersBody, ) -> AsyncJobGet: task_name = deep_copy_files_from_project.__name__ - task_uuid = await get_task_manager_from_app(app).submit_task( + task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ), @@ -35,10 +34,12 @@ async def copy_folders_from_project( @router.expose() async def start_export_data( - app: FastAPI, job_id_data: AsyncJobNameData, paths_to_export: list[PathToExport] + task_manager: TaskManager, + job_id_data: AsyncJobNameData, + paths_to_export: list[PathToExport], ) -> AsyncJobGet: task_name = export_data.__name__ - task_uuid = await get_task_manager_from_app(app).submit_task( + task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ephemeral=False, diff --git a/services/storage/src/simcore_service_storage/api/rpc/routes.py b/services/storage/src/simcore_service_storage/api/rpc/routes.py index db6469ed380..ebf1ba60411 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/routes.py +++ b/services/storage/src/simcore_service_storage/api/rpc/routes.py @@ -1,12 +1,14 @@ import logging +from celery_library.rpc import _async_jobs from fastapi import FastAPI from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from servicelib.logging_utils import log_context from servicelib.rabbitmq import RPCRouter +from simcore_service_storage.modules.celery import get_task_manager_from_app from ...modules.rabbitmq import get_rabbitmq_rpc_server -from . import _async_jobs, _paths, _simcore_s3 +from . import _paths, _simcore_s3 _logger = logging.getLogger(__name__) @@ -18,7 +20,7 @@ ] -def setup_rpc_api_routes(app: FastAPI) -> None: +def setup_rpc_routes(app: FastAPI) -> None: async def startup() -> None: with log_context( _logger, @@ -26,7 +28,10 @@ async def startup() -> None: msg="Storage startup RPC API Routes", ): rpc_server = get_rabbitmq_rpc_server(app) + task_manager = get_task_manager_from_app(app) for router in ROUTERS: - await rpc_server.register_router(router, STORAGE_RPC_NAMESPACE, app) + await rpc_server.register_router( + router, STORAGE_RPC_NAMESPACE, task_manager=task_manager + ) app.add_event_handler("startup", startup) diff --git a/services/storage/src/simcore_service_storage/core/application.py b/services/storage/src/simcore_service_storage/core/application.py index 987878d32cf..2a1fe8246fb 100644 --- a/services/storage/src/simcore_service_storage/core/application.py +++ b/services/storage/src/simcore_service_storage/core/application.py @@ -32,7 +32,7 @@ APP_WORKER_STARTED_BANNER_MSG, ) from ..api.rest.routes import setup_rest_api_routes -from ..api.rpc.routes import setup_rpc_api_routes +from ..api.rpc.routes import setup_rpc_routes from ..dsm import setup_dsm from ..dsm_cleaner import setup_dsm_cleaner from ..exceptions.handlers import set_exception_handlers @@ -90,12 +90,12 @@ def create_app(settings: ApplicationSettings) -> FastAPI: # noqa: C901 setup_s3(app) setup_client_session(app) - if not settings.STORAGE_WORKER_MODE: + if settings.STORAGE_CELERY and not settings.STORAGE_WORKER_MODE: setup_rabbitmq(app) - setup_rpc_api_routes(app) - assert settings.STORAGE_CELERY # nosec setup_task_manager(app, celery_settings=settings.STORAGE_CELERY) + + setup_rpc_routes(app) setup_rest_api_long_running_tasks_for_uploads(app) setup_rest_api_routes(app, API_VTAG) set_exception_handlers(app) diff --git a/services/storage/src/simcore_service_storage/modules/celery/__init__.py b/services/storage/src/simcore_service_storage/modules/celery/__init__.py index 45795d329a3..684262d3f9b 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/__init__.py +++ b/services/storage/src/simcore_service_storage/modules/celery/__init__.py @@ -13,7 +13,7 @@ def setup_task_manager(app: FastAPI, celery_settings: CelerySettings) -> None: async def on_startup() -> None: - app.state.celery_client = await create_task_manager( + app.state.task_manager = await create_task_manager( create_app(celery_settings), celery_settings ) @@ -24,7 +24,7 @@ async def on_startup() -> None: def get_task_manager_from_app(app: FastAPI) -> CeleryTaskManager: - assert hasattr(app.state, "celery_client") # nosec - celery_client = app.state.celery_client - assert isinstance(celery_client, CeleryTaskManager) # nosec - return celery_client + assert hasattr(app.state, "task_manager") # nosec + task_manager = app.state.task_manager + assert isinstance(task_manager, CeleryTaskManager) # nosec + return task_manager