diff --git a/packages/celery-library/src/celery_library/backends/redis.py b/packages/celery-library/src/celery_library/backends/redis.py index 9878cd5e063f..d9d7c37a1685 100644 --- a/packages/celery-library/src/celery_library/backends/redis.py +++ b/packages/celery-library/src/celery_library/backends/redis.py @@ -1,7 +1,7 @@ import contextlib import logging from datetime import timedelta -from typing import Final +from typing import TYPE_CHECKING, Final from models_library.progress_bar import ProgressReport from pydantic import ValidationError @@ -9,16 +9,14 @@ Task, TaskFilter, TaskID, + TaskInfoStore, TaskMetadata, - TaskUUID, + Wildcard, ) from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types -from ..utils import build_task_id_prefix - _CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-" _CELERY_TASK_ID_KEY_ENCODING = "utf-8" -_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":" _CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000 _CELERY_TASK_METADATA_KEY: Final[str] = "metadata" _CELERY_TASK_PROGRESS_KEY: Final[str] = "progress" @@ -88,17 +86,14 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: return None async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: - search_key = ( - _CELERY_TASK_INFO_PREFIX - + build_task_id_prefix(task_filter) - + _CELERY_TASK_ID_KEY_SEPARATOR + search_key = _CELERY_TASK_INFO_PREFIX + task_filter.create_task_id( + task_uuid=Wildcard() ) - search_key_len = len(search_key) keys: list[str] = [] pipeline = self._redis_client_sdk.redis.pipeline() async for key in self._redis_client_sdk.redis.scan_iter( - match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH + match=search_key, count=_CELERY_TASK_SCAN_COUNT_PER_BATCH ): # fake redis (tests) returns bytes, real redis returns str _key = ( @@ -120,7 +115,7 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: task_metadata = TaskMetadata.model_validate_json(raw_metadata) tasks.append( Task( - uuid=TaskUUID(key[search_key_len:]), + uuid=TaskFilter.get_task_uuid(key), metadata=task_metadata, ) ) @@ -143,3 +138,7 @@ async def task_exists(self, task_id: TaskID) -> bool: n = await self._redis_client_sdk.redis.exists(_build_key(task_id)) assert isinstance(n, int) # nosec return n > 0 + + +if TYPE_CHECKING: + _: type[TaskInfoStore] = RedisTaskInfoStore diff --git a/packages/celery-library/src/celery_library/rpc/_async_jobs.py b/packages/celery-library/src/celery_library/rpc/_async_jobs.py index 6fb45336fdcd..e0b077077f12 100644 --- a/packages/celery-library/src/celery_library/rpc/_async_jobs.py +++ b/packages/celery-library/src/celery_library/rpc/_async_jobs.py @@ -134,9 +134,8 @@ async def result( @router.expose(reraise_if_error_type=(JobSchedulerError,)) async def list_jobs( - task_manager: TaskManager, filter_: str, job_filter: AsyncJobFilter + task_manager: TaskManager, job_filter: AsyncJobFilter ) -> list[AsyncJobGet]: - _ = filter_ assert task_manager # nosec task_filter = TaskFilter.model_validate(job_filter.model_dump()) try: diff --git a/packages/celery-library/src/celery_library/task_manager.py b/packages/celery-library/src/celery_library/task_manager.py index 04e18a291583..182ae53e66d3 100644 --- a/packages/celery-library/src/celery_library/task_manager.py +++ b/packages/celery-library/src/celery_library/task_manager.py @@ -22,7 +22,6 @@ from settings_library.celery import CelerySettings from .errors import TaskNotFoundError -from .utils import build_task_id _logger = logging.getLogger(__name__) @@ -50,7 +49,7 @@ async def submit_task( msg=f"Submit {task_metadata.name=}: {task_filter=} {task_params=}", ): task_uuid = uuid4() - task_id = build_task_id(task_filter, task_uuid) + task_id = task_filter.create_task_id(task_uuid=task_uuid) self._celery_app.send_task( task_metadata.name, task_id=task_id, @@ -74,7 +73,7 @@ async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> Non logging.DEBUG, msg=f"task cancellation: {task_filter=} {task_uuid=}", ): - task_id = build_task_id(task_filter, task_uuid) + task_id = task_filter.create_task_id(task_uuid=task_uuid) if not await self.task_exists(task_id): raise TaskNotFoundError(task_id=task_id) @@ -96,7 +95,7 @@ async def get_task_result( logging.DEBUG, msg=f"Get task result: {task_filter=} {task_uuid=}", ): - task_id = build_task_id(task_filter, task_uuid) + task_id = task_filter.create_task_id(task_uuid=task_uuid) if not await self.task_exists(task_id): raise TaskNotFoundError(task_id=task_id) @@ -139,7 +138,7 @@ async def get_task_status( logging.DEBUG, msg=f"Getting task status: {task_filter=} {task_uuid=}", ): - task_id = build_task_id(task_filter, task_uuid) + task_id = task_filter.create_task_id(task_uuid=task_uuid) if not await self.task_exists(task_id): raise TaskNotFoundError(task_id=task_id) diff --git a/packages/celery-library/src/celery_library/utils.py b/packages/celery-library/src/celery_library/utils.py index 79910df1c568..4d8bad73b96d 100644 --- a/packages/celery-library/src/celery_library/utils.py +++ b/packages/celery-library/src/celery_library/utils.py @@ -1,26 +1,8 @@ -from typing import Final - from celery import Celery # type: ignore[import-untyped] from servicelib.celery.app_server import BaseAppServer -from servicelib.celery.models import TaskFilter, TaskID, TaskUUID _APP_SERVER_KEY = "app_server" -_TASK_ID_KEY_DELIMITATOR: Final[str] = ":" - - -def build_task_id_prefix(task_filter: TaskFilter) -> str: - filter_dict = task_filter.model_dump() - return _TASK_ID_KEY_DELIMITATOR.join( - [f"{filter_dict[key]}" for key in sorted(filter_dict)] - ) - - -def build_task_id(task_filter: TaskFilter, task_uuid: TaskUUID) -> TaskID: - return _TASK_ID_KEY_DELIMITATOR.join( - [build_task_id_prefix(task_filter), f"{task_uuid}"] - ) - def get_app_server(app: Celery) -> BaseAppServer: app_server = app.conf[_APP_SERVER_KEY] diff --git a/packages/celery-library/tests/unit/test_async_jobs.py b/packages/celery-library/tests/unit/test_async_jobs.py index cc72bd6b75ed..1622866f2424 100644 --- a/packages/celery-library/tests/unit/test_async_jobs.py +++ b/packages/celery-library/tests/unit/test_async_jobs.py @@ -258,7 +258,6 @@ async def test_async_jobs_workflow( jobs = await async_jobs.list_jobs( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, - filter_="", # currently not used job_filter=job_filter, ) assert len(jobs) > 0 @@ -311,7 +310,6 @@ async def test_async_jobs_cancel( jobs = await async_jobs.list_jobs( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, - filter_="", # currently not used job_filter=job_filter, ) assert async_job_get.job_id not in [job.job_id for job in jobs] diff --git a/packages/celery-library/tests/unit/test_tasks.py b/packages/celery-library/tests/unit/test_tasks.py index a9a5ad5c6a17..80a474ad04d8 100644 --- a/packages/celery-library/tests/unit/test_tasks.py +++ b/packages/celery-library/tests/unit/test_tasks.py @@ -18,16 +18,21 @@ from celery_library.task_manager import CeleryTaskManager from celery_library.utils import get_app_server from common_library.errors_classes import OsparcErrorMixin +from faker import Faker from models_library.progress_bar import ProgressReport from servicelib.celery.models import ( TaskFilter, TaskID, TaskMetadata, TaskState, + TaskUUID, + Wildcard, ) from servicelib.logging_utils import log_context from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed +_faker = Faker() + _logger = logging.getLogger(__name__) pytest_simcore_core_services_selection = ["redis"] @@ -199,5 +204,53 @@ async def test_listing_task_uuids_contains_submitted_task( tasks = await celery_task_manager.list_tasks(task_filter) assert any(task.uuid == task_uuid for task in tasks) - tasks = await celery_task_manager.list_tasks(task_filter) - assert any(task.uuid == task_uuid for task in tasks) + tasks = await celery_task_manager.list_tasks(task_filter) + assert any(task.uuid == task_uuid for task in tasks) + + +async def test_filtering_listing_tasks( + celery_task_manager: CeleryTaskManager, +): + class MyFilter(TaskFilter): + user_id: int + product_name: str | Wildcard + client_app: str | Wildcard + + user_id = 42 + expected_task_uuids: set[TaskUUID] = set() + + for _ in range(5): + task_filter = MyFilter( + user_id=user_id, + product_name=_faker.word(), + client_app=_faker.word(), + ) + task_uuid = await celery_task_manager.submit_task( + TaskMetadata( + name=dreamer_task.__name__, + ), + task_filter=task_filter, + ) + expected_task_uuids.add(task_uuid) + + for _ in range(3): + task_filter = MyFilter( + user_id=_faker.pyint(min_value=100, max_value=200), + product_name=_faker.word(), + client_app=_faker.word(), + ) + await celery_task_manager.submit_task( + TaskMetadata( + name=dreamer_task.__name__, + ), + task_filter=task_filter, + ) + + search_filter = MyFilter( + user_id=user_id, + product_name=Wildcard(), + client_app=Wildcard(), + ) + tasks = await celery_task_manager.list_tasks(search_filter) + assert expected_task_uuids == {task.uuid for task in tasks} + await asyncio.sleep(5 * 60) diff --git a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py index 4c37ad32df48..6faa7b1f2b68 100644 --- a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py +++ b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py @@ -67,7 +67,7 @@ class AsyncJobFilter(AsyncJobFilterBase): product_name: ProductName user_id: UserID - client_name: Annotated[ + client_name: Annotated[ # this is the name of the app which *submits* the async job. It is mainly used for filtering purposes str, StringConstraints(min_length=1, pattern=r"^[^\s]+$"), ] diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py b/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py index 72ce830784a5..b4b3d771efc1 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py @@ -13,8 +13,6 @@ AsyncJobGet, ) from models_library.api_schemas_webserver.storage import PathToExport -from models_library.products import ProductName -from models_library.users import UserID from pydantic import TypeAdapter, validate_call from pytest_mock import MockType from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient @@ -27,14 +25,12 @@ async def start_export_data( self, rabbitmq_rpc_client: RabbitMQRPCClient | MockType, *, - user_id: UserID, - product_name: ProductName, paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], + job_filter: AsyncJobFilter, ) -> tuple[AsyncJobGet, AsyncJobFilter]: assert rabbitmq_rpc_client - assert user_id - assert product_name + assert job_filter assert paths_to_export assert export_as diff --git a/packages/service-library/src/servicelib/celery/models.py b/packages/service-library/src/servicelib/celery/models.py index db1a07c80eea..1090b3d50823 100644 --- a/packages/service-library/src/servicelib/celery/models.py +++ b/packages/service-library/src/servicelib/celery/models.py @@ -1,20 +1,103 @@ import datetime from enum import StrEnum -from typing import Annotated, Final, Protocol, TypeAlias +from typing import Annotated, Any, Final, Protocol, Self, TypeAlias, TypeVar from uuid import UUID from models_library.progress_bar import ProgressReport -from pydantic import BaseModel, ConfigDict, StringConstraints +from pydantic import BaseModel, ConfigDict, StringConstraints, model_validator from pydantic.config import JsonDict +ModelType = TypeVar("ModelType", bound=BaseModel) + TaskID: TypeAlias = str TaskName: TypeAlias = Annotated[ str, StringConstraints(strip_whitespace=True, min_length=1) ] TaskUUID: TypeAlias = UUID +_TASK_ID_KEY_DELIMITATOR: Final[str] = ":" +_WILDCARD: Final[str] = "*" +_FORBIDDEN_CHARS = (_WILDCARD, _TASK_ID_KEY_DELIMITATOR, "=") + + +class Wildcard: + def __str__(self) -> str: + return _WILDCARD + + +class TaskFilter(BaseModel): + """ + Class for associating metadata with a celery task. The implementation is very flexible and allows "clients" to define their own metadata. + The class exposes a filtering mechanism to list tasks using wildcards. + + Example usage: + class MyTaskFilter(TaskFilter): + user_id: int | Wildcard + product_name: int | Wildcard + client_name: str + + Listing tasks using the filter `MyTaskFilter(user_id=123, product_name=Wildcard(), client_name="my-app")` will return all tasks with + user_id 123, any product_name submitted from my-app. + + If the metadata schema is known, the class allows deserializing the metadata (recreate_as_model). I.e. one can recover the metadata from the task: + metadata -> task_uuid -> metadata + + """ + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + @model_validator(mode="after") + def _check_valid_filters(self) -> Self: + for key, value in self.model_dump().items(): + # forbidden keys + if any(x in key for x in _FORBIDDEN_CHARS): + raise ValueError(f"Invalid filter key: '{key}'") + # forbidden values + if not isinstance(value, Wildcard) and any( + x in f"{value}" for x in _FORBIDDEN_CHARS + ): + raise ValueError(f"Invalid filter value for key '{key}': '{value}'") + return self + + def _build_task_id_prefix(self) -> str: + filter_dict = self.model_dump() + return _TASK_ID_KEY_DELIMITATOR.join( + [f"{key}={filter_dict[key]}" for key in sorted(filter_dict)] + ) + def create_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID: + return _TASK_ID_KEY_DELIMITATOR.join( + [ + self._build_task_id_prefix(), + f"task_uuid={task_uuid}", + ] + ) -class TaskFilter(BaseModel): ... + @classmethod + def recreate_as_model(cls, task_id: TaskID, schema: type[ModelType]) -> ModelType: + filter_dict = cls._recreate_data(task_id) + return schema.model_validate(filter_dict) + + @classmethod + def _recreate_data(cls, task_id: TaskID) -> dict[str, Any]: + """Recreates the filter data from a task_id string + WARNING: does not validate types. For that use `recreate_model` instead + """ + try: + parts = task_id.split(_TASK_ID_KEY_DELIMITATOR) + return { + key: value + for part in parts[:-1] + if (key := part.split("=")[0]) and (value := part.split("=")[1]) + } + except (IndexError, ValueError) as err: + raise ValueError(f"Invalid task_id format: {task_id}") from err + + @classmethod + def get_task_uuid(cls, task_id: TaskID) -> TaskUUID: + try: + return UUID(task_id.split(_TASK_ID_KEY_DELIMITATOR)[-1].split("=")[1]) + except (IndexError, ValueError) as err: + raise ValueError(f"Invalid task_id format: {task_id}") from err class TaskState(StrEnum): diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py index 81cead539c9c..0c79097852fb 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py @@ -92,13 +92,11 @@ async def list_jobs( rabbitmq_rpc_client: RabbitMQRPCClient, *, rpc_namespace: RPCNamespace, - filter_: str, job_filter: AsyncJobFilter, ) -> list[AsyncJobGet]: _result: list[AsyncJobGet] = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("list_jobs"), - filter_=filter_, job_filter=job_filter, timeout_s=_DEFAULT_TIMEOUT_S, ) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/_utils.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/_utils.py deleted file mode 100644 index 6330d16cd065..000000000000 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Final - -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter -from models_library.products import ProductName -from models_library.users import UserID - -ASYNC_JOB_CLIENT_NAME: Final[str] = "STORAGE" - - -def get_async_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter: - return AsyncJobFilter( - user_id=user_id, - product_name=product_name, - client_name=ASYNC_JOB_CLIENT_NAME, - ) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py index 3e9f30c5c5fb..bd205479e5ab 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py @@ -5,26 +5,21 @@ AsyncJobGet, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE -from models_library.products import ProductName from models_library.projects_nodes_io import LocationID from models_library.rabbitmq_basic_types import RPCMethodName -from models_library.users import UserID from pydantic import TypeAdapter from ..._client_rpc import RabbitMQRPCClient from ..async_jobs.async_jobs import submit -from ._utils import get_async_job_filter async def compute_path_size( client: RabbitMQRPCClient, *, - user_id: UserID, - product_name: ProductName, location_id: LocationID, path: Path, + job_filter: AsyncJobFilter, ) -> tuple[AsyncJobGet, AsyncJobFilter]: - job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, @@ -39,12 +34,10 @@ async def compute_path_size( async def delete_paths( client: RabbitMQRPCClient, *, - user_id: UserID, - product_name: ProductName, location_id: LocationID, paths: set[Path], + job_filter: AsyncJobFilter, ) -> tuple[AsyncJobGet, AsyncJobFilter]: - job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py index 00cfd9d15353..463b535667f0 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py @@ -7,24 +7,16 @@ from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.api_schemas_storage.storage_schemas import FoldersBody from models_library.api_schemas_webserver.storage import PathToExport -from models_library.products import ProductName from models_library.rabbitmq_basic_types import RPCMethodName -from models_library.users import UserID from pydantic import TypeAdapter from ... import RabbitMQRPCClient from ..async_jobs.async_jobs import submit -from ._utils import get_async_job_filter async def copy_folders_from_project( - client: RabbitMQRPCClient, - *, - user_id: UserID, - product_name: ProductName, - body: FoldersBody, + client: RabbitMQRPCClient, *, body: FoldersBody, job_filter: AsyncJobFilter ) -> tuple[AsyncJobGet, AsyncJobFilter]: - job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, @@ -40,12 +32,10 @@ async def copy_folders_from_project( async def start_export_data( rabbitmq_rpc_client: RabbitMQRPCClient, *, - user_id: UserID, - product_name: ProductName, paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], + job_filter: AsyncJobFilter ) -> tuple[AsyncJobGet, AsyncJobFilter]: - job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, diff --git a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py index a8d3ac3b6f7e..9105283f20ee 100644 --- a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py +++ b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py @@ -29,7 +29,7 @@ "rabbit", ] -_ASYNC_JOB_CLIENT_NAME: Final[str] = "PYTEST_CLIENT_NAME" +_ASYNC_JOB_CLIENT_NAME: Final[str] = "pytest_client_name" @pytest.fixture @@ -102,11 +102,8 @@ async def result( } ) - async def list_jobs( - self, filter_: str, job_filter: AsyncJobFilter - ) -> list[AsyncJobGet]: + async def list_jobs(self, job_filter: AsyncJobFilter) -> list[AsyncJobGet]: assert job_filter - assert filter_ is not None return [ AsyncJobGet( @@ -173,7 +170,6 @@ async def test_list_jobs( await list_jobs( rpc_client, rpc_namespace=namespace, - filter_="", job_filter=job_filter, ) diff --git a/packages/service-library/tests/test_celery.py b/packages/service-library/tests/test_celery.py new file mode 100644 index 000000000000..d9ef42f0e2e5 --- /dev/null +++ b/packages/service-library/tests/test_celery.py @@ -0,0 +1,82 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=protected-access +import pydantic +import pytest +from faker import Faker +from pydantic import BaseModel +from servicelib.celery.models import TaskFilter, TaskUUID + +_faker = Faker() + + +@pytest.fixture +def task_filter_data() -> dict[str, str | int | bool | None | list[str]]: + return { + "string": _faker.word(), + "int": _faker.random_int(), + "bool": _faker.boolean(), + "none": None, + "uuid": _faker.uuid4(), + "list": [_faker.word() for _ in range(3)], + } + + +async def test_task_filter_serialization( + task_filter_data: dict[str, str | int | bool | None | list[str]], +): + task_filter = TaskFilter.model_validate(task_filter_data) + assert task_filter.model_dump() == task_filter_data + assert task_filter.model_dump() == task_filter_data + + +async def test_task_filter_sorting_key_not_serialized(): + + keys = ["a", "b"] + task_filter = TaskFilter.model_validate( + { + "a": _faker.random_int(), + "b": _faker.word(), + } + ) + expected_key = ":".join([f"{k}={getattr(task_filter, k)}" for k in sorted(keys)]) + assert task_filter._build_task_id_prefix() == expected_key + + +async def test_task_filter_task_uuid( + task_filter_data: dict[str, str | int | bool | None | list[str]], +): + task_filter = TaskFilter.model_validate(task_filter_data) + task_uuid = TaskUUID(_faker.uuid4()) + task_id = task_filter.create_task_id(task_uuid) + assert TaskFilter.get_task_uuid(task_id=task_id) == task_uuid + + +async def test_create_task_filter_from_task_id(): + + class MyModel(BaseModel): + _int: int + _bool: bool + _str: str + _list: list[str] + + mymodel = MyModel(_int=1, _bool=True, _str="test", _list=["a", "b"]) + task_filter = TaskFilter.model_validate(mymodel.model_dump()) + task_uuid = TaskUUID(_faker.uuid4()) + task_id = task_filter.create_task_id(task_uuid) + assert TaskFilter.recreate_as_model(task_id=task_id, schema=MyModel) == mymodel + + +@pytest.mark.parametrize( + "bad_data", + [ + {"foo": "bar:baz"}, + {"foo": "bar=baz"}, + {"foo:bad": "bar"}, + {"foo=bad": "bar"}, + {"foo": ":baz"}, + {"foo": "=baz"}, + ], +) +def test_task_filter_validator_raises_on_forbidden_chars(bad_data): + with pytest.raises(pydantic.ValidationError): + TaskFilter.model_validate(bad_data) diff --git a/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py b/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py index 1fa0ccfb3e4f..86c37a0a7af6 100644 --- a/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py +++ b/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py @@ -1,10 +1,6 @@ -from typing import Final - from celery_library.task_manager import CeleryTaskManager from fastapi import FastAPI -ASYNC_JOB_CLIENT_NAME: Final[str] = "API_SERVER" - def get_task_manager(app: FastAPI) -> CeleryTaskManager: assert hasattr(app.state, "task_manager") # nosec diff --git a/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py b/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py index 0421f8ae8ad2..a622e7fd4a22 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py @@ -29,6 +29,7 @@ from ..._service_function_jobs import FunctionJobService from ..._service_functions import FunctionService from ..._service_jobs import JobService +from ...clients.celery_task_manager import get_task_filter from ...exceptions.function_errors import FunctionJobProjectMissingError from ...models.domain.functions import PageRegisteredFunctionJobWithorWithoutStatus from ...models.pagination import PaginationParams @@ -55,7 +56,6 @@ FMSG_CHANGELOG_NEW_IN_VERSION, create_route_description, ) -from .tasks import _get_task_filter _logger = getLogger(__name__) @@ -294,7 +294,7 @@ async def function_job_status( ): if task_id := function_job.job_creation_task_id: task_manager = get_task_manager(app) - task_filter = _get_task_filter(user_id, product_name) + task_filter = get_task_filter(user_id, product_name) task_status = await task_manager.get_task_status( task_uuid=TaskUUID(task_id), task_filter=task_filter ) diff --git a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py index 14cfcadbc294..89cee7b84eda 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py @@ -18,13 +18,12 @@ RegisteredFunctionJob, RegisteredFunctionJobCollection, ) -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter from models_library.functions import FunctionJobCollection, FunctionJobID from models_library.products import ProductName from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.users import UserID -from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata, TasksQueue +from servicelib.celery.models import TaskID, TaskMetadata, TasksQueue from servicelib.fastapi.dependencies import get_reverse_url_mapper from servicelib.utils import limited_gather @@ -33,6 +32,7 @@ from ...celery_worker.worker_tasks.functions_tasks import ( run_function as run_function_task, ) +from ...clients.celery_task_manager import get_task_filter from ...exceptions.function_errors import FunctionJobCacheNotFoundError from ...models.pagination import Page, PaginationParams from ...models.schemas.errors import ErrorGet @@ -44,7 +44,7 @@ get_current_user_id, get_product_name, ) -from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME, get_task_manager +from ..dependencies.celery import get_task_manager from ..dependencies.services import get_function_job_service, get_function_service from ..dependencies.webserver_rpc import get_wb_api_rpc_client from ._constants import ( @@ -368,12 +368,9 @@ async def run_function( ) # run function in celery task - job_filter = AsyncJobFilter( - user_id=user_identity.user_id, - product_name=user_identity.product_name, - client_name=ASYNC_JOB_CLIENT_NAME, + task_filter = get_task_filter( + user_id=user_identity.user_id, product_name=user_identity.product_name ) - task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_name = run_function_task.__name__ task_uuid = await task_manager.submit_task( diff --git a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py index 36663efccf2d..440f1db35a1a 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py @@ -10,19 +10,19 @@ TaskStatus, ) from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobId, ) from models_library.products import ProductName from models_library.users import UserID -from servicelib.celery.models import TaskFilter, TaskState, TaskUUID +from servicelib.celery.models import TaskState, TaskUUID from servicelib.fastapi.dependencies import get_app from servicelib.logging_errors import create_troubleshootting_log_kwargs +from ...clients.celery_task_manager import get_task_filter from ...models.schemas.base import ApiServerEnvelope from ...models.schemas.errors import ErrorGet from ..dependencies.authentication import get_current_user_id, get_product_name -from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME, get_task_manager +from ..dependencies.celery import get_task_manager from ._constants import ( FMSG_CHANGELOG_NEW_IN_VERSION, create_route_description, @@ -32,13 +32,6 @@ _logger = logging.getLogger(__name__) -def _get_task_filter(user_id: UserID, product_name: ProductName) -> TaskFilter: - job_filter = AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name=ASYNC_JOB_CLIENT_NAME - ) - return TaskFilter.model_validate(job_filter.model_dump()) - - _DEFAULT_TASK_STATUS_CODES: dict[int | str, dict[str, Any]] = { status.HTTP_500_INTERNAL_SERVER_ERROR: { "description": "Internal server error", @@ -68,7 +61,7 @@ async def list_tasks( task_manager = get_task_manager(app) tasks = await task_manager.list_tasks( - task_filter=_get_task_filter(user_id, product_name), + task_filter=get_task_filter(user_id, product_name), ) app_router = app.router @@ -110,7 +103,7 @@ async def get_task_status( task_manager = get_task_manager(app) task_status = await task_manager.get_task_status( - task_filter=_get_task_filter(user_id, product_name), + task_filter=get_task_filter(user_id, product_name), task_uuid=TaskUUID(f"{task_id}"), ) @@ -145,7 +138,7 @@ async def cancel_task( task_manager = get_task_manager(app) await task_manager.cancel_task( - task_filter=_get_task_filter(user_id, product_name), + task_filter=get_task_filter(user_id, product_name), task_uuid=TaskUUID(f"{task_id}"), ) @@ -175,7 +168,7 @@ async def get_task_result( product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - task_filter = _get_task_filter(user_id, product_name) + task_filter = get_task_filter(user_id, product_name) task_status = await task_manager.get_task_status( task_filter=task_filter, diff --git a/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py index 8f9f002e1d52..87675ac122fd 100644 --- a/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py +++ b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py @@ -5,16 +5,28 @@ from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types, register_pydantic_types from fastapi import FastAPI +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter +from models_library.products import ProductName +from models_library.users import UserID +from servicelib.celery.models import TaskFilter from servicelib.logging_utils import log_context from servicelib.redis import RedisClientSDK from settings_library.celery import CelerySettings from settings_library.redis import RedisDatabase +from .._meta import APP_NAME from ..celery_worker.worker_tasks.tasks import pydantic_types_to_register _logger = logging.getLogger(__name__) +def get_task_filter(user_id: UserID, product_name: ProductName) -> TaskFilter: + job_filter = AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name=APP_NAME + ) + return TaskFilter.model_validate(job_filter.model_dump()) + + def setup_task_manager(app: FastAPI, settings: CelerySettings) -> None: async def on_startup() -> None: with log_context(_logger, logging.INFO, "Setting up Celery"): diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py index 429a67ef1619..129a83577992 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py @@ -86,12 +86,9 @@ async def result( JobSchedulerError: TaskSchedulerError, } ) - async def list_jobs( - self, *, filter_: str, job_filter: AsyncJobFilter - ) -> list[AsyncJobGet]: + async def list_jobs(self, *, job_filter: AsyncJobFilter) -> list[AsyncJobGet]: return await async_jobs.list_jobs( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, - filter_=filter_, job_filter=job_filter, ) diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/storage.py b/services/api-server/src/simcore_service_api_server/services_rpc/storage.py index 32c8d38e49bd..38b7abaa3b4a 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/storage.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/storage.py @@ -1,18 +1,28 @@ from dataclasses import dataclass from functools import partial -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobGet +from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, + AsyncJobGet, +) from models_library.api_schemas_webserver.storage import PathToExport from models_library.products import ProductName from models_library.users import UserID from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient from servicelib.rabbitmq.rpc_interfaces.storage import simcore_s3 as storage_rpc +from .._meta import APP_NAME from ..exceptions.service_errors_utils import service_exception_mapper _exception_mapper = partial(service_exception_mapper, service_name="Storage") +def get_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter: + return AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name=APP_NAME + ) + + @dataclass(frozen=True, kw_only=True) class StorageService: _rpc_client: RabbitMQRPCClient @@ -26,9 +36,11 @@ async def start_data_export( ) -> AsyncJobGet: async_job_get, _ = await storage_rpc.start_export_data( self._rpc_client, - user_id=self._user_id, - product_name=self._product_name, paths_to_export=paths_to_export, export_as="download_link", + job_filter=get_job_filter( + user_id=self._user_id, + product_name=self._product_name, + ), ) return async_job_get diff --git a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py index c28685ee2f6c..dd54eefd0268 100644 --- a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py +++ b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py @@ -23,7 +23,6 @@ from fastapi import FastAPI, status from httpx import AsyncClient, BasicAuth, HTTPStatusError from models_library.api_schemas_long_running_tasks.tasks import TaskResult, TaskStatus -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter from models_library.functions import ( FunctionClass, FunctionID, @@ -50,7 +49,6 @@ from simcore_service_api_server._meta import API_VTAG from simcore_service_api_server.api.dependencies.authentication import Identity from simcore_service_api_server.api.dependencies.celery import ( - ASYNC_JOB_CLIENT_NAME, get_task_manager, ) from simcore_service_api_server.celery_worker.worker_tasks.functions_tasks import ( @@ -65,6 +63,7 @@ JobPricingSpecification, NodeID, ) +from simcore_service_api_server.services_rpc.storage import get_job_filter from tenacity import ( AsyncRetrying, retry_if_exception_type, @@ -277,16 +276,13 @@ async def test_celery_error_propagation( app: FastAPI, client: AsyncClient, auth: BasicAuth, + user_identity: Identity, with_api_server_celery_worker: TestWorkController, ): - user_identity = Identity( - user_id=_faker.pyint(), product_name=_faker.word(), email=_faker.email() - ) - job_filter = AsyncJobFilter( + job_filter = get_job_filter( user_id=user_identity.user_id, product_name=user_identity.product_name, - client_name=ASYNC_JOB_CLIENT_NAME, ) task_manager = get_task_manager(app=app) task_uuid = await task_manager.submit_task( diff --git a/services/storage/src/simcore_service_storage/api/rest/_files.py b/services/storage/src/simcore_service_storage/api/rest/_files.py index 9981570fc5fa..d7530f514366 100644 --- a/services/storage/src/simcore_service_storage/api/rest/_files.py +++ b/services/storage/src/simcore_service_storage/api/rest/_files.py @@ -3,7 +3,6 @@ from urllib.parse import quote from fastapi import APIRouter, Depends, Header, Request -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter from models_library.api_schemas_storage.storage_schemas import ( FileMetaDataGet, FileMetaDataGetv010, @@ -18,6 +17,7 @@ ) from models_library.generics import Envelope from models_library.projects_nodes_io import LocationID, StorageFileID +from models_library.users import UserID from pydantic import AnyUrl, ByteSize, TypeAdapter from servicelib.aiohttp import status from servicelib.celery.models import TaskFilter, TaskMetadata, TaskUUID @@ -25,6 +25,7 @@ from servicelib.logging_utils import log_context from yarl import URL +from ..._meta import APP_NAME from ...dsm import get_dsm_provider from ...exceptions.errors import FileMetaDataNotFoundError from ...models import ( @@ -41,7 +42,13 @@ from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file from .dependencies.celery import get_task_manager -_ASYNC_JOB_CLIENT_NAME: Final[str] = "STORAGE" + +def _get_task_filter(*, user_id: UserID) -> TaskFilter: + _data = { + "user_id": user_id, + "client_name": APP_NAME, + } + return TaskFilter().model_validate(_data) _logger = logging.getLogger(__name__) @@ -287,18 +294,13 @@ async def complete_upload_file( # NOTE: completing a multipart upload on AWS can take up to several minutes # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - job_filter = AsyncJobFilter( - user_id=query_params.user_id, - product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here - client_name=_ASYNC_JOB_CLIENT_NAME, - ) - task_filter = TaskFilter.model_validate(job_filter.model_dump()) + task_filter = _get_task_filter(user_id=query_params.user_id) task_uuid = await task_manager.submit_task( TaskMetadata( name=remote_complete_upload_file.__name__, ), task_filter=task_filter, - user_id=job_filter.user_id, + user_id=query_params.user_id, location_id=location_id, file_id=file_id, body=body, @@ -345,12 +347,7 @@ async def is_completed_upload_file( # therefore we wait a bit to see if it completes fast and return a 204 # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - job_filter = AsyncJobFilter( - user_id=query_params.user_id, - product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here - client_name=_ASYNC_JOB_CLIENT_NAME, - ) - task_filter = TaskFilter.model_validate(job_filter.model_dump()) + task_filter = _get_task_filter(user_id=query_params.user_id) task_status = await task_manager.get_task_status( task_filter=task_filter, task_uuid=TaskUUID(future_id) ) diff --git a/services/storage/tests/unit/test_rpc_handlers_paths.py b/services/storage/tests/unit/test_rpc_handlers_paths.py index 04cbb47692cf..c1acc0719f9b 100644 --- a/services/storage/tests/unit/test_rpc_handlers_paths.py +++ b/services/storage/tests/unit/test_rpc_handlers_paths.py @@ -71,10 +71,11 @@ async def _assert_compute_path_size( ) -> ByteSize: async_job, async_job_name = await compute_path_size( storage_rpc_client, - product_name=product_name, - user_id=user_id, location_id=location_id, path=path, + job_filter=AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + ), ) async for job_composed_result in wait_and_get_result( storage_rpc_client, @@ -107,10 +108,11 @@ async def _assert_delete_paths( ) -> None: async_job, async_job_name = await delete_paths( storage_rpc_client, - product_name=product_name, - user_id=user_id, location_id=location_id, paths=paths, + job_filter=AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + ), ) async for job_composed_result in wait_and_get_result( storage_rpc_client, diff --git a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py index 2f76ea6135df..59d5a6d5586f 100644 --- a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py +++ b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py @@ -24,7 +24,10 @@ from faker import Faker from fastapi import FastAPI from fastapi.encoders import jsonable_encoder -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobResult +from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, + AsyncJobResult, +) from models_library.api_schemas_rpc_async_jobs.exceptions import JobError from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.api_schemas_storage.storage_schemas import ( @@ -84,11 +87,14 @@ async def _request_copy_folders( ) as ctx: async_job_get, async_job_name = await copy_folders_from_project( rpc_client, - user_id=user_id, - product_name=product_name, body=FoldersBody( source=source_project, destination=dst_project, nodes_map=nodes_map ), + job_filter=AsyncJobFilter( + user_id=user_id, + product_name=product_name, + client_name="PYTEST_CLIENT_NAME", + ), ) async for async_job_result in wait_and_get_result( @@ -526,10 +532,13 @@ async def _request_start_export_data( ) as ctx: async_job_get, async_job_name = await start_export_data( rpc_client, - user_id=user_id, - product_name=product_name, paths_to_export=paths_to_export, export_as=export_as, + job_filter=AsyncJobFilter( + user_id=user_id, + product_name=product_name, + client_name="PYTEST_CLIENT_NAME", + ), ) async for async_job_result in wait_and_get_result( diff --git a/services/web/server/src/simcore_service_webserver/constants.py b/services/web/server/src/simcore_service_webserver/constants.py index 42e3a4f1102b..4fe86fe3f177 100644 --- a/services/web/server/src/simcore_service_webserver/constants.py +++ b/services/web/server/src/simcore_service_webserver/constants.py @@ -11,6 +11,8 @@ ) from servicelib.request_keys import RQT_USERID_KEY +from ._meta import APP_NAME + # Application storage keys APP_PRODUCTS_KEY: Final[str] = f"{__name__ }.APP_PRODUCTS_KEY" @@ -51,8 +53,6 @@ "Please try again shortly. If the issue persists, contact support.", _version=1 ) -ASYNC_JOB_CLIENT_NAME: Final[str] = "WEBSERVER" - __all__: tuple[str, ...] = ( "APP_AIOPG_ENGINE_KEY", diff --git a/services/web/server/src/simcore_service_webserver/storage/_rest.py b/services/web/server/src/simcore_service_webserver/storage/_rest.py index 6836b5f80dcb..1fec3432d332 100644 --- a/services/web/server/src/simcore_service_webserver/storage/_rest.py +++ b/services/web/server/src/simcore_service_webserver/storage/_rest.py @@ -12,7 +12,9 @@ from models_library.api_schemas_long_running_tasks.tasks import ( TaskGet, ) -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobGet +from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobGet, +) from models_library.api_schemas_storage.storage_schemas import ( FileUploadCompleteResponse, FileUploadCompletionBody, @@ -55,6 +57,7 @@ from ..rabbitmq import get_rabbitmq_rpc_client from ..security.decorators import permission_required from ..tasks._exception_handlers import handle_export_data_exceptions +from ..utils import get_job_filter from .schemas import StorageFileIDStr from .settings import StorageSettings, get_plugin_settings @@ -206,10 +209,12 @@ async def compute_path_size(request: web.Request) -> web.Response: rabbitmq_rpc_client = get_rabbitmq_rpc_client(request.app) async_job, _ = await remote_compute_path_size( rabbitmq_rpc_client, - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, location_id=path_params.location_id, path=path_params.path, + job_filter=get_job_filter( + user_id=req_ctx.user_id, + product_name=req_ctx.product_name, + ), ) return _create_data_response_from_async_job(request, async_job) @@ -229,10 +234,12 @@ async def batch_delete_paths(request: web.Request): rabbitmq_rpc_client = get_rabbitmq_rpc_client(request.app) async_job, _ = await remote_delete_paths( rabbitmq_rpc_client, - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, location_id=path_params.location_id, paths=body.paths, + job_filter=get_job_filter( + user_id=req_ctx.user_id, + product_name=req_ctx.product_name, + ), ) return _create_data_response_from_async_job(request, async_job) @@ -494,10 +501,12 @@ def allow_only_simcore(cls, v: int) -> int: ) async_job_rpc_get, _ = await start_export_data( rabbitmq_rpc_client=rabbitmq_rpc_client, - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, paths_to_export=export_data_post.paths, export_as="path", + job_filter=get_job_filter( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ), ) _job_id = f"{async_job_rpc_get.job_id}" return create_data_response( diff --git a/services/web/server/src/simcore_service_webserver/storage/api.py b/services/web/server/src/simcore_service_webserver/storage/api.py index 48b5877a0e8f..839c26bc89c0 100644 --- a/services/web/server/src/simcore_service_webserver/storage/api.py +++ b/services/web/server/src/simcore_service_webserver/storage/api.py @@ -7,7 +7,6 @@ from typing import Any, Final from aiohttp import ClientError, ClientSession, ClientTimeout, web -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.api_schemas_storage.storage_schemas import ( FileLocation, @@ -30,10 +29,10 @@ ) from yarl import URL -from ..constants import ASYNC_JOB_CLIENT_NAME from ..projects.models import ProjectDict from ..projects.utils import NodesMap from ..rabbitmq import get_rabbitmq_rpc_client +from ..utils import get_job_filter from .settings import StorageSettings, get_plugin_settings _logger = logging.getLogger(__name__) @@ -119,10 +118,9 @@ async def copy_data_folders_from_project( rabbitmq_client, method_name="copy_folders_from_project", rpc_namespace=STORAGE_RPC_NAMESPACE, - job_filter=AsyncJobFilter( + job_filter=get_job_filter( user_id=user_id, product_name=product_name, - client_name=ASYNC_JOB_CLIENT_NAME, ), body=TypeAdapter(FoldersBody).validate_python( { diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index 45c6457bc582..0ba2b2d65ec5 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -15,7 +15,6 @@ TaskStatus, ) from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobId, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE @@ -32,12 +31,12 @@ from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs from .._meta import API_VTAG -from ..constants import ASYNC_JOB_CLIENT_NAME from ..login.decorators import login_required from ..long_running_tasks.plugin import webserver_request_context_decorator from ..models import AuthenticatedRequestContext from ..rabbitmq import get_rabbitmq_rpc_client from ..security.decorators import permission_required +from ..utils import get_job_filter from ._exception_handlers import handle_export_data_exceptions log = logging.getLogger(__name__) @@ -71,12 +70,10 @@ async def get_async_jobs(request: web.Request) -> web.Response: user_async_jobs = await async_jobs.list_jobs( rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, - job_filter=AsyncJobFilter( + job_filter=get_job_filter( user_id=_req_ctx.user_id, product_name=_req_ctx.product_name, - client_name=ASYNC_JOB_CLIENT_NAME, ), - filter_="", ) return create_data_response( [ @@ -122,10 +119,9 @@ async def get_async_job_status(request: web.Request) -> web.Response: rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_filter=AsyncJobFilter( + job_filter=get_job_filter( user_id=_req_ctx.user_id, product_name=_req_ctx.product_name, - client_name=ASYNC_JOB_CLIENT_NAME, ), ) _task_id = f"{async_job_rpc_status.job_id}" @@ -159,10 +155,9 @@ async def cancel_async_job(request: web.Request) -> web.Response: rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_filter=AsyncJobFilter( + job_filter=get_job_filter( user_id=_req_ctx.user_id, product_name=_req_ctx.product_name, - client_name=ASYNC_JOB_CLIENT_NAME, ), ) @@ -188,10 +183,9 @@ class _PathParams(BaseModel): rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_filter=AsyncJobFilter( + job_filter=get_job_filter( user_id=_req_ctx.user_id, product_name=_req_ctx.product_name, - client_name=ASYNC_JOB_CLIENT_NAME, ), ) diff --git a/services/web/server/src/simcore_service_webserver/utils.py b/services/web/server/src/simcore_service_webserver/utils.py index 8928deb21c60..fee35eff5054 100644 --- a/services/web/server/src/simcore_service_webserver/utils.py +++ b/services/web/server/src/simcore_service_webserver/utils.py @@ -9,10 +9,15 @@ from datetime import datetime from common_library.error_codes import ErrorCodeStr +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter +from models_library.products import ProductName +from models_library.users import UserID from typing_extensions import ( # https://docs.pydantic.dev/latest/api/standard_library_types/#typeddict TypedDict, ) +from ._meta import APP_NAME + _logger = logging.getLogger(__name__) DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" @@ -120,3 +125,9 @@ def compose_support_error_msg( ) return ". ".join(sentences) + + +def get_job_filter(*, user_id: UserID, product_name: ProductName) -> AsyncJobFilter: + return AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name=APP_NAME + )