diff --git a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py index 3fb24ae952dc..3b19513ca361 100644 --- a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py +++ b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py @@ -1,12 +1,16 @@ -from typing import Any, TypeAlias +from typing import Annotated, Any, TypeAlias from uuid import UUID -from models_library.users import UserID -from pydantic import BaseModel +from pydantic import BaseModel, StringConstraints +from ..products import ProductName from ..progress_bar import ProgressReport +from ..users import UserID AsyncJobId: TypeAlias = UUID +AsyncJobName: TypeAlias = Annotated[ + str, StringConstraints(strip_whitespace=True, min_length=1) +] class AsyncJobStatus(BaseModel): @@ -21,6 +25,7 @@ class AsyncJobResult(BaseModel): class AsyncJobGet(BaseModel): job_id: AsyncJobId + job_name: AsyncJobName class AsyncJobAbort(BaseModel): @@ -31,5 +36,5 @@ class AsyncJobAbort(BaseModel): class AsyncJobNameData(BaseModel): """Data for controlling access to an async job""" + product_name: ProductName user_id: UserID - product_name: str diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py index f9dd7945358f..8a9f0a941ccb 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py @@ -26,7 +26,7 @@ async def _task_progress_cb( ) -> None: worker = get_celery_worker(task.app) assert task.name # nosec - await worker.set_progress( + await worker.set_task_progress( task_id=task_id, report=report, ) @@ -87,7 +87,7 @@ async def export_data( async def _progress_cb(report: ProgressReport) -> None: assert task.name # nosec - await get_celery_worker(task.app).set_progress(task_id, report) + await get_celery_worker(task.app).set_task_progress(task_id, report) _logger.debug("'%s' progress %s", task_id, report.percent_value) async with ProgressBarData( diff --git a/services/storage/src/simcore_service_storage/api/rest/_files.py b/services/storage/src/simcore_service_storage/api/rest/_files.py index dd3805024f43..f47818415700 100644 --- a/services/storage/src/simcore_service_storage/api/rest/_files.py +++ b/services/storage/src/simcore_service_storage/api/rest/_files.py @@ -35,7 +35,7 @@ UploadLinks, ) from ...modules.celery.client import CeleryTaskClient -from ...modules.celery.models import TaskUUID +from ...modules.celery.models import TaskMetadata, TaskUUID from ...simcore_s3_dsm import SimcoreS3DataManager from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file from .dependencies.celery import get_celery_client @@ -284,8 +284,10 @@ async def complete_upload_file( user_id=query_params.user_id, product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here ) - task_uuid = await celery_client.send_task( - remote_complete_upload_file.__name__, + task_uuid = await celery_client.submit_task( + TaskMetadata( + name=remote_complete_upload_file.__name__, + ), task_context=async_job_name_data.model_dump(), user_id=async_job_name_data.user_id, location_id=location_id, diff --git a/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py b/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py index 690c18d37e3c..080d5edf0457 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py @@ -127,10 +127,12 @@ async def list_jobs( _ = filter_ assert app # nosec try: - task_uuids = await get_celery_client(app).get_task_uuids( + tasks = await get_celery_client(app).list_tasks( task_context=job_id_data.model_dump(), ) except CeleryError as exc: raise JobSchedulerError(exc=f"{exc}") from exc - return [AsyncJobGet(job_id=task_uuid) for task_uuid in task_uuids] + return [ + AsyncJobGet(job_id=task.uuid, job_name=task.metadata.name) for task in tasks + ] diff --git a/services/storage/src/simcore_service_storage/api/rpc/_paths.py b/services/storage/src/simcore_service_storage/api/rpc/_paths.py index 0390156dac40..db0e69af38d7 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_paths.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_paths.py @@ -10,6 +10,7 @@ from servicelib.rabbitmq import RPCRouter from ...modules.celery import get_celery_client +from ...modules.celery.models import TaskMetadata from .._worker_tasks._paths import compute_path_size as remote_compute_path_size from .._worker_tasks._paths import delete_paths as remote_delete_paths @@ -24,15 +25,18 @@ async def compute_path_size( location_id: LocationID, path: Path, ) -> AsyncJobGet: - task_uuid = await get_celery_client(app).send_task( - remote_compute_path_size.__name__, + task_name = remote_compute_path_size.__name__ + task_uuid = await get_celery_client(app).submit_task( + task_metadata=TaskMetadata( + name=task_name, + ), task_context=job_id_data.model_dump(), user_id=job_id_data.user_id, location_id=location_id, path=path, ) - return AsyncJobGet(job_id=task_uuid) + return AsyncJobGet(job_id=task_uuid, job_name=task_name) @router.expose(reraise_if_error_type=None) @@ -42,11 +46,14 @@ async def delete_paths( location_id: LocationID, paths: set[Path], ) -> AsyncJobGet: - task_uuid = await get_celery_client(app).send_task( - remote_delete_paths.__name__, + task_name = remote_delete_paths.__name__ + task_uuid = await get_celery_client(app).submit_task( + task_metadata=TaskMetadata( + name=task_name, + ), task_context=job_id_data.model_dump(), user_id=job_id_data.user_id, location_id=location_id, paths=paths, ) - return AsyncJobGet(job_id=task_uuid) + return AsyncJobGet(job_id=task_uuid, job_name=task_name) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py index cd23f0ef8df1..ba3830c03298 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py @@ -20,28 +20,32 @@ async def copy_folders_from_project( job_id_data: AsyncJobNameData, body: FoldersBody, ) -> AsyncJobGet: - task_uuid = await get_celery_client(app).send_task( - deep_copy_files_from_project.__name__, + task_name = deep_copy_files_from_project.__name__ + task_uuid = await get_celery_client(app).submit_task( + task_metadata=TaskMetadata( + name=task_name, + ), task_context=job_id_data.model_dump(), user_id=job_id_data.user_id, body=body, ) - return AsyncJobGet(job_id=task_uuid) + return AsyncJobGet(job_id=task_uuid, job_name=task_name) @router.expose() async def start_export_data( app: FastAPI, job_id_data: AsyncJobNameData, paths_to_export: list[PathToExport] ) -> AsyncJobGet: - task_uuid = await get_celery_client(app).send_task( - export_data.__name__, - task_context=job_id_data.model_dump(), + task_name = export_data.__name__ + task_uuid = await get_celery_client(app).submit_task( task_metadata=TaskMetadata( + name=task_name, ephemeral=False, queue=TasksQueue.CPU_BOUND, ), + task_context=job_id_data.model_dump(), user_id=job_id_data.user_id, paths_to_export=paths_to_export, ) - return AsyncJobGet(job_id=task_uuid) + return AsyncJobGet(job_id=task_uuid, job_name=task_name) diff --git a/services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py b/services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py index 2a6dcce333f3..3fd9984fb2ab 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py +++ b/services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py @@ -1,12 +1,20 @@ +import contextlib import logging from datetime import timedelta from typing import Final -from celery.result import AsyncResult # type: ignore[import-untyped] from models_library.progress_bar import ProgressReport +from pydantic import ValidationError from servicelib.redis._client import RedisClientSDK -from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix +from ..models import ( + Task, + TaskContext, + TaskID, + TaskMetadata, + TaskUUID, + build_task_id_prefix, +) _CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-" _CELERY_TASK_ID_KEY_ENCODING = "utf-8" @@ -26,26 +34,64 @@ class RedisTaskInfoStore: def __init__(self, redis_client_sdk: RedisClientSDK) -> None: self._redis_client_sdk = redis_client_sdk - async def exists(self, task_id: TaskID) -> bool: + async def create_task( + self, + task_id: TaskID, + task_metadata: TaskMetadata, + expiry: timedelta, + ) -> None: + task_key = _build_key(task_id) + await self._redis_client_sdk.redis.hset( + name=task_key, + key=_CELERY_TASK_METADATA_KEY, + value=task_metadata.model_dump_json(), + ) # type: ignore + await self._redis_client_sdk.redis.expire( + task_key, + expiry, + ) + + async def exists_task(self, task_id: TaskID) -> bool: n = await self._redis_client_sdk.redis.exists(_build_key(task_id)) assert isinstance(n, int) # nosec return n > 0 - async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None: - result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore - return TaskMetadata.model_validate_json(result) if result else None + async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: + raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore + if not raw_result: + return None + + try: + return TaskMetadata.model_validate_json(raw_result) + except ValidationError as exc: + _logger.debug( + "Failed to deserialize task metadata for task %s: %s", task_id, f"{exc}" + ) + return None - async def get_progress(self, task_id: TaskID) -> ProgressReport | None: - result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore - return ProgressReport.model_validate_json(result) if result else None + async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: + raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore + if not raw_result: + return None - async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: + try: + return ProgressReport.model_validate_json(raw_result) + except ValidationError as exc: + _logger.debug( + "Failed to deserialize task progress for task %s: %s", task_id, f"{exc}" + ) + return None + + async def list_tasks(self, task_context: TaskContext) -> list[Task]: search_key = ( _CELERY_TASK_INFO_PREFIX + build_task_id_prefix(task_context) + _CELERY_TASK_ID_KEY_SEPARATOR ) - keys = set() + search_key_len = len(search_key) + + keys: list[str] = [] + pipeline = self._redis_client_sdk.redis.pipeline() async for key in self._redis_client_sdk.redis.scan_iter( match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH ): @@ -55,27 +101,31 @@ async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: if isinstance(key, bytes) else key ) - keys.add(TaskUUID(_key.removeprefix(search_key))) - return keys + keys.append(_key) + pipeline.hget(_key, _CELERY_TASK_METADATA_KEY) - async def remove(self, task_id: TaskID) -> None: - await self._redis_client_sdk.redis.delete(_build_key(task_id)) - AsyncResult(task_id).forget() + results = await pipeline.execute() - async def set_metadata( - self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta - ) -> None: - await self._redis_client_sdk.redis.hset( - name=_build_key(task_id), - key=_CELERY_TASK_METADATA_KEY, - value=task_metadata.model_dump_json(), - ) # type: ignore - await self._redis_client_sdk.redis.expire( - _build_key(task_id), - expiry, - ) + tasks = [] + for key, raw_metadata in zip(keys, results, strict=True): + if raw_metadata is None: + continue + + with contextlib.suppress(ValidationError): + task_metadata = TaskMetadata.model_validate_json(raw_metadata) + tasks.append( + Task( + uuid=TaskUUID(key[search_key_len:]), + metadata=task_metadata, + ) + ) + + return tasks + + async def remove_task(self, task_id: TaskID) -> None: + await self._redis_client_sdk.redis.delete(_build_key(task_id)) - async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None: + async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None: await self._redis_client_sdk.redis.hset( name=_build_key(task_id), key=_CELERY_TASK_PROGRESS_KEY, diff --git a/services/storage/src/simcore_service_storage/modules/celery/client.py b/services/storage/src/simcore_service_storage/modules/celery/client.py index 53cd13b2df13..305731f946af 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/client.py +++ b/services/storage/src/simcore_service_storage/modules/celery/client.py @@ -13,6 +13,7 @@ from settings_library.celery import CelerySettings from .models import ( + Task, TaskContext, TaskID, TaskInfoStore, @@ -36,24 +37,22 @@ class CeleryTaskClient: _celery_settings: CelerySettings _task_store: TaskInfoStore - async def send_task( + async def submit_task( self, - task_name: str, + task_metadata: TaskMetadata, *, task_context: TaskContext, - task_metadata: TaskMetadata | None = None, **task_params, ) -> TaskUUID: with log_context( _logger, logging.DEBUG, - msg=f"Submit {task_name=}: {task_context=} {task_params=}", + msg=f"Submit {task_metadata.name=}: {task_context=} {task_params=}", ): task_uuid = uuid4() task_id = build_task_id(task_context, task_uuid) - task_metadata = task_metadata or TaskMetadata() self._celery_app.send_task( - task_name, + task_metadata.name, task_id=task_id, kwargs=task_params, queue=task_metadata.queue.value, @@ -64,12 +63,14 @@ async def send_task( if task_metadata.ephemeral else self._celery_settings.CELERY_RESULT_EXPIRES ) - await self._task_store.set_metadata(task_id, task_metadata, expiry=expiry) + await self._task_store.create_task(task_id, task_metadata, expiry=expiry) return task_uuid @make_async() - def _abort_task(self, task_id: TaskID) -> None: - AbortableAsyncResult(task_id, app=self._celery_app).abort() + def _abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None: + AbortableAsyncResult( + build_task_id(task_context, task_uuid), app=self._celery_app + ).abort() async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None: with log_context( @@ -77,8 +78,11 @@ async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> No logging.DEBUG, msg=f"Abort task: {task_context=} {task_uuid=}", ): - task_id = build_task_id(task_context, task_uuid) - await self._abort_task(task_id) + await self._abort_task(task_context, task_uuid) + + @make_async() + def _forget_task(self, task_id: TaskID) -> None: + AbortableAsyncResult(task_id, app=self._celery_app).forget() async def get_task_result( self, task_context: TaskContext, task_uuid: TaskUUID @@ -92,19 +96,21 @@ async def get_task_result( async_result = self._celery_app.AsyncResult(task_id) result = async_result.result if async_result.ready(): - task_metadata = await self._task_store.get_metadata(task_id) + task_metadata = await self._task_store.get_task_metadata(task_id) if task_metadata is not None and task_metadata.ephemeral: - await self._task_store.remove(task_id) + await self._task_store.remove_task(task_id) + await self._forget_task(task_id) return result - async def _get_progress_report( - self, task_id: TaskID, state: TaskState + async def _get_task_progress_report( + self, task_context: TaskContext, task_uuid: TaskUUID, task_state: TaskState ) -> ProgressReport: - if state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED): - progress = await self._task_store.get_progress(task_id) + if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED): + task_id = build_task_id(task_context, task_uuid) + progress = await self._task_store.get_task_progress(task_id) if progress is not None: return progress - if state in ( + if task_state in ( TaskState.SUCCESS, TaskState.FAILURE, ): @@ -118,7 +124,9 @@ async def _get_progress_report( ) @make_async() - def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState: + def _get_task_celery_state( + self, task_context: TaskContext, task_uuid: TaskUUID + ) -> TaskState: task_id = build_task_id(task_context, task_uuid) return TaskState(self._celery_app.AsyncResult(task_id).state) @@ -130,18 +138,19 @@ async def get_task_status( logging.DEBUG, msg=f"Getting task status: {task_context=} {task_uuid=}", ): - task_state = await self._get_state(task_context, task_uuid) - task_id = build_task_id(task_context, task_uuid) + task_state = await self._get_task_celery_state(task_context, task_uuid) return TaskStatus( task_uuid=task_uuid, task_state=task_state, - progress_report=await self._get_progress_report(task_id, task_state), + progress_report=await self._get_task_progress_report( + task_context, task_uuid, task_state + ), ) - async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: + async def list_tasks(self, task_context: TaskContext) -> list[Task]: with log_context( _logger, logging.DEBUG, - msg=f"Getting task uuids: {task_context=}", + msg=f"Listing tasks: {task_context=}", ): - return await self._task_store.get_uuids(task_context) + return await self._task_store.list_tasks(task_context) diff --git a/services/storage/src/simcore_service_storage/modules/celery/models.py b/services/storage/src/simcore_service_storage/modules/celery/models.py index 57c70e6da363..8b19d124ff17 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/models.py +++ b/services/storage/src/simcore_service_storage/modules/celery/models.py @@ -1,13 +1,16 @@ from datetime import timedelta from enum import StrEnum -from typing import Any, Final, Protocol, TypeAlias +from typing import Annotated, Any, Final, Protocol, TypeAlias from uuid import UUID from models_library.progress_bar import ProgressReport -from pydantic import BaseModel +from pydantic import BaseModel, StringConstraints TaskContext: TypeAlias = dict[str, Any] TaskID: TypeAlias = str +TaskName: TypeAlias = Annotated[ + str, StringConstraints(strip_whitespace=True, min_length=1) +] TaskUUID: TypeAlias = UUID _CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":" @@ -40,29 +43,40 @@ class TasksQueue(StrEnum): class TaskMetadata(BaseModel): + name: TaskName ephemeral: bool = True queue: TasksQueue = TasksQueue.DEFAULT +class Task(BaseModel): + uuid: TaskUUID + metadata: TaskMetadata + + _TASK_DONE = {TaskState.SUCCESS, TaskState.FAILURE, TaskState.ABORTED} class TaskInfoStore(Protocol): - async def exists(self, task_id: TaskID) -> bool: ... + async def create_task( + self, + task_id: TaskID, + task_metadata: TaskMetadata, + expiry: timedelta, + ) -> None: ... - async def get_progress(self, task_id: TaskID) -> ProgressReport | None: ... + async def exists_task(self, task_id: TaskID) -> bool: ... - async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None: ... + async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ... - async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ... + async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ... - async def remove(self, task_id: TaskID) -> None: ... + async def list_tasks(self, task_context: TaskContext) -> list[Task]: ... - async def set_metadata( - self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta - ) -> None: ... + async def remove_task(self, task_id: TaskID) -> None: ... - async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None: ... + async def set_task_progress( + self, task_id: TaskID, report: ProgressReport + ) -> None: ... class TaskStatus(BaseModel): diff --git a/services/storage/src/simcore_service_storage/modules/celery/worker.py b/services/storage/src/simcore_service_storage/modules/celery/worker.py index ef71d23365c8..a5e98ac09df7 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/worker.py +++ b/services/storage/src/simcore_service_storage/modules/celery/worker.py @@ -12,8 +12,8 @@ class CeleryTaskWorker: _task_info_store: TaskInfoStore - async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None: - await self._task_info_store.set_progress( + async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None: + await self._task_info_store.set_task_progress( task_id=task_id, report=report, ) diff --git a/services/storage/tests/unit/test_async_jobs.py b/services/storage/tests/unit/test_async_jobs.py index 4df44c41b640..95319a6533ff 100644 --- a/services/storage/tests/unit/test_async_jobs.py +++ b/services/storage/tests/unit/test_async_jobs.py @@ -29,6 +29,7 @@ from simcore_service_storage.api.rpc.routes import get_rabbitmq_rpc_server from simcore_service_storage.modules.celery import get_celery_client from simcore_service_storage.modules.celery._task import register_task +from simcore_service_storage.modules.celery.client import TaskMetadata from simcore_service_storage.modules.celery.models import TaskID from simcore_service_storage.modules.celery.worker import CeleryTaskWorker from tenacity import ( @@ -52,22 +53,24 @@ async def rpc_sync_job( app: FastAPI, *, job_id_data: AsyncJobNameData, **kwargs: Any ) -> AsyncJobGet: - task_uuid = await get_celery_client(app).send_task( - sync_job.__name__, task_context=job_id_data.model_dump(), **kwargs + task_name = sync_job.__name__ + task_uuid = await get_celery_client(app).submit_task( + TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs ) - return AsyncJobGet(job_id=task_uuid) + return AsyncJobGet(job_id=task_uuid, job_name=task_name) @router.expose() async def rpc_async_job( app: FastAPI, *, job_id_data: AsyncJobNameData, **kwargs: Any ) -> AsyncJobGet: - task_uuid = await get_celery_client(app).send_task( - async_job.__name__, task_context=job_id_data.model_dump(), **kwargs + task_name = async_job.__name__ + task_uuid = await get_celery_client(app).submit_task( + TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs ) - return AsyncJobGet(job_id=task_uuid) + return AsyncJobGet(job_id=task_uuid, job_name=task_name) ################################# diff --git a/services/storage/tests/unit/test_modules_celery.py b/services/storage/tests/unit/test_modules_celery.py index 8fb1de6d31fe..d5f3ce70b980 100644 --- a/services/storage/tests/unit/test_modules_celery.py +++ b/services/storage/tests/unit/test_modules_celery.py @@ -18,11 +18,16 @@ from models_library.progress_bar import ProgressReport from servicelib.logging_utils import log_context from simcore_service_storage.modules.celery import get_celery_client, get_event_loop -from simcore_service_storage.modules.celery._task import register_task +from simcore_service_storage.modules.celery._task import ( + AbortableAsyncResult, + register_task, +) from simcore_service_storage.modules.celery.client import CeleryTaskClient from simcore_service_storage.modules.celery.errors import TransferrableCeleryError from simcore_service_storage.modules.celery.models import ( TaskContext, + TaskID, + TaskMetadata, TaskState, ) from simcore_service_storage.modules.celery.utils import ( @@ -56,7 +61,7 @@ def sleep_for(seconds: float) -> None: for n, file in enumerate(files, start=1): with log_context(_logger, logging.INFO, msg=f"Processing file {file}"): - await worker.set_progress( + await worker.set_task_progress( task_id=task_id, report=ProgressReport(actual_value=n / len(files)), ) @@ -84,10 +89,10 @@ def failure_task(task: Task): raise MyError(msg=msg) -async def dreamer_task(task: AbortableTask) -> list[int]: +async def dreamer_task(task: AbortableTask, task_id: TaskID) -> list[int]: numbers = [] for _ in range(30): - if task.is_aborted(): + if AbortableAsyncResult(task_id, app=task.app).is_aborted(): _logger.warning("Alarm clock") return numbers numbers.append(randint(1, 90)) # noqa: S311 @@ -110,8 +115,10 @@ async def test_submitting_task_calling_async_function_results_with_success_state ): task_context = TaskContext(user_id=42) - task_uuid = await celery_client.send_task( - fake_file_processor.__name__, + task_uuid = await celery_client.submit_task( + TaskMetadata( + name=fake_file_processor.__name__, + ), task_context=task_context, files=[f"file{n}" for n in range(5)], ) @@ -138,8 +145,11 @@ async def test_submitting_task_with_failure_results_with_error( ): task_context = TaskContext(user_id=42) - task_uuid = await celery_client.send_task( - failure_task.__name__, task_context=task_context + task_uuid = await celery_client.submit_task( + TaskMetadata( + name=failure_task.__name__, + ), + task_context=task_context, ) for attempt in Retrying( @@ -161,8 +171,10 @@ async def test_aborting_task_results_with_aborted_state( ): task_context = TaskContext(user_id=42) - task_uuid = await celery_client.send_task( - dreamer_task.__name__, + task_uuid = await celery_client.submit_task( + TaskMetadata( + name=dreamer_task.__name__, + ), task_context=task_context, ) @@ -187,8 +199,10 @@ async def test_listing_task_uuids_contains_submitted_task( ): task_context = TaskContext(user_id=42) - task_uuid = await celery_client.send_task( - dreamer_task.__name__, + task_uuid = await celery_client.submit_task( + TaskMetadata( + name=dreamer_task.__name__, + ), task_context=task_context, ) @@ -198,6 +212,10 @@ async def test_listing_task_uuids_contains_submitted_task( stop=stop_after_delay(10), ): with attempt: - assert task_uuid in await celery_client.get_task_uuids(task_context) + tasks = await celery_client.list_tasks(task_context) + assert len(tasks) == 1 + assert task_uuid == tasks[0].uuid - assert task_uuid in await celery_client.get_task_uuids(task_context) + tasks = await celery_client.list_tasks(task_context) + assert len(tasks) == 1 + assert task_uuid == tasks[0].uuid diff --git a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py index d24a5785d598..751cdae4f2f8 100644 --- a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py +++ b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py @@ -547,7 +547,7 @@ async def _request_start_export_data( @pytest.fixture def task_progress_spy(mocker: MockerFixture) -> Mock: - return mocker.spy(CeleryTaskWorker, "set_progress") + return mocker.spy(CeleryTaskWorker, "set_task_progress") @pytest.mark.parametrize( diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index a20c88c352e6..71850d627a78 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -85,7 +85,7 @@ async def get_async_jobs(request: web.Request) -> web.Response: [ TaskGet( task_id=f"{job.job_id}", - task_name=f"{job.job_id}", + task_name=job.job_name, status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=str(job.job_id))))}", abort_href=f"{request.url.with_path(str(request.app.router['abort_async_job'].url_for(task_id=str(job.job_id))))}", result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=str(job.job_id))))}", diff --git a/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py b/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py index 9f74803b5636..73b6c3c086a7 100644 --- a/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py +++ b/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py @@ -162,7 +162,9 @@ def side_effect(*args, **kwargs): @pytest.mark.parametrize( "backend_result_or_exception", [ - AsyncJobGet(job_id=AsyncJobId(f"{_faker.uuid4()}")), + AsyncJobGet( + job_id=AsyncJobId(f"{_faker.uuid4()}"), job_name="compute_path_size" + ), ], ids=lambda x: type(x).__name__, ) @@ -204,7 +206,9 @@ async def test_compute_path_size( @pytest.mark.parametrize( "backend_result_or_exception", [ - AsyncJobGet(job_id=AsyncJobId(f"{_faker.uuid4()}")), + AsyncJobGet( + job_id=AsyncJobId(f"{_faker.uuid4()}"), job_name="batch_delete_paths" + ), ], ids=lambda x: type(x).__name__, ) @@ -429,7 +433,12 @@ def side_effect(*args, **kwargs): "backend_result_or_exception, expected_status", [ ( - (AsyncJobGet(job_id=AsyncJobId(f"{_faker.uuid4()}")), None), + ( + AsyncJobGet( + job_id=AsyncJobId(f"{_faker.uuid4()}"), job_name="export_data" + ), + None, + ), status.HTTP_202_ACCEPTED, ), ( @@ -590,6 +599,7 @@ async def test_get_async_job_result( [ AsyncJobGet( job_id=AsyncJobId(_faker.uuid4()), + job_name="task_name", ) ], status.HTTP_200_OK, @@ -668,7 +678,13 @@ async def test_get_async_job_links( create_storage_rpc_client_mock( "simcore_service_webserver.storage._rest", start_export_data.__name__, - (AsyncJobGet(job_id=AsyncJobId(f"{_faker.uuid4()}")), None), + ( + AsyncJobGet( + job_id=AsyncJobId(f"{_faker.uuid4()}"), + job_name="export_data", + ), + None, + ), ) _body = DataExportPost(