From 3793141a4ccc4b55eb07f7c55dbed05aaebde7b4 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Wed, 26 Mar 2025 11:52:26 +0100 Subject: [PATCH] bad definition of function --- .../api/_worker_tasks/_paths.py | 4 +- .../modules/celery/_celery_types.py | 2 + .../modules/celery/_task.py | 30 +++++----- .../modules/celery/client.py | 57 +++++++++++++------ 4 files changed, 60 insertions(+), 33 deletions(-) diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py index f8208221f9fe..fae0bdc770c8 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py @@ -8,14 +8,16 @@ from servicelib.logging_utils import log_context from ...dsm import get_dsm_provider +from ...modules.celery.models import TaskId from ...modules.celery.utils import get_fastapi_app _logger = logging.getLogger(__name__) async def compute_path_size( - task: Task, user_id: UserID, location_id: LocationID, path: Path + task: Task, task_id: TaskId, user_id: UserID, location_id: LocationID, path: Path ) -> ByteSize: + assert task_id # nosec with log_context( _logger, logging.INFO, diff --git a/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py b/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py index 1ad45342a248..7fb44e087d57 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py +++ b/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from ...models import FileMetaData +from ...modules.celery.models import TaskError def _path_encoder(obj): @@ -57,3 +58,4 @@ def register_celery_types() -> None: _register_pydantic_types(FileUploadCompletionBody) _register_pydantic_types(FileMetaData) _register_pydantic_types(FoldersBody) + _register_pydantic_types(TaskError) diff --git a/services/storage/src/simcore_service_storage/modules/celery/_task.py b/services/storage/src/simcore_service_storage/modules/celery/_task.py index 02d5c6a5edd6..6e735a8be815 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/_task.py +++ b/services/storage/src/simcore_service_storage/modules/celery/_task.py @@ -6,10 +6,7 @@ from functools import wraps from typing import Any, Concatenate, ParamSpec, TypeVar, overload -from celery import ( # type: ignore[import-untyped] - Celery, - Task, -) +from celery import Celery # type: ignore[import-untyped] from celery.contrib.abortable import AbortableTask # type: ignore[import-untyped] from celery.exceptions import Ignore # type: ignore[import-untyped] @@ -22,7 +19,7 @@ def error_handling(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any: + def wrapper(task: AbortableTask, *args: Any, **kwargs: Any) -> Any: try: return func(task, *args, **kwargs) except Exception as exc: @@ -31,8 +28,9 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any: exc_traceback = traceback.format_exc().split("\n") _logger.exception( - "Task %s failed with exception: %s", + "Task %s failed with exception: %s:%s", task.request.id, + exc_type, exc_message, ) @@ -57,14 +55,14 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any: def _async_task_wrapper( app: Celery, ) -> Callable[ - [Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]], - Callable[Concatenate[Task, P], R], + [Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]]], + Callable[Concatenate[AbortableTask, P], R], ]: def decorator( - coro: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]], - ) -> Callable[Concatenate[Task, P], R]: + coro: Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]], + ) -> Callable[Concatenate[AbortableTask, P], R]: @wraps(coro) - def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R: + def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R: fastapi_app = get_fastapi_app(app) _logger.debug("task id: %s", task.request.id) # NOTE: task.request is a thread local object, so we need to pass the id explicitly @@ -82,7 +80,7 @@ def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R: @overload def define_task( app: Celery, - fn: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]], + fn: Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]], task_name: str | None = None, ) -> None: ... @@ -90,7 +88,7 @@ def define_task( @overload def define_task( app: Celery, - fn: Callable[Concatenate[Task, P], R], + fn: Callable[Concatenate[AbortableTask, P], R], task_name: str | None = None, ) -> None: ... @@ -98,13 +96,13 @@ def define_task( def define_task( # type: ignore[misc] app: Celery, fn: ( - Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]] - | Callable[Concatenate[Task, P], R] + Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]] + | Callable[Concatenate[AbortableTask, P], R] ), task_name: str | None = None, ) -> None: """Decorator to define a celery task with error handling and abortable support""" - wrapped_fn: Callable[Concatenate[Task, P], R] + wrapped_fn: Callable[Concatenate[AbortableTask, P], R] if asyncio.iscoroutinefunction(fn): wrapped_fn = _async_task_wrapper(app)(fn) else: diff --git a/services/storage/src/simcore_service_storage/modules/celery/client.py b/services/storage/src/simcore_service_storage/modules/celery/client.py index b6c2a0906917..d5b1d1b88af9 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/client.py +++ b/services/storage/src/simcore_service_storage/modules/celery/client.py @@ -1,5 +1,6 @@ import contextlib import logging +from dataclasses import dataclass from typing import Any, Final from uuid import uuid4 @@ -12,6 +13,7 @@ from pydantic import ValidationError from servicelib.logging_utils import log_context +from ...exceptions.errors import ConfigurationError from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID _logger = logging.getLogger(__name__) @@ -53,36 +55,44 @@ def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID: ) +@dataclass class CeleryTaskQueueClient: - def __init__(self, celery_app: Celery): - self._celery_app = celery_app + _celery_app: Celery @make_async() def send_task( self, task_name: str, *, task_context: TaskContext, **task_params ) -> TaskUUID: - task_uuid = uuid4() - task_id = _build_task_id(task_context, task_uuid) with log_context( _logger, logging.DEBUG, - msg=f"Submitting task {task_name}: {task_id=} {task_params=}", + msg=f"Submit {task_name=}: {task_context=} {task_params=}", ): + task_uuid = uuid4() + task_id = _build_task_id(task_context, task_uuid) self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params) return task_uuid + @staticmethod @make_async() - def abort_task( # pylint: disable=R6301 - self, task_context: TaskContext, task_uuid: TaskUUID - ) -> None: - task_id = _build_task_id(task_context, task_uuid) - _logger.info("Aborting task %s", task_id) - AbortableAsyncResult(task_id).abort() + def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None: + with log_context( + _logger, + logging.DEBUG, + msg=f"Abort task {task_uuid=}: {task_context=}", + ): + task_id = _build_task_id(task_context, task_uuid) + AbortableAsyncResult(task_id).abort() @make_async() def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any: - task_id = _build_task_id(task_context, task_uuid) - return self._celery_app.AsyncResult(task_id).result + with log_context( + _logger, + logging.DEBUG, + msg=f"Get task {task_uuid=}: {task_context=} result", + ): + task_id = _build_task_id(task_context, task_uuid) + return self._celery_app.AsyncResult(task_id).result def _get_progress_report( self, task_context: TaskContext, task_uuid: TaskUUID @@ -118,15 +128,30 @@ def get_task_status( def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context) - redis = self._celery_app.backend.client - if hasattr(redis, "keys") and (keys := redis.keys(search_key + "*")): + backend_client = self._celery_app.backend.client + if hasattr(backend_client, "keys") and ( + keys := backend_client.keys(f"{search_key}*") + ): return { TaskUUID( f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}" ) for key in keys } - return set() + if hasattr(backend_client, "cache"): + # NOTE: backend used in testing. It is a dict-like object + found_keys = set() + for key in backend_client.cache: + str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING) + if str_key.startswith(search_key): + found_keys.add( + TaskUUID( + f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}" + ) + ) + return found_keys + msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}" + raise ConfigurationError(msg=msg) @make_async() def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: