diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/_redis.py index 6b30b5e5fc45..2bd80ed76d87 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/_redis.py @@ -19,7 +19,7 @@ _CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-" _CELERY_TASK_ID_KEY_ENCODING = "utf-8" _CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":" -_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000 +_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000 _CELERY_TASK_METADATA_KEY: Final[str] = "metadata" _CELERY_TASK_PROGRESS_KEY: Final[str] = "progress" @@ -51,11 +51,6 @@ async def create_task( expiry, ) - async def exists_task(self, task_id: TaskID) -> bool: - n = await self._redis_client_sdk.redis.exists(_build_key(task_id)) - assert isinstance(n, int) # nosec - return n > 0 - async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore if not raw_result: @@ -131,3 +126,8 @@ async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> No key=_CELERY_TASK_PROGRESS_KEY, value=report.model_dump_json(), ) # type: ignore + + async def task_exists(self, task_id: TaskID) -> bool: + n = await self._redis_client_sdk.redis.exists(_build_key(task_id)) + assert isinstance(n, int) # nosec + return n > 0 diff --git a/packages/celery-library/src/celery_library/errors.py b/packages/celery-library/src/celery_library/errors.py index 37b174189f81..e4ba148b8812 100644 --- a/packages/celery-library/src/celery_library/errors.py +++ b/packages/celery-library/src/celery_library/errors.py @@ -1,6 +1,8 @@ import base64 import pickle +from common_library.errors_classes import OsparcErrorMixin + class TransferrableCeleryError(Exception): def __repr__(self) -> str: @@ -22,3 +24,7 @@ def decode_celery_transferrable_error(error: TransferrableCeleryError) -> Except assert isinstance(error, TransferrableCeleryError) # nosec result: Exception = pickle.loads(base64.b64decode(error.args[0])) # noqa: S301 return result + + +class TaskNotFoundError(OsparcErrorMixin, Exception): + msg_template = "Task with id '{task_id}' was not found" diff --git a/packages/celery-library/src/celery_library/rpc/_async_jobs.py b/packages/celery-library/src/celery_library/rpc/_async_jobs.py index ea7cb5876a5d..6fb45336fdcd 100644 --- a/packages/celery-library/src/celery_library/rpc/_async_jobs.py +++ b/packages/celery-library/src/celery_library/rpc/_async_jobs.py @@ -13,6 +13,7 @@ from models_library.api_schemas_rpc_async_jobs.exceptions import ( JobAbortedError, JobError, + JobMissingError, JobNotDoneError, JobSchedulerError, ) @@ -22,6 +23,7 @@ from servicelib.rabbitmq import RPCRouter from ..errors import ( + TaskNotFoundError, TransferrableCeleryError, decode_celery_transferrable_error, ) @@ -30,7 +32,7 @@ router = RPCRouter() -@router.expose(reraise_if_error_type=(JobSchedulerError,)) +@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError)) async def cancel( task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter ): @@ -42,11 +44,13 @@ async def cancel( task_filter=task_filter, task_uuid=job_id, ) + except TaskNotFoundError as exc: + raise JobMissingError(job_id=job_id) from exc except CeleryError as exc: raise JobSchedulerError(exc=f"{exc}") from exc -@router.expose(reraise_if_error_type=(JobSchedulerError,)) +@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError)) async def status( task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter ) -> AsyncJobStatus: @@ -59,6 +63,8 @@ async def status( task_filter=task_filter, task_uuid=job_id, ) + except TaskNotFoundError as exc: + raise JobMissingError(job_id=job_id) from exc except CeleryError as exc: raise JobSchedulerError(exc=f"{exc}") from exc @@ -71,9 +77,10 @@ async def status( @router.expose( reraise_if_error_type=( + JobAbortedError, JobError, + JobMissingError, JobNotDoneError, - JobAbortedError, JobSchedulerError, ) ) @@ -97,11 +104,11 @@ async def result( task_filter=task_filter, task_uuid=job_id, ) + except TaskNotFoundError as exc: + raise JobMissingError(job_id=job_id) from exc except CeleryError as exc: raise JobSchedulerError(exc=f"{exc}") from exc - if _status.task_state == TaskState.ABORTED: - raise JobAbortedError(job_id=job_id) if _status.task_state == TaskState.FAILURE: # fallback exception to report exc_type = type(_result).__name__ diff --git a/packages/celery-library/src/celery_library/task.py b/packages/celery-library/src/celery_library/task.py index 075e10036bc5..c3efc7ead141 100644 --- a/packages/celery-library/src/celery_library/task.py +++ b/packages/celery-library/src/celery_library/task.py @@ -6,11 +6,7 @@ from functools import wraps from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload -from celery import Celery # type: ignore[import-untyped] -from celery.contrib.abortable import ( # type: ignore[import-untyped] - AbortableAsyncResult, - AbortableTask, -) +from celery import Celery, Task # type: ignore[import-untyped] from celery.exceptions import Ignore # type: ignore[import-untyped] from common_library.async_tools import cancel_wait_task from pydantic import NonNegativeInt @@ -39,42 +35,42 @@ class TaskAbortedError(Exception): ... def _async_task_wrapper( app: Celery, ) -> Callable[ - [Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]]], - Callable[Concatenate[AbortableTask, P], R], + [Callable[Concatenate[Task, P], Coroutine[Any, Any, R]]], + Callable[Concatenate[Task, P], R], ]: def decorator( - coro: Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]], - ) -> Callable[Concatenate[AbortableTask, P], R]: + coro: Callable[Concatenate[Task, P], Coroutine[Any, Any, R]], + ) -> Callable[Concatenate[Task, P], R]: @wraps(coro) - def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R: + def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R: app_server = get_app_server(app) # NOTE: task.request is a thread local object, so we need to pass the id explicitly assert task.request.id is not None # nosec - async def run_task(task_id: TaskID) -> R: + async def _run_task(task_id: TaskID) -> R: try: async with asyncio.TaskGroup() as tg: - main_task = tg.create_task( + async_io_task = tg.create_task( coro(task, *args, **kwargs), ) - async def abort_monitor(): - abortable_result = AbortableAsyncResult(task_id, app=app) - while not main_task.done(): - if abortable_result.is_aborted(): + async def _abort_monitor(): + while not async_io_task.done(): + if not await app_server.task_manager.task_exists( + task_id + ): await cancel_wait_task( - main_task, + async_io_task, max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(), ) - AbortableAsyncResult(task_id, app=app).forget() raise TaskAbortedError await asyncio.sleep( _DEFAULT_ABORT_TASK_TIMEOUT.total_seconds() ) - tg.create_task(abort_monitor()) + tg.create_task(_abort_monitor()) - return main_task.result() + return async_io_task.result() except BaseExceptionGroup as eg: task_aborted_errors, other_errors = eg.split(TaskAbortedError) @@ -88,7 +84,7 @@ async def abort_monitor(): raise other_errors.exceptions[0] from eg return asyncio.run_coroutine_threadsafe( - run_task(task.request.id), + _run_task(task.request.id), app_server.event_loop, ).result() @@ -102,14 +98,14 @@ def _error_handling( delay_between_retries: timedelta, dont_autoretry_for: tuple[type[Exception], ...], ) -> Callable[ - [Callable[Concatenate[AbortableTask, P], R]], - Callable[Concatenate[AbortableTask, P], R], + [Callable[Concatenate[Task, P], R]], + Callable[Concatenate[Task, P], R], ]: def decorator( - func: Callable[Concatenate[AbortableTask, P], R], - ) -> Callable[Concatenate[AbortableTask, P], R]: + func: Callable[Concatenate[Task, P], R], + ) -> Callable[Concatenate[Task, P], R]: @wraps(func) - def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R: + def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R: try: return func(task, *args, **kwargs) except TaskAbortedError as exc: @@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R: @overload def register_task( app: Celery, - fn: Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]], + fn: Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]], task_name: str | None = None, timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT, max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES, @@ -156,7 +152,7 @@ def register_task( @overload def register_task( app: Celery, - fn: Callable[Concatenate[AbortableTask, P], R], + fn: Callable[Concatenate[Task, P], R], task_name: str | None = None, timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT, max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES, @@ -168,8 +164,8 @@ def register_task( def register_task( # type: ignore[misc] app: Celery, fn: ( - Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]] - | Callable[Concatenate[AbortableTask, P], R] + Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]] + | Callable[Concatenate[Task, P], R] ), task_name: str | None = None, timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT, @@ -186,7 +182,7 @@ def register_task( # type: ignore[misc] delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY}) dont_autoretry_for -- exceptions that should not be retried when raised by the task """ - wrapped_fn: Callable[Concatenate[AbortableTask, P], R] + wrapped_fn: Callable[Concatenate[Task, P], R] if asyncio.iscoroutinefunction(fn): wrapped_fn = _async_task_wrapper(app)(fn) else: @@ -202,7 +198,6 @@ def register_task( # type: ignore[misc] app.task( name=task_name or fn.__name__, bind=True, - base=AbortableTask, time_limit=None if timeout is None else timeout.total_seconds(), pydantic=True, )(wrapped_fn) diff --git a/packages/celery-library/src/celery_library/task_manager.py b/packages/celery-library/src/celery_library/task_manager.py index 72ca039f6ca2..04e18a291583 100644 --- a/packages/celery-library/src/celery_library/task_manager.py +++ b/packages/celery-library/src/celery_library/task_manager.py @@ -4,12 +4,10 @@ from uuid import uuid4 from celery import Celery # type: ignore[import-untyped] -from celery.contrib.abortable import ( # type: ignore[import-untyped] - AbortableAsyncResult, -) from common_library.async_tools import make_async from models_library.progress_bar import ProgressReport from servicelib.celery.models import ( + TASK_DONE_STATES, Task, TaskFilter, TaskID, @@ -23,6 +21,7 @@ from servicelib.logging_utils import log_context from settings_library.celery import CelerySettings +from .errors import TaskNotFoundError from .utils import build_task_id _logger = logging.getLogger(__name__) @@ -69,10 +68,6 @@ async def submit_task( ) return task_uuid - @make_async() - def _abort_task(self, task_id: TaskID) -> None: - AbortableAsyncResult(task_id, app=self._celery_app).abort() - async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None: with log_context( _logger, @@ -80,13 +75,18 @@ async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> Non msg=f"task cancellation: {task_filter=} {task_uuid=}", ): task_id = build_task_id(task_filter, task_uuid) - if not (await self.get_task_status(task_filter, task_uuid)).is_done: - await self._abort_task(task_id) + if not await self.task_exists(task_id): + raise TaskNotFoundError(task_id=task_id) + await self._task_info_store.remove_task(task_id) + await self._forget_task(task_id) + + async def task_exists(self, task_id: TaskID) -> bool: + return await self._task_info_store.task_exists(task_id) @make_async() def _forget_task(self, task_id: TaskID) -> None: - AbortableAsyncResult(task_id, app=self._celery_app).forget() + self._celery_app.AsyncResult(task_id).forget() async def get_task_result( self, task_filter: TaskFilter, task_uuid: TaskUUID @@ -97,27 +97,27 @@ async def get_task_result( msg=f"Get task result: {task_filter=} {task_uuid=}", ): task_id = build_task_id(task_filter, task_uuid) + if not await self.task_exists(task_id): + raise TaskNotFoundError(task_id=task_id) + async_result = self._celery_app.AsyncResult(task_id) result = async_result.result if async_result.ready(): task_metadata = await self._task_info_store.get_task_metadata(task_id) if task_metadata is not None and task_metadata.ephemeral: - await self._forget_task(task_id) await self._task_info_store.remove_task(task_id) + await self._forget_task(task_id) return result async def _get_task_progress_report( - self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState + self, task_id: TaskID, task_state: TaskState ) -> ProgressReport: - if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED): - task_id = build_task_id(task_filter, task_uuid) + if task_state in (TaskState.STARTED, TaskState.RETRY): progress = await self._task_info_store.get_task_progress(task_id) if progress is not None: return progress - if task_state in ( - TaskState.SUCCESS, - TaskState.FAILURE, - ): + + if task_state in TASK_DONE_STATES: return ProgressReport( actual_value=_MAX_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE ) @@ -140,12 +140,15 @@ async def get_task_status( msg=f"Getting task status: {task_filter=} {task_uuid=}", ): task_id = build_task_id(task_filter, task_uuid) + if not await self.task_exists(task_id): + raise TaskNotFoundError(task_id=task_id) + task_state = await self._get_task_celery_state(task_id) return TaskStatus( task_uuid=task_uuid, task_state=task_state, progress_report=await self._get_task_progress_report( - task_filter, task_uuid, task_state + task_id, task_state ), ) diff --git a/packages/celery-library/tests/unit/test_async_jobs.py b/packages/celery-library/tests/unit/test_async_jobs.py index 4a646a1fdb46..cc72bd6b75ed 100644 --- a/packages/celery-library/tests/unit/test_async_jobs.py +++ b/packages/celery-library/tests/unit/test_async_jobs.py @@ -20,8 +20,8 @@ AsyncJobGet, ) from models_library.api_schemas_rpc_async_jobs.exceptions import ( - JobAbortedError, JobError, + JobMissingError, ) from models_library.products import ProductName from models_library.rabbitmq_basic_types import RPCNamespace @@ -308,12 +308,6 @@ async def test_async_jobs_cancel( job_filter=job_filter, ) - await _wait_for_job( - async_jobs_rabbitmq_rpc_client, - async_job_get=async_job_get, - job_filter=job_filter, - ) - jobs = await async_jobs.list_jobs( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, @@ -322,7 +316,15 @@ async def test_async_jobs_cancel( ) assert async_job_get.job_id not in [job.job_id for job in jobs] - with pytest.raises(JobAbortedError): + with pytest.raises(JobMissingError): + await async_jobs.status( + async_jobs_rabbitmq_rpc_client, + rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, + job_id=async_job_get.job_id, + job_filter=job_filter, + ) + + with pytest.raises(JobMissingError): await async_jobs.result( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, diff --git a/packages/celery-library/tests/unit/test_tasks.py b/packages/celery-library/tests/unit/test_tasks.py index 35da31aa1802..d3e768c9ff1b 100644 --- a/packages/celery-library/tests/unit/test_tasks.py +++ b/packages/celery-library/tests/unit/test_tasks.py @@ -12,8 +12,7 @@ import pytest from celery import Celery, Task # pylint: disable=no-name-in-module -from celery.contrib.abortable import AbortableTask # pylint: disable=no-name-in-module -from celery_library.errors import TransferrableCeleryError +from celery_library.errors import TaskNotFoundError, TransferrableCeleryError from celery_library.task import register_task from celery_library.task_manager import CeleryTaskManager from celery_library.utils import get_app_server @@ -72,7 +71,7 @@ def failure_task(task: Task, task_id: TaskID) -> None: raise MyError(msg=msg) -async def dreamer_task(task: AbortableTask, task_id: TaskID) -> list[int]: +async def dreamer_task(task: Task, task_id: TaskID) -> list[int]: numbers = [] for _ in range(30): numbers.append(randint(1, 90)) # noqa: S311 @@ -164,18 +163,8 @@ async def test_cancelling_a_running_task_aborts_and_deletes( await celery_task_manager.cancel_task(task_filter, task_uuid) - for attempt in Retrying( - retry=retry_if_exception_type(AssertionError), - wait=wait_fixed(1), - stop=stop_after_delay(30), - ): - with attempt: - progress = await celery_task_manager.get_task_status(task_filter, task_uuid) - assert progress.task_state == TaskState.ABORTED - - assert ( + with pytest.raises(TaskNotFoundError): await celery_task_manager.get_task_status(task_filter, task_uuid) - ).task_state == TaskState.ABORTED assert task_uuid not in await celery_task_manager.list_tasks(task_filter) diff --git a/packages/service-library/src/servicelib/celery/models.py b/packages/service-library/src/servicelib/celery/models.py index 0c46e1716b14..c35fc98504ed 100644 --- a/packages/service-library/src/servicelib/celery/models.py +++ b/packages/service-library/src/servicelib/celery/models.py @@ -1,6 +1,6 @@ import datetime from enum import StrEnum -from typing import Annotated, Protocol, TypeAlias +from typing import Annotated, Final, Protocol, TypeAlias from uuid import UUID from models_library.progress_bar import ProgressReport @@ -23,7 +23,12 @@ class TaskState(StrEnum): RETRY = "RETRY" SUCCESS = "SUCCESS" FAILURE = "FAILURE" - ABORTED = "ABORTED" + + +TASK_DONE_STATES: Final[tuple[TaskState, ...]] = ( + TaskState.SUCCESS, + TaskState.FAILURE, +) class TasksQueue(StrEnum): @@ -78,9 +83,6 @@ def _update_json_schema_extra(schema: JsonDict) -> None: model_config = ConfigDict(json_schema_extra=_update_json_schema_extra) -_TASK_DONE = {TaskState.SUCCESS, TaskState.FAILURE, TaskState.ABORTED} - - class TaskInfoStore(Protocol): async def create_task( self, @@ -89,7 +91,7 @@ async def create_task( expiry: datetime.timedelta, ) -> None: ... - async def exists_task(self, task_id: TaskID) -> bool: ... + async def task_exists(self, task_id: TaskID) -> bool: ... async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ... @@ -138,4 +140,4 @@ def _update_json_schema_extra(schema: JsonDict) -> None: @property def is_done(self) -> bool: - return self.task_state in _TASK_DONE + return self.task_state in TASK_DONE_STATES diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index 93612e6845fe..68a62edbb8ae 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -21,6 +21,8 @@ async def cancel_task( self, task_filter: TaskFilter, task_uuid: TaskUUID ) -> None: ... + async def task_exists(self, task_id: TaskID) -> bool: ... + async def get_task_result( self, task_filter: TaskFilter, task_uuid: TaskUUID ) -> Any: ... diff --git a/services/api-server/openapi.json b/services/api-server/openapi.json index 57c07b5e4c45..17f6bd059ae3 100644 --- a/services/api-server/openapi.json +++ b/services/api-server/openapi.json @@ -8474,16 +8474,6 @@ } } }, - "409": { - "description": "Task is cancelled", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorGet" - } - } - } - }, "500": { "description": "Internal server error", "content": { diff --git a/services/api-server/src/simcore_service_api_server/_constants.py b/services/api-server/src/simcore_service_api_server/_constants.py index 512a987b640d..8d2ffb058564 100644 --- a/services/api-server/src/simcore_service_api_server/_constants.py +++ b/services/api-server/src/simcore_service_api_server/_constants.py @@ -10,3 +10,8 @@ "Something went wrong on our end. We've been notified and will resolve this issue as soon as possible. Thank you for your patience.", _version=2, ) + +MSG_CLIENT_ERROR_USER_FRIENDLY_TEMPLATE: Final[str] = user_message( + "Something went wrong with your request.", + _version=1, +) diff --git a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py index 9837a5f625f3..36663efccf2d 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py @@ -158,10 +158,6 @@ async def cancel_task( "description": "Task result not found", "model": ErrorGet, }, - status.HTTP_409_CONFLICT: { - "description": "Task is cancelled", - "model": ErrorGet, - }, **_DEFAULT_TASK_STATUS_CODES, }, description=create_route_description( @@ -191,11 +187,6 @@ async def get_task_result( status_code=status.HTTP_404_NOT_FOUND, detail="Task result not available yet", ) - if task_status.task_state == TaskState.ABORTED: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="Task was cancelled", - ) task_result = await task_manager.get_task_result( task_filter=task_filter, diff --git a/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py b/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py index adecb5d7203c..03e1533a7d42 100644 --- a/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py +++ b/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py @@ -1,6 +1,7 @@ from celery.exceptions import ( # type: ignore[import-untyped] #pylint: disable=no-name-in-module CeleryError, ) +from celery_library.errors import TaskNotFoundError from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from httpx import HTTPError as HttpxException @@ -8,7 +9,10 @@ from starlette import status from starlette.exceptions import HTTPException -from ..._constants import MSG_INTERNAL_ERROR_USER_FRIENDLY_TEMPLATE +from ..._constants import ( + MSG_CLIENT_ERROR_USER_FRIENDLY_TEMPLATE, + MSG_INTERNAL_ERROR_USER_FRIENDLY_TEMPLATE, +) from ...exceptions.backend_errors import BaseBackEndError from ..custom_errors import CustomBaseError from ..log_streaming_errors import LogStreamingBaseError @@ -41,6 +45,16 @@ def setup(app: FastAPI, *, is_debug: bool = False): ), ) + app.add_exception_handler( + TaskNotFoundError, + make_handler_for_exception( + TaskNotFoundError, + status.HTTP_404_NOT_FOUND, + error_message=MSG_CLIENT_ERROR_USER_FRIENDLY_TEMPLATE, + add_exception_to_message=True, + ), + ) + app.add_exception_handler( CeleryError, make_handler_for_exception( diff --git a/services/api-server/tests/unit/test_tasks.py b/services/api-server/tests/unit/test_tasks.py index 0a3439299fe5..02af824cc87d 100644 --- a/services/api-server/tests/unit/test_tasks.py +++ b/services/api-server/tests/unit/test_tasks.py @@ -154,28 +154,6 @@ async def test_get_task_result( None, status.HTTP_404_NOT_FOUND, ), - ( - "GET", - f"/v0/tasks/{_faker.uuid4()}/result", - None, - CeleryTaskStatus( - task_uuid=TaskUUID("123e4567-e89b-12d3-a456-426614174000"), - task_state=TaskState.ABORTED, - progress_report=ProgressReport( - actual_value=0.5, - total=1.0, - unit="Byte", - message=ProgressStructuredMessage.model_validate( - ProgressStructuredMessage.model_json_schema( - schema_generator=GenerateResolvedJsonSchema - )["examples"][0] - ), - ), - ), - None, - None, - status.HTTP_409_CONFLICT, - ), ], ) async def test_celery_error_propagation(