diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/_redis.py index 37a9a415cd5d..6b30b5e5fc45 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/_redis.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from servicelib.celery.models import ( Task, - TaskContext, + TaskFilter, TaskID, TaskMetadata, TaskUUID, @@ -82,10 +82,10 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ) return None - async def list_tasks(self, task_context: TaskContext) -> list[Task]: + async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: search_key = ( _CELERY_TASK_INFO_PREFIX - + build_task_id_prefix(task_context) + + build_task_id_prefix(task_filter) + _CELERY_TASK_ID_KEY_SEPARATOR ) search_key_len = len(search_key) diff --git a/packages/celery-library/src/celery_library/rpc/_async_jobs.py b/packages/celery-library/src/celery_library/rpc/_async_jobs.py index 4972142a457f..ea7cb5876a5d 100644 --- a/packages/celery-library/src/celery_library/rpc/_async_jobs.py +++ b/packages/celery-library/src/celery_library/rpc/_async_jobs.py @@ -4,9 +4,9 @@ from celery.exceptions import CeleryError # type: ignore[import-untyped] from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, AsyncJobId, - AsyncJobNameData, AsyncJobResult, AsyncJobStatus, ) @@ -16,7 +16,7 @@ JobNotDoneError, JobSchedulerError, ) -from servicelib.celery.models import TaskState +from servicelib.celery.models import TaskFilter, TaskState from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_catch from servicelib.rabbitmq import RPCRouter @@ -32,13 +32,14 @@ @router.expose(reraise_if_error_type=(JobSchedulerError,)) async def cancel( - task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData + task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter ): assert task_manager # nosec - assert job_id_data # nosec + assert job_filter # nosec + task_filter = TaskFilter.model_validate(job_filter.model_dump()) try: await task_manager.cancel_task( - task_context=job_id_data.model_dump(), + task_filter=task_filter, task_uuid=job_id, ) except CeleryError as exc: @@ -47,14 +48,15 @@ async def cancel( @router.expose(reraise_if_error_type=(JobSchedulerError,)) async def status( - task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData + task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter ) -> AsyncJobStatus: assert task_manager # nosec - assert job_id_data # nosec + assert job_filter # nosec + task_filter = TaskFilter.model_validate(job_filter.model_dump()) try: task_status = await task_manager.get_task_status( - task_context=job_id_data.model_dump(), + task_filter=task_filter, task_uuid=job_id, ) except CeleryError as exc: @@ -76,21 +78,23 @@ async def status( ) ) async def result( - task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData + task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter ) -> AsyncJobResult: assert task_manager # nosec assert job_id # nosec - assert job_id_data # nosec + assert job_filter # nosec + + task_filter = TaskFilter.model_validate(job_filter.model_dump()) try: _status = await task_manager.get_task_status( - task_context=job_id_data.model_dump(), + task_filter=task_filter, task_uuid=job_id, ) if not _status.is_done: raise JobNotDoneError(job_id=job_id) _result = await task_manager.get_task_result( - task_context=job_id_data.model_dump(), + task_filter=task_filter, task_uuid=job_id, ) except CeleryError as exc: @@ -123,13 +127,14 @@ async def result( @router.expose(reraise_if_error_type=(JobSchedulerError,)) async def list_jobs( - task_manager: TaskManager, filter_: str, job_id_data: AsyncJobNameData + task_manager: TaskManager, filter_: str, job_filter: AsyncJobFilter ) -> list[AsyncJobGet]: _ = filter_ assert task_manager # nosec + task_filter = TaskFilter.model_validate(job_filter.model_dump()) try: tasks = await task_manager.list_tasks( - task_context=job_id_data.model_dump(), + task_filter=task_filter, ) except CeleryError as exc: raise JobSchedulerError(exc=f"{exc}") from exc diff --git a/packages/celery-library/src/celery_library/task_manager.py b/packages/celery-library/src/celery_library/task_manager.py index 7f14d4ddd34f..72ca039f6ca2 100644 --- a/packages/celery-library/src/celery_library/task_manager.py +++ b/packages/celery-library/src/celery_library/task_manager.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any from uuid import uuid4 from celery import Celery # type: ignore[import-untyped] @@ -11,7 +11,7 @@ from models_library.progress_bar import ProgressReport from servicelib.celery.models import ( Task, - TaskContext, + TaskFilter, TaskID, TaskInfoStore, TaskMetadata, @@ -19,6 +19,7 @@ TaskStatus, TaskUUID, ) +from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_context from settings_library.celery import CelerySettings @@ -41,16 +42,16 @@ async def submit_task( self, task_metadata: TaskMetadata, *, - task_context: TaskContext, + task_filter: TaskFilter, **task_params, ) -> TaskUUID: with log_context( _logger, logging.DEBUG, - msg=f"Submit {task_metadata.name=}: {task_context=} {task_params=}", + msg=f"Submit {task_metadata.name=}: {task_filter=} {task_params=}", ): task_uuid = uuid4() - task_id = build_task_id(task_context, task_uuid) + task_id = build_task_id(task_filter, task_uuid) self._celery_app.send_task( task_metadata.name, task_id=task_id, @@ -72,14 +73,14 @@ async def submit_task( def _abort_task(self, task_id: TaskID) -> None: AbortableAsyncResult(task_id, app=self._celery_app).abort() - async def cancel_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None: + async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None: with log_context( _logger, logging.DEBUG, - msg=f"task cancellation: {task_context=} {task_uuid=}", + msg=f"task cancellation: {task_filter=} {task_uuid=}", ): - task_id = build_task_id(task_context, task_uuid) - if not (await self.get_task_status(task_context, task_uuid)).is_done: + task_id = build_task_id(task_filter, task_uuid) + if not (await self.get_task_status(task_filter, task_uuid)).is_done: await self._abort_task(task_id) await self._task_info_store.remove_task(task_id) @@ -88,14 +89,14 @@ def _forget_task(self, task_id: TaskID) -> None: AbortableAsyncResult(task_id, app=self._celery_app).forget() async def get_task_result( - self, task_context: TaskContext, task_uuid: TaskUUID + self, task_filter: TaskFilter, task_uuid: TaskUUID ) -> Any: with log_context( _logger, logging.DEBUG, - msg=f"Get task result: {task_context=} {task_uuid=}", + msg=f"Get task result: {task_filter=} {task_uuid=}", ): - task_id = build_task_id(task_context, task_uuid) + task_id = build_task_id(task_filter, task_uuid) async_result = self._celery_app.AsyncResult(task_id) result = async_result.result if async_result.ready(): @@ -106,10 +107,10 @@ async def get_task_result( return result async def _get_task_progress_report( - self, task_context: TaskContext, task_uuid: TaskUUID, task_state: TaskState + self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState ) -> ProgressReport: if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED): - task_id = build_task_id(task_context, task_uuid) + task_id = build_task_id(task_filter, task_uuid) progress = await self._task_info_store.get_task_progress(task_id) if progress is not None: return progress @@ -131,33 +132,37 @@ def _get_task_celery_state(self, task_id: TaskID) -> TaskState: return TaskState(self._celery_app.AsyncResult(task_id).state) async def get_task_status( - self, task_context: TaskContext, task_uuid: TaskUUID + self, task_filter: TaskFilter, task_uuid: TaskUUID ) -> TaskStatus: with log_context( _logger, logging.DEBUG, - msg=f"Getting task status: {task_context=} {task_uuid=}", + msg=f"Getting task status: {task_filter=} {task_uuid=}", ): - task_id = build_task_id(task_context, task_uuid) + task_id = build_task_id(task_filter, task_uuid) task_state = await self._get_task_celery_state(task_id) return TaskStatus( task_uuid=task_uuid, task_state=task_state, progress_report=await self._get_task_progress_report( - task_context, task_uuid, task_state + task_filter, task_uuid, task_state ), ) - async def list_tasks(self, task_context: TaskContext) -> list[Task]: + async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: with log_context( _logger, logging.DEBUG, - msg=f"Listing tasks: {task_context=}", + msg=f"Listing tasks: {task_filter=}", ): - return await self._task_info_store.list_tasks(task_context) + return await self._task_info_store.list_tasks(task_filter) async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None: await self._task_info_store.set_task_progress( task_id=task_id, report=report, ) + + +if TYPE_CHECKING: + _: type[TaskManager] = CeleryTaskManager diff --git a/packages/celery-library/src/celery_library/utils.py b/packages/celery-library/src/celery_library/utils.py index 64da3e0c2483..79910df1c568 100644 --- a/packages/celery-library/src/celery_library/utils.py +++ b/packages/celery-library/src/celery_library/utils.py @@ -2,22 +2,23 @@ from celery import Celery # type: ignore[import-untyped] from servicelib.celery.app_server import BaseAppServer -from servicelib.celery.models import TaskContext, TaskID, TaskUUID +from servicelib.celery.models import TaskFilter, TaskID, TaskUUID _APP_SERVER_KEY = "app_server" _TASK_ID_KEY_DELIMITATOR: Final[str] = ":" -def build_task_id_prefix(task_context: TaskContext) -> str: +def build_task_id_prefix(task_filter: TaskFilter) -> str: + filter_dict = task_filter.model_dump() return _TASK_ID_KEY_DELIMITATOR.join( - [f"{task_context[key]}" for key in sorted(task_context)] + [f"{filter_dict[key]}" for key in sorted(filter_dict)] ) -def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID: +def build_task_id(task_filter: TaskFilter, task_uuid: TaskUUID) -> TaskID: return _TASK_ID_KEY_DELIMITATOR.join( - [build_task_id_prefix(task_context), f"{task_uuid}"] + [build_task_id_prefix(task_filter), f"{task_uuid}"] ) diff --git a/packages/celery-library/tests/unit/test_async_jobs.py b/packages/celery-library/tests/unit/test_async_jobs.py index 02c8362c1aae..4a646a1fdb46 100644 --- a/packages/celery-library/tests/unit/test_async_jobs.py +++ b/packages/celery-library/tests/unit/test_async_jobs.py @@ -16,8 +16,8 @@ from common_library.errors_classes import OsparcErrorMixin from faker import Faker from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, - AsyncJobNameData, ) from models_library.api_schemas_rpc_async_jobs.exceptions import ( JobAbortedError, @@ -27,7 +27,7 @@ from models_library.rabbitmq_basic_types import RPCNamespace from models_library.users import UserID from pydantic import TypeAdapter -from servicelib.celery.models import TaskID, TaskMetadata +from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RabbitMQRPCClient, RPCRouter from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs @@ -79,11 +79,12 @@ def product_name(faker: Faker) -> ProductName: @router.expose() async def rpc_sync_job( - task_manager: TaskManager, *, job_id_data: AsyncJobNameData, **kwargs: Any + task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any ) -> AsyncJobGet: task_name = sync_job.__name__ + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs + TaskMetadata(name=task_name), task_filter=task_filter, **kwargs ) return AsyncJobGet(job_id=task_uuid, job_name=task_name) @@ -91,11 +92,12 @@ async def rpc_sync_job( @router.expose() async def rpc_async_job( - task_manager: TaskManager, *, job_id_data: AsyncJobNameData, **kwargs: Any + task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any ) -> AsyncJobGet: task_name = async_job.__name__ + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs + TaskMetadata(name=task_name), task_filter=task_filter, **kwargs ) return AsyncJobGet(job_id=task_uuid, job_name=task_name) @@ -156,16 +158,18 @@ async def _start_task_via_rpc( user_id: UserID, product_name: ProductName, **kwargs: Any, -) -> tuple[AsyncJobGet, AsyncJobNameData]: - job_id_data = AsyncJobNameData(user_id=user_id, product_name=product_name) +) -> tuple[AsyncJobGet, AsyncJobFilter]: + job_filter = AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name="pytest_client" + ) async_job_get = await async_jobs.submit( rabbitmq_rpc_client=client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, method_name=rpc_task_name, - job_id_data=job_id_data, + job_filter=job_filter, **kwargs, ) - return async_job_get, job_id_data + return async_job_get, job_filter @pytest.fixture @@ -193,7 +197,7 @@ async def _wait_for_job( rpc_client: RabbitMQRPCClient, *, async_job_get: AsyncJobGet, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, stop_after: timedelta = timedelta(seconds=5), ) -> None: @@ -208,7 +212,7 @@ async def _wait_for_job( rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) assert ( result.done is True @@ -242,7 +246,7 @@ async def test_async_jobs_workflow( exposed_rpc_start: str, payload: Any, ): - async_job_get, job_id_data = await _start_task_via_rpc( + async_job_get, job_filter = await _start_task_via_rpc( async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, @@ -255,21 +259,21 @@ async def test_async_jobs_workflow( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, filter_="", # currently not used - job_id_data=job_id_data, + job_filter=job_filter, ) assert len(jobs) > 0 await _wait_for_job( async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, - job_id_data=job_id_data, + job_filter=job_filter, ) async_job_result = await async_jobs.result( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) assert async_job_result.result == payload @@ -288,7 +292,7 @@ async def test_async_jobs_cancel( product_name: ProductName, exposed_rpc_start: str, ): - async_job_get, job_id_data = await _start_task_via_rpc( + async_job_get, job_filter = await _start_task_via_rpc( async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, @@ -301,20 +305,20 @@ async def test_async_jobs_cancel( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) await _wait_for_job( async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, - job_id_data=job_id_data, + job_filter=job_filter, ) jobs = await async_jobs.list_jobs( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, filter_="", # currently not used - job_id_data=job_id_data, + job_filter=job_filter, ) assert async_job_get.job_id not in [job.job_id for job in jobs] @@ -323,7 +327,7 @@ async def test_async_jobs_cancel( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) @@ -353,7 +357,7 @@ async def test_async_jobs_raises( exposed_rpc_start: str, error: Exception, ): - async_job_get, job_id_data = await _start_task_via_rpc( + async_job_get, job_filter = await _start_task_via_rpc( async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, @@ -365,7 +369,7 @@ async def test_async_jobs_raises( await _wait_for_job( async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, - job_id_data=job_id_data, + job_filter=job_filter, stop_after=timedelta(minutes=1), ) @@ -374,7 +378,7 @@ async def test_async_jobs_raises( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) assert exc.value.exc_type == type(error).__name__ assert exc.value.exc_msg == f"{error}" diff --git a/packages/celery-library/tests/unit/test_tasks.py b/packages/celery-library/tests/unit/test_tasks.py index 4270efcc0655..a4edfb7540ae 100644 --- a/packages/celery-library/tests/unit/test_tasks.py +++ b/packages/celery-library/tests/unit/test_tasks.py @@ -20,7 +20,7 @@ from common_library.errors_classes import OsparcErrorMixin from models_library.progress_bar import ProgressReport from servicelib.celery.models import ( - TaskContext, + TaskFilter, TaskID, TaskMetadata, TaskState, @@ -93,13 +93,13 @@ def _(celery_app: Celery) -> None: async def test_submitting_task_calling_async_function_results_with_success_state( celery_task_manager: CeleryTaskManager, ): - task_context = TaskContext(user_id=42) + task_filter = TaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( name=fake_file_processor.__name__, ), - task_context=task_context, + task_filter=task_filter, files=[f"file{n}" for n in range(5)], ) @@ -109,27 +109,27 @@ async def test_submitting_task_calling_async_function_results_with_success_state stop=stop_after_delay(30), ): with attempt: - status = await celery_task_manager.get_task_status(task_context, task_uuid) + status = await celery_task_manager.get_task_status(task_filter, task_uuid) assert status.task_state == TaskState.SUCCESS assert ( - await celery_task_manager.get_task_status(task_context, task_uuid) + await celery_task_manager.get_task_status(task_filter, task_uuid) ).task_state == TaskState.SUCCESS assert ( - await celery_task_manager.get_task_result(task_context, task_uuid) + await celery_task_manager.get_task_result(task_filter, task_uuid) ) == "archive.zip" async def test_submitting_task_with_failure_results_with_error( celery_task_manager: CeleryTaskManager, ): - task_context = TaskContext(user_id=42) + task_filter = TaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( name=failure_task.__name__, ), - task_context=task_context, + task_filter=task_filter, ) for attempt in Retrying( @@ -140,29 +140,29 @@ async def test_submitting_task_with_failure_results_with_error( with attempt: raw_result = await celery_task_manager.get_task_result( - task_context, task_uuid + task_filter, task_uuid ) assert isinstance(raw_result, TransferrableCeleryError) - raw_result = await celery_task_manager.get_task_result(task_context, task_uuid) + raw_result = await celery_task_manager.get_task_result(task_filter, task_uuid) assert f"{raw_result}" == "Something strange happened: BOOM!" async def test_cancelling_a_running_task_aborts_and_deletes( celery_task_manager: CeleryTaskManager, ): - task_context = TaskContext(user_id=42) + task_filter = TaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( name=dreamer_task.__name__, ), - task_context=task_context, + task_filter=task_filter, ) await asyncio.sleep(3.0) - await celery_task_manager.cancel_task(task_context, task_uuid) + await celery_task_manager.cancel_task(task_filter, task_uuid) for attempt in Retrying( retry=retry_if_exception_type(AssertionError), @@ -170,28 +170,26 @@ async def test_cancelling_a_running_task_aborts_and_deletes( stop=stop_after_delay(30), ): with attempt: - progress = await celery_task_manager.get_task_status( - task_context, task_uuid - ) + progress = await celery_task_manager.get_task_status(task_filter, task_uuid) assert progress.task_state == TaskState.ABORTED assert ( - await celery_task_manager.get_task_status(task_context, task_uuid) + await celery_task_manager.get_task_status(task_filter, task_uuid) ).task_state == TaskState.ABORTED - assert task_uuid not in await celery_task_manager.list_tasks(task_context) + assert task_uuid not in await celery_task_manager.list_tasks(task_filter) async def test_listing_task_uuids_contains_submitted_task( celery_task_manager: CeleryTaskManager, ): - task_context = TaskContext(user_id=42) + task_filter = TaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( name=dreamer_task.__name__, ), - task_context=task_context, + task_filter=task_filter, ) for attempt in Retrying( @@ -200,8 +198,8 @@ async def test_listing_task_uuids_contains_submitted_task( stop=stop_after_delay(10), ): with attempt: - tasks = await celery_task_manager.list_tasks(task_context) + tasks = await celery_task_manager.list_tasks(task_filter) assert any(task.uuid == task_uuid for task in tasks) - tasks = await celery_task_manager.list_tasks(task_context) + tasks = await celery_task_manager.list_tasks(task_filter) assert any(task.uuid == task_uuid for task in tasks) diff --git a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py index 3b19513ca361..6c6c06093287 100644 --- a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py +++ b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py @@ -1,7 +1,7 @@ from typing import Annotated, Any, TypeAlias from uuid import UUID -from pydantic import BaseModel, StringConstraints +from pydantic import BaseModel, ConfigDict, StringConstraints from ..products import ProductName from ..progress_bar import ProgressReport @@ -13,6 +13,12 @@ ] +class AsyncJobFilterBase(BaseModel): + """Base class for async job filters""" + + model_config = ConfigDict(extra="forbid") + + class AsyncJobStatus(BaseModel): job_id: AsyncJobId progress: ProgressReport @@ -33,8 +39,12 @@ class AsyncJobAbort(BaseModel): job_id: AsyncJobId -class AsyncJobNameData(BaseModel): +class AsyncJobFilter(AsyncJobFilterBase): """Data for controlling access to an async job""" product_name: ProductName user_id: UserID + client_name: Annotated[ + str, + StringConstraints(min_length=1, pattern=r"^[^\s]+$"), + ] diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py index 369396153efe..f9e4193e7d86 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py @@ -3,9 +3,9 @@ from dataclasses import dataclass from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, AsyncJobId, - AsyncJobNameData, AsyncJobResult, AsyncJobStatus, ) @@ -28,7 +28,7 @@ async def cancel( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ) -> None: if self.exception is not None: raise self.exception @@ -41,7 +41,7 @@ async def status( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ) -> AsyncJobStatus: if self.exception is not None: raise self.exception @@ -63,7 +63,7 @@ async def result( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ) -> AsyncJobResult: if self.exception is not None: raise self.exception @@ -75,7 +75,7 @@ async def list_jobs( rabbitmq_rpc_client: RabbitMQRPCClient | MockType, *, rpc_namespace: RPCNamespace, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, filter_: str = "", ) -> list[AsyncJobGet]: if self.exception is not None: diff --git a/packages/service-library/src/servicelib/celery/models.py b/packages/service-library/src/servicelib/celery/models.py index 8bc744fcb3eb..407565533776 100644 --- a/packages/service-library/src/servicelib/celery/models.py +++ b/packages/service-library/src/servicelib/celery/models.py @@ -1,12 +1,11 @@ import datetime from enum import StrEnum -from typing import Annotated, Any, Protocol, TypeAlias +from typing import Annotated, Protocol, TypeAlias from uuid import UUID from models_library.progress_bar import ProgressReport from pydantic import BaseModel, StringConstraints -TaskContext: TypeAlias = dict[str, Any] TaskID: TypeAlias = str TaskName: TypeAlias = Annotated[ str, StringConstraints(strip_whitespace=True, min_length=1) @@ -14,6 +13,9 @@ TaskUUID: TypeAlias = UUID +class TaskFilter(BaseModel): ... + + class TaskState(StrEnum): PENDING = "PENDING" STARTED = "STARTED" @@ -56,7 +58,7 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ... async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ... - async def list_tasks(self, task_context: TaskContext) -> list[Task]: ... + async def list_tasks(self, task_context: TaskFilter) -> list[Task]: ... async def remove_task(self, task_id: TaskID) -> None: ... diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index f8e178348c06..93612e6845fe 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -4,7 +4,7 @@ from ..celery.models import ( Task, - TaskContext, + TaskFilter, TaskID, TaskMetadata, TaskStatus, @@ -14,22 +14,22 @@ class TaskManager(Protocol): async def submit_task( - self, task_metadata: TaskMetadata, *, task_context: TaskContext, **task_param + self, task_metadata: TaskMetadata, *, task_filter: TaskFilter, **task_param ) -> TaskUUID: ... async def cancel_task( - self, task_context: TaskContext, task_uuid: TaskUUID + self, task_filter: TaskFilter, task_uuid: TaskUUID ) -> None: ... async def get_task_result( - self, task_context: TaskContext, task_uuid: TaskUUID + self, task_filter: TaskFilter, task_uuid: TaskUUID ) -> Any: ... async def get_task_status( - self, task_context: TaskContext, task_uuid: TaskUUID + self, task_filter: TaskFilter, task_uuid: TaskUUID ) -> TaskStatus: ... - async def list_tasks(self, task_context: TaskContext) -> list[Task]: ... + async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ... async def set_task_progress( self, task_id: TaskID, report: ProgressReport diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py index f6e1954c9368..81cead539c9c 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py @@ -6,9 +6,9 @@ from attr import dataclass from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, AsyncJobId, - AsyncJobNameData, AsyncJobResult, AsyncJobStatus, ) @@ -41,13 +41,13 @@ async def cancel( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ) -> None: await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("cancel"), job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, timeout_s=_DEFAULT_TIMEOUT_S, ) @@ -57,13 +57,13 @@ async def status( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ) -> AsyncJobStatus: _result = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("status"), job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, timeout_s=_DEFAULT_TIMEOUT_S, ) assert isinstance(_result, AsyncJobStatus) @@ -75,13 +75,13 @@ async def result( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ) -> AsyncJobResult: _result = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("result"), job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, timeout_s=_DEFAULT_TIMEOUT_S, ) assert isinstance(_result, AsyncJobResult) @@ -93,13 +93,13 @@ async def list_jobs( *, rpc_namespace: RPCNamespace, filter_: str, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ) -> list[AsyncJobGet]: _result: list[AsyncJobGet] = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("list_jobs"), filter_=filter_, - job_id_data=job_id_data, + job_filter=job_filter, timeout_s=_DEFAULT_TIMEOUT_S, ) return _result @@ -110,13 +110,13 @@ async def submit( *, rpc_namespace: RPCNamespace, method_name: str, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, **kwargs, ) -> AsyncJobGet: _result = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python(method_name), - job_id_data=job_id_data, + job_filter=job_filter, **kwargs, timeout_s=_DEFAULT_TIMEOUT_S, ) @@ -140,7 +140,7 @@ async def _wait_for_completion( rpc_namespace: RPCNamespace, method_name: RPCMethodName, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, client_timeout: datetime.timedelta, ) -> AsyncGenerator[AsyncJobStatus, None]: try: @@ -156,7 +156,7 @@ async def _wait_for_completion( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) yield job_status if not job_status.done: @@ -191,7 +191,7 @@ async def wait_and_get_result( rpc_namespace: RPCNamespace, method_name: str, job_id: AsyncJobId, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, client_timeout: datetime.timedelta, ) -> AsyncGenerator[AsyncJobComposedResult, None]: """when a job is already submitted this will wait for its completion @@ -203,7 +203,7 @@ async def wait_and_get_result( rpc_namespace=rpc_namespace, method_name=method_name, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, client_timeout=client_timeout, ): assert job_status is not None # nosec @@ -217,7 +217,7 @@ async def wait_and_get_result( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, ), ) except (TimeoutError, CancelledError) as error: @@ -226,7 +226,7 @@ async def wait_and_get_result( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) except Exception as exc: raise exc from error # NOSONAR @@ -238,7 +238,7 @@ async def submit_and_wait( *, rpc_namespace: RPCNamespace, method_name: str, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, client_timeout: datetime.timedelta, **kwargs, ) -> AsyncGenerator[AsyncJobComposedResult, None]: @@ -248,7 +248,7 @@ async def submit_and_wait( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, method_name=method_name, - job_id_data=job_id_data, + job_filter=job_filter, **kwargs, ) except (TimeoutError, CancelledError) as error: @@ -258,7 +258,7 @@ async def submit_and_wait( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=async_job_rpc_get.job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) except Exception as exc: raise exc from error @@ -269,7 +269,7 @@ async def submit_and_wait( rpc_namespace=rpc_namespace, method_name=method_name, job_id=async_job_rpc_get.job_id, - job_id_data=job_id_data, + job_filter=job_filter, client_timeout=client_timeout, ): yield wait_and_ diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/_utils.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/_utils.py new file mode 100644 index 000000000000..6330d16cd065 --- /dev/null +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/_utils.py @@ -0,0 +1,15 @@ +from typing import Final + +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter +from models_library.products import ProductName +from models_library.users import UserID + +ASYNC_JOB_CLIENT_NAME: Final[str] = "STORAGE" + + +def get_async_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter: + return AsyncJobFilter( + user_id=user_id, + product_name=product_name, + client_name=ASYNC_JOB_CLIENT_NAME, + ) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py index c1049bfc1bbb..3e9f30c5c5fb 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py @@ -1,17 +1,19 @@ from pathlib import Path from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, - AsyncJobNameData, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.products import ProductName from models_library.projects_nodes_io import LocationID from models_library.rabbitmq_basic_types import RPCMethodName from models_library.users import UserID +from pydantic import TypeAdapter from ..._client_rpc import RabbitMQRPCClient from ..async_jobs.async_jobs import submit +from ._utils import get_async_job_filter async def compute_path_size( @@ -21,17 +23,17 @@ async def compute_path_size( product_name: ProductName, location_id: LocationID, path: Path, -) -> tuple[AsyncJobGet, AsyncJobNameData]: - job_id_data = AsyncJobNameData(user_id=user_id, product_name=product_name) +) -> tuple[AsyncJobGet, AsyncJobFilter]: + job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, - method_name=RPCMethodName("compute_path_size"), - job_id_data=job_id_data, + method_name=TypeAdapter(RPCMethodName).validate_python("compute_path_size"), + job_filter=job_filter, location_id=location_id, path=path, ) - return async_job_rpc_get, job_id_data + return async_job_rpc_get, job_filter async def delete_paths( @@ -41,14 +43,14 @@ async def delete_paths( product_name: ProductName, location_id: LocationID, paths: set[Path], -) -> tuple[AsyncJobGet, AsyncJobNameData]: - job_id_data = AsyncJobNameData(user_id=user_id, product_name=product_name) +) -> tuple[AsyncJobGet, AsyncJobFilter]: + job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, - method_name=RPCMethodName("delete_paths"), - job_id_data=job_id_data, + method_name=TypeAdapter(RPCMethodName).validate_python("delete_paths"), + job_filter=job_filter, location_id=location_id, paths=paths, ) - return async_job_rpc_get, job_id_data + return async_job_rpc_get, job_filter diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py index df78448a5752..4b1a9cf18fce 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py @@ -1,6 +1,6 @@ from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, - AsyncJobNameData, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.api_schemas_storage.storage_schemas import FoldersBody @@ -12,6 +12,7 @@ from ... import RabbitMQRPCClient from ..async_jobs.async_jobs import submit +from ._utils import get_async_job_filter async def copy_folders_from_project( @@ -20,16 +21,18 @@ async def copy_folders_from_project( user_id: UserID, product_name: ProductName, body: FoldersBody, -) -> tuple[AsyncJobGet, AsyncJobNameData]: - job_id_data = AsyncJobNameData(user_id=user_id, product_name=product_name) +) -> tuple[AsyncJobGet, AsyncJobFilter]: + job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, - method_name=RPCMethodName("copy_folders_from_project"), - job_id_data=job_id_data, + method_name=TypeAdapter(RPCMethodName).validate_python( + "copy_folders_from_project" + ), + job_filter=job_filter, body=body, ) - return async_job_rpc_get, job_id_data + return async_job_rpc_get, job_filter async def start_export_data( @@ -38,13 +41,13 @@ async def start_export_data( user_id: UserID, product_name: ProductName, paths_to_export: list[PathToExport], -) -> tuple[AsyncJobGet, AsyncJobNameData]: - job_id_data = AsyncJobNameData(user_id=user_id, product_name=product_name) +) -> tuple[AsyncJobGet, AsyncJobFilter]: + job_filter = get_async_job_filter(user_id=user_id, product_name=product_name) async_job_rpc_get = await submit( rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=TypeAdapter(RPCMethodName).validate_python("start_export_data"), - job_id_data=job_id_data, + job_filter=job_filter, paths_to_export=paths_to_export, ) - return async_job_rpc_get, job_id_data + return async_job_rpc_get, job_filter diff --git a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py index 40455ee6d7f2..a8d3ac3b6f7e 100644 --- a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py +++ b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py @@ -2,14 +2,15 @@ import datetime from collections.abc import AsyncIterator from dataclasses import dataclass, field +from typing import Final import pytest from common_library.async_tools import cancel_wait_task from faker import Faker from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, AsyncJobId, - AsyncJobNameData, AsyncJobResult, AsyncJobStatus, ) @@ -28,6 +29,8 @@ "rabbit", ] +_ASYNC_JOB_CLIENT_NAME: Final[str] = "PYTEST_CLIENT_NAME" + @pytest.fixture def method_name(faker: Faker) -> RPCMethodName: @@ -35,10 +38,11 @@ def method_name(faker: Faker) -> RPCMethodName: @pytest.fixture -def job_id_data(faker: Faker) -> AsyncJobNameData: - return AsyncJobNameData( +def job_filter(faker: Faker) -> AsyncJobFilter: + return AsyncJobFilter( user_id=faker.pyint(min_value=1), product_name=faker.word(), + client_name=_ASYNC_JOB_CLIENT_NAME, ) @@ -68,9 +72,9 @@ def _get_task(self, job_id: AsyncJobId) -> asyncio.Task: raise JobMissingError(job_id=f"{job_id}") async def status( - self, job_id: AsyncJobId, job_id_data: AsyncJobNameData + self, job_id: AsyncJobId, job_filter: AsyncJobFilter ) -> AsyncJobStatus: - assert job_id_data + assert job_filter task = self._get_task(job_id) return AsyncJobStatus( job_id=job_id, @@ -78,32 +82,30 @@ async def status( done=task.done(), ) - async def cancel( - self, job_id: AsyncJobId, job_id_data: AsyncJobNameData - ) -> None: + async def cancel(self, job_id: AsyncJobId, job_filter: AsyncJobFilter) -> None: assert job_id - assert job_id_data + assert job_filter task = self._get_task(job_id) task.cancel() async def result( - self, job_id: AsyncJobId, job_id_data: AsyncJobNameData + self, job_id: AsyncJobId, job_filter: AsyncJobFilter ) -> AsyncJobResult: - assert job_id_data + assert job_filter task = self._get_task(job_id) assert task.done() return AsyncJobResult( result={ "data": task.result(), "job_id": job_id, - "job_id_data": job_id_data, + "job_filter": job_filter, } ) async def list_jobs( - self, filter_: str, job_id_data: AsyncJobNameData + self, filter_: str, job_filter: AsyncJobFilter ) -> list[AsyncJobGet]: - assert job_id_data + assert job_filter assert filter_ is not None return [ @@ -114,8 +116,8 @@ async def list_jobs( for t in self.tasks ] - async def submit(self, job_id_data: AsyncJobNameData) -> AsyncJobGet: - assert job_id_data + async def submit(self, job_filter: AsyncJobFilter) -> AsyncJobGet: + assert job_filter job_id = faker.uuid4(cast_to=None) self.tasks.append(asyncio.create_task(_slow_task(), name=f"{job_id}")) return AsyncJobGet(job_id=job_id, job_name="fake_job_name") @@ -145,7 +147,7 @@ async def test_async_jobs_methods( async_job_rpc_server: RabbitMQRPCClient, rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, job_id: AsyncJobId, method: str, ): @@ -157,7 +159,7 @@ async def test_async_jobs_methods( rpc_client, rpc_namespace=namespace, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) @@ -166,13 +168,13 @@ async def test_list_jobs( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ): await list_jobs( rpc_client, rpc_namespace=namespace, filter_="", - job_id_data=job_id_data, + job_filter=job_filter, ) @@ -181,13 +183,13 @@ async def test_submit( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ): await submit( rpc_client, rpc_namespace=namespace, method_name=method_name, - job_id_data=job_id_data, + job_filter=job_filter, ) @@ -195,14 +197,14 @@ async def test_submit_with_invalid_method_name( async_job_rpc_server: RabbitMQRPCClient, rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ): with pytest.raises(RemoteMethodNotRegisteredError): await submit( rpc_client, rpc_namespace=namespace, method_name=RPCMethodName("invalid_method_name"), - job_id_data=job_id_data, + job_filter=job_filter, ) @@ -211,14 +213,14 @@ async def test_submit_and_wait_properly_timesout( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ): with pytest.raises(TimeoutError): # noqa: PT012 async for _job_composed_result in submit_and_wait( rpc_client, rpc_namespace=namespace, method_name=method_name, - job_id_data=job_id_data, + job_filter=job_filter, client_timeout=datetime.timedelta(seconds=0.1), ): pass @@ -229,13 +231,13 @@ async def test_submit_and_wait( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, ): async for job_composed_result in submit_and_wait( rpc_client, rpc_namespace=namespace, method_name=method_name, - job_id_data=job_id_data, + job_filter=job_filter, client_timeout=datetime.timedelta(seconds=10), ): if not job_composed_result.done: @@ -243,10 +245,11 @@ async def test_submit_and_wait( await job_composed_result.result() assert job_composed_result.done assert job_composed_result.status.progress.actual_value == 1 - assert await job_composed_result.result() == AsyncJobResult( + result = await job_composed_result.result() + assert result == AsyncJobResult( result={ "data": None, "job_id": job_composed_result.status.job_id, - "job_id_data": job_id_data, + "job_filter": job_filter, } ) diff --git a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py index 2c3f6b2f66cd..ad80cb5daf46 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py @@ -1,5 +1,5 @@ import logging -from typing import Annotated, Any +from typing import Annotated, Any, Final from fastapi import APIRouter, Depends, FastAPI, status from models_library.api_schemas_long_running_tasks.base import TaskProgress @@ -9,8 +9,8 @@ TaskStatus, ) from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobId, - AsyncJobNameData, ) from models_library.products import ProductName from models_library.users import UserID @@ -26,12 +26,16 @@ create_route_description, ) +_ASYNC_JOB_CLIENT_NAME: Final[str] = "API_SERVER" + router = APIRouter() _logger = logging.getLogger(__name__) -def _get_job_id_data(user_id: UserID, product_name: ProductName) -> AsyncJobNameData: - return AsyncJobNameData(user_id=user_id, product_name=product_name) +def _get_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter: + return AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name=_ASYNC_JOB_CLIENT_NAME + ) _DEFAULT_TASK_STATUS_CODES: dict[int | str, dict[str, Any]] = { @@ -61,7 +65,7 @@ async def list_tasks( async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)], ): user_async_jobs = await async_jobs.list_jobs( - job_id_data=_get_job_id_data(user_id, product_name), + job_filter=_get_job_filter(user_id, product_name), filter_="", ) app_router = app.router @@ -102,7 +106,7 @@ async def get_task_status( ): async_job_rpc_status = await async_jobs.status( job_id=task_id, - job_id_data=_get_job_id_data(user_id, product_name), + job_filter=_get_job_filter(user_id, product_name), ) _task_id = f"{async_job_rpc_status.job_id}" return TaskStatus( @@ -134,7 +138,7 @@ async def cancel_task( ): await async_jobs.cancel( job_id=task_id, - job_id_data=_get_job_id_data(user_id, product_name), + job_filter=_get_job_filter(user_id, product_name), ) @@ -168,6 +172,6 @@ async def get_task_result( ): async_job_rpc_result = await async_jobs.result( job_id=task_id, - job_id_data=_get_job_id_data(user_id, product_name), + job_filter=_get_job_filter(user_id, product_name), ) return TaskResult(result=async_job_rpc_result.result, error=None) diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py index 9e263d755057..429a67ef1619 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py @@ -2,9 +2,9 @@ from dataclasses import dataclass from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, AsyncJobId, - AsyncJobNameData, AsyncJobResult, AsyncJobStatus, ) @@ -40,14 +40,12 @@ class AsyncJobClient: JobSchedulerError: TaskSchedulerError, } ) - async def cancel( - self, *, job_id: AsyncJobId, job_id_data: AsyncJobNameData - ) -> None: + async def cancel(self, *, job_id: AsyncJobId, job_filter: AsyncJobFilter) -> None: return await async_jobs.cancel( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) @_exception_mapper( @@ -56,13 +54,13 @@ async def cancel( } ) async def status( - self, *, job_id: AsyncJobId, job_id_data: AsyncJobNameData + self, *, job_id: AsyncJobId, job_filter: AsyncJobFilter ) -> AsyncJobStatus: return await async_jobs.status( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) @_exception_mapper( @@ -74,13 +72,13 @@ async def status( } ) async def result( - self, *, job_id: AsyncJobId, job_id_data: AsyncJobNameData + self, *, job_id: AsyncJobId, job_filter: AsyncJobFilter ) -> AsyncJobResult: return await async_jobs.result( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=job_id, - job_id_data=job_id_data, + job_filter=job_filter, ) @_exception_mapper( @@ -89,11 +87,11 @@ async def result( } ) async def list_jobs( - self, *, filter_: str, job_id_data: AsyncJobNameData + self, *, filter_: str, job_filter: AsyncJobFilter ) -> list[AsyncJobGet]: return await async_jobs.list_jobs( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, filter_=filter_, - job_id_data=job_id_data, + job_filter=job_filter, ) diff --git a/services/storage/src/simcore_service_storage/api/rest/_files.py b/services/storage/src/simcore_service_storage/api/rest/_files.py index 9092d8da0115..9981570fc5fa 100644 --- a/services/storage/src/simcore_service_storage/api/rest/_files.py +++ b/services/storage/src/simcore_service_storage/api/rest/_files.py @@ -3,7 +3,7 @@ from urllib.parse import quote from fastapi import APIRouter, Depends, Header, Request -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobNameData +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter from models_library.api_schemas_storage.storage_schemas import ( FileMetaDataGet, FileMetaDataGetv010, @@ -20,7 +20,7 @@ from models_library.projects_nodes_io import LocationID, StorageFileID from pydantic import AnyUrl, ByteSize, TypeAdapter from servicelib.aiohttp import status -from servicelib.celery.models import TaskMetadata, TaskUUID +from servicelib.celery.models import TaskFilter, TaskMetadata, TaskUUID from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_context from yarl import URL @@ -41,6 +41,9 @@ from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file from .dependencies.celery import get_task_manager +_ASYNC_JOB_CLIENT_NAME: Final[str] = "STORAGE" + + _logger = logging.getLogger(__name__) router = APIRouter( @@ -284,16 +287,18 @@ async def complete_upload_file( # NOTE: completing a multipart upload on AWS can take up to several minutes # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - async_job_name_data = AsyncJobNameData( + job_filter = AsyncJobFilter( user_id=query_params.user_id, product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here + client_name=_ASYNC_JOB_CLIENT_NAME, ) + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( TaskMetadata( name=remote_complete_upload_file.__name__, ), - task_context=async_job_name_data.model_dump(), - user_id=async_job_name_data.user_id, + task_filter=task_filter, + user_id=job_filter.user_id, location_id=location_id, file_id=file_id, body=body, @@ -340,18 +345,20 @@ async def is_completed_upload_file( # therefore we wait a bit to see if it completes fast and return a 204 # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - async_job_name_data = AsyncJobNameData( + job_filter = AsyncJobFilter( user_id=query_params.user_id, product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here + client_name=_ASYNC_JOB_CLIENT_NAME, ) + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_status = await task_manager.get_task_status( - task_context=async_job_name_data.model_dump(), task_uuid=TaskUUID(future_id) + task_filter=task_filter, task_uuid=TaskUUID(future_id) ) # first check if the task is in the app if task_status.is_done: task_result = TypeAdapter(FileMetaData).validate_python( await task_manager.get_task_result( - task_context=async_job_name_data.model_dump(), + task_filter=task_filter, task_uuid=TaskUUID(future_id), ) ) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_paths.py b/services/storage/src/simcore_service_storage/api/rpc/_paths.py index f4b0eae297db..b34da2e7e7f8 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_paths.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_paths.py @@ -2,11 +2,11 @@ from pathlib import Path from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, - AsyncJobNameData, ) from models_library.projects_nodes_io import LocationID -from servicelib.celery.models import TaskMetadata +from servicelib.celery.models import TaskFilter, TaskMetadata from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RPCRouter @@ -20,17 +20,18 @@ @router.expose(reraise_if_error_type=None) async def compute_path_size( task_manager: TaskManager, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, location_id: LocationID, path: Path, ) -> AsyncJobGet: task_name = remote_compute_path_size.__name__ + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ), - task_context=job_id_data.model_dump(), - user_id=job_id_data.user_id, + task_filter=task_filter, + user_id=job_filter.user_id, location_id=location_id, path=path, ) @@ -41,17 +42,18 @@ async def compute_path_size( @router.expose(reraise_if_error_type=None) async def delete_paths( task_manager: TaskManager, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, location_id: LocationID, paths: set[Path], ) -> AsyncJobGet: task_name = remote_delete_paths.__name__ + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ), - task_context=job_id_data.model_dump(), - user_id=job_id_data.user_id, + task_filter=task_filter, + user_id=job_filter.user_id, location_id=location_id, paths=paths, ) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py index bd144179cd23..6c9b77c5b99e 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py @@ -1,10 +1,10 @@ from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobGet, - AsyncJobNameData, ) from models_library.api_schemas_storage.storage_schemas import FoldersBody from models_library.api_schemas_webserver.storage import PathToExport -from servicelib.celery.models import TaskMetadata, TasksQueue +from servicelib.celery.models import TaskFilter, TaskMetadata, TasksQueue from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RPCRouter @@ -16,16 +16,17 @@ @router.expose(reraise_if_error_type=None) async def copy_folders_from_project( task_manager: TaskManager, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, body: FoldersBody, ) -> AsyncJobGet: task_name = deep_copy_files_from_project.__name__ + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ), - task_context=job_id_data.model_dump(), - user_id=job_id_data.user_id, + task_filter=task_filter, + user_id=job_filter.user_id, body=body, ) @@ -35,18 +36,19 @@ async def copy_folders_from_project( @router.expose() async def start_export_data( task_manager: TaskManager, - job_id_data: AsyncJobNameData, + job_filter: AsyncJobFilter, paths_to_export: list[PathToExport], ) -> AsyncJobGet: task_name = export_data.__name__ + task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( task_metadata=TaskMetadata( name=task_name, ephemeral=False, queue=TasksQueue.CPU_BOUND, ), - task_context=job_id_data.model_dump(), - user_id=job_id_data.user_id, + task_filter=task_filter, + user_id=job_filter.user_id, paths_to_export=paths_to_export, ) return AsyncJobGet(job_id=task_uuid, job_name=task_name) diff --git a/services/storage/tests/unit/test_rpc_handlers_paths.py b/services/storage/tests/unit/test_rpc_handlers_paths.py index cd5db4140372..04cbb47692cf 100644 --- a/services/storage/tests/unit/test_rpc_handlers_paths.py +++ b/services/storage/tests/unit/test_rpc_handlers_paths.py @@ -17,7 +17,7 @@ from faker import Faker from fastapi import FastAPI from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobNameData, + AsyncJobFilter, AsyncJobResult, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE @@ -81,7 +81,9 @@ async def _assert_compute_path_size( rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=RPCMethodName(compute_path_size.__name__), job_id=async_job.job_id, - job_id_data=AsyncJobNameData(user_id=user_id, product_name=product_name), + job_filter=AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + ), client_timeout=datetime.timedelta(seconds=120), ): if job_composed_result.done: @@ -115,7 +117,9 @@ async def _assert_delete_paths( rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=RPCMethodName(compute_path_size.__name__), job_id=async_job.job_id, - job_id_data=AsyncJobNameData(user_id=user_id, product_name=product_name), + job_filter=AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + ), client_timeout=datetime.timedelta(seconds=120), ): if job_composed_result.done: diff --git a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py index 7f0d87667f64..c3f8411ebea1 100644 --- a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py +++ b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py @@ -94,7 +94,7 @@ async def _request_copy_folders( rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=copy_folders_from_project.__name__, job_id=async_job_get.job_id, - job_id_data=async_job_name, + job_filter=async_job_name, client_timeout=client_timeout, ): ctx.logger.info("%s", f"<-- current state is {async_job_result=}") @@ -533,7 +533,7 @@ async def _request_start_export_data( rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=start_export_data.__name__, job_id=async_job_get.job_id, - job_id_data=async_job_name, + job_filter=async_job_name, client_timeout=client_timeout, ): ctx.logger.info("%s", f"<-- current state is {async_job_result=}") diff --git a/services/web/server/src/simcore_service_webserver/constants.py b/services/web/server/src/simcore_service_webserver/constants.py index dbf03900a064..42e3a4f1102b 100644 --- a/services/web/server/src/simcore_service_webserver/constants.py +++ b/services/web/server/src/simcore_service_webserver/constants.py @@ -51,6 +51,9 @@ "Please try again shortly. If the issue persists, contact support.", _version=1 ) +ASYNC_JOB_CLIENT_NAME: Final[str] = "WEBSERVER" + + __all__: tuple[str, ...] = ( "APP_AIOPG_ENGINE_KEY", "APP_CONFIG_KEY", diff --git a/services/web/server/src/simcore_service_webserver/storage/api.py b/services/web/server/src/simcore_service_webserver/storage/api.py index 868dd63ad935..48b5877a0e8f 100644 --- a/services/web/server/src/simcore_service_webserver/storage/api.py +++ b/services/web/server/src/simcore_service_webserver/storage/api.py @@ -7,7 +7,7 @@ from typing import Any, Final from aiohttp import ClientError, ClientSession, ClientTimeout, web -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobNameData +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.api_schemas_storage.storage_schemas import ( FileLocation, @@ -30,6 +30,7 @@ ) from yarl import URL +from ..constants import ASYNC_JOB_CLIENT_NAME from ..projects.models import ProjectDict from ..projects.utils import NodesMap from ..rabbitmq import get_rabbitmq_rpc_client @@ -37,6 +38,7 @@ _logger = logging.getLogger(__name__) + _TOTAL_TIMEOUT_TO_COPY_DATA_SECS: Final[int] = 60 * 60 _SIMCORE_LOCATION: Final[LocationID] = 0 @@ -117,7 +119,11 @@ async def copy_data_folders_from_project( rabbitmq_client, method_name="copy_folders_from_project", rpc_namespace=STORAGE_RPC_NAMESPACE, - job_id_data=AsyncJobNameData(user_id=user_id, product_name=product_name), + job_filter=AsyncJobFilter( + user_id=user_id, + product_name=product_name, + client_name=ASYNC_JOB_CLIENT_NAME, + ), body=TypeAdapter(FoldersBody).validate_python( { "source": source_project, diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index 9b6ce8411d83..13176af833b9 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -15,8 +15,8 @@ TaskStatus, ) from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobFilter, AsyncJobId, - AsyncJobNameData, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from pydantic import BaseModel @@ -32,6 +32,7 @@ from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs from .._meta import API_VTAG +from ..constants import ASYNC_JOB_CLIENT_NAME from ..login.decorators import login_required from ..long_running_tasks import webserver_request_context_decorator from ..models import AuthenticatedRequestContext @@ -69,8 +70,10 @@ async def get_async_jobs(request: web.Request) -> web.Response: user_async_jobs = await async_jobs.list_jobs( rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, - job_id_data=AsyncJobNameData( - user_id=_req_ctx.user_id, product_name=_req_ctx.product_name + job_filter=AsyncJobFilter( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + client_name=ASYNC_JOB_CLIENT_NAME, ), filter_="", ) @@ -119,8 +122,10 @@ async def get_async_job_status(request: web.Request) -> web.Response: rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_id_data=AsyncJobNameData( - user_id=_req_ctx.user_id, product_name=_req_ctx.product_name + job_filter=AsyncJobFilter( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + client_name=ASYNC_JOB_CLIENT_NAME, ), ) _task_id = f"{async_job_rpc_status.job_id}" @@ -154,8 +159,10 @@ async def cancel_async_job(request: web.Request) -> web.Response: rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_id_data=AsyncJobNameData( - user_id=_req_ctx.user_id, product_name=_req_ctx.product_name + job_filter=AsyncJobFilter( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + client_name=ASYNC_JOB_CLIENT_NAME, ), ) @@ -181,8 +188,10 @@ class _PathParams(BaseModel): rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_id_data=AsyncJobNameData( - user_id=_req_ctx.user_id, product_name=_req_ctx.product_name + job_filter=AsyncJobFilter( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + client_name=ASYNC_JOB_CLIENT_NAME, ), )