Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/celery-library/src/celery_library/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import pickle

from common_library.errors_classes import OsparcErrorMixin


class TransferrableCeleryError(Exception):
def __repr__(self) -> str:
Expand All @@ -22,3 +24,7 @@ def decode_celery_transferrable_error(error: TransferrableCeleryError) -> Except
assert isinstance(error, TransferrableCeleryError) # nosec
result: Exception = pickle.loads(base64.b64decode(error.args[0])) # noqa: S301
return result


class TaskNotFoundError(OsparcErrorMixin, Exception):
msg_template = "Task with id '{task_id}' not found"
17 changes: 12 additions & 5 deletions packages/celery-library/src/celery_library/rpc/_async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobSchedulerError,
)
Expand All @@ -22,6 +23,7 @@
from servicelib.rabbitmq import RPCRouter

from ..errors import (
TaskNotFoundError,
TransferrableCeleryError,
decode_celery_transferrable_error,
)
Expand All @@ -30,7 +32,7 @@
router = RPCRouter()


@router.expose(reraise_if_error_type=(JobSchedulerError,))
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
async def cancel(
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
):
Expand All @@ -42,11 +44,13 @@ async def cancel(
task_filter=task_filter,
task_uuid=job_id,
)
except TaskNotFoundError as exc:
raise JobMissingError(job_id=job_id) from exc
except CeleryError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc


@router.expose(reraise_if_error_type=(JobSchedulerError,))
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
async def status(
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
) -> AsyncJobStatus:
Expand All @@ -59,6 +63,8 @@ async def status(
task_filter=task_filter,
task_uuid=job_id,
)
except TaskNotFoundError as exc:
raise JobMissingError(job_id=job_id) from exc
except CeleryError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc

Expand All @@ -71,9 +77,10 @@ async def status(

@router.expose(
reraise_if_error_type=(
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobAbortedError,
JobSchedulerError,
)
)
Expand All @@ -97,11 +104,11 @@ async def result(
task_filter=task_filter,
task_uuid=job_id,
)
except TaskNotFoundError as exc:
raise JobMissingError(job_id=job_id) from exc
except CeleryError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc

if _status.task_state == TaskState.ABORTED:
raise JobAbortedError(job_id=job_id)
if _status.task_state == TaskState.FAILURE:
# fallback exception to report
exc_type = type(_result).__name__
Expand Down
43 changes: 19 additions & 24 deletions packages/celery-library/src/celery_library/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from functools import wraps
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload

from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import ( # type: ignore[import-untyped]
AbortableAsyncResult,
AbortableTask,
)
from celery import Celery, Task # type: ignore[import-untyped]
from celery.exceptions import Ignore # type: ignore[import-untyped]
from common_library.async_tools import cancel_wait_task
from pydantic import NonNegativeInt
Expand Down Expand Up @@ -39,14 +35,14 @@ class TaskAbortedError(Exception): ...
def _async_task_wrapper(
app: Celery,
) -> Callable[
[Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[AbortableTask, P], R],
[Callable[Concatenate[Task, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[Task, P], R],
]:
def decorator(
coro: Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[AbortableTask, P], R]:
coro: Callable[Concatenate[Task, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[Task, P], R]:
@wraps(coro)
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
app_server = get_app_server(app)
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
assert task.request.id is not None # nosec
Expand All @@ -59,14 +55,14 @@ async def run_task(task_id: TaskID) -> R:
)

async def abort_monitor():
abortable_result = AbortableAsyncResult(task_id, app=app)
while not main_task.done():
if abortable_result.is_aborted():
if not await app_server.task_manager.exists_task(
task_id
):
await cancel_wait_task(
main_task,
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
)
AbortableAsyncResult(task_id, app=app).forget()
raise TaskAbortedError
await asyncio.sleep(
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()
Expand Down Expand Up @@ -102,14 +98,14 @@ def _error_handling(
delay_between_retries: timedelta,
dont_autoretry_for: tuple[type[Exception], ...],
) -> Callable[
[Callable[Concatenate[AbortableTask, P], R]],
Callable[Concatenate[AbortableTask, P], R],
[Callable[Concatenate[Task, P], R]],
Callable[Concatenate[Task, P], R],
]:
def decorator(
func: Callable[Concatenate[AbortableTask, P], R],
) -> Callable[Concatenate[AbortableTask, P], R]:
func: Callable[Concatenate[Task, P], R],
) -> Callable[Concatenate[Task, P], R]:
@wraps(func)
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
try:
return func(task, *args, **kwargs)
except TaskAbortedError as exc:
Expand Down Expand Up @@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
@overload
def register_task(
app: Celery,
fn: Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]],
fn: Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]],
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
Expand All @@ -156,7 +152,7 @@ def register_task(
@overload
def register_task(
app: Celery,
fn: Callable[Concatenate[AbortableTask, P], R],
fn: Callable[Concatenate[Task, P], R],
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
Expand All @@ -168,8 +164,8 @@ def register_task(
def register_task( # type: ignore[misc]
app: Celery,
fn: (
Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[AbortableTask, P], R]
Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[Task, P], R]
),
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
Expand All @@ -186,7 +182,7 @@ def register_task( # type: ignore[misc]
delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY})
dont_autoretry_for -- exceptions that should not be retried when raised by the task
"""
wrapped_fn: Callable[Concatenate[AbortableTask, P], R]
wrapped_fn: Callable[Concatenate[Task, P], R]
if asyncio.iscoroutinefunction(fn):
wrapped_fn = _async_task_wrapper(app)(fn)
else:
Expand All @@ -202,7 +198,6 @@ def register_task( # type: ignore[misc]
app.task(
name=task_name or fn.__name__,
bind=True,
base=AbortableTask,
time_limit=None if timeout is None else timeout.total_seconds(),
pydantic=True,
)(wrapped_fn)
41 changes: 22 additions & 19 deletions packages/celery-library/src/celery_library/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
from uuid import uuid4

from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import ( # type: ignore[import-untyped]
AbortableAsyncResult,
)
from common_library.async_tools import make_async
from models_library.progress_bar import ProgressReport
from servicelib.celery.models import (
TASK_DONE_STATES,
Task,
TaskFilter,
TaskID,
Expand All @@ -23,6 +21,7 @@
from servicelib.logging_utils import log_context
from settings_library.celery import CelerySettings

from .errors import TaskNotFoundError
from .utils import build_task_id

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,24 +68,25 @@ async def submit_task(
)
return task_uuid

@make_async()
def _abort_task(self, task_id: TaskID) -> None:
AbortableAsyncResult(task_id, app=self._celery_app).abort()

async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"task cancellation: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
await self._abort_task(task_id)
if not await self.exists_task(task_id):
raise TaskNotFoundError(task_id=task_id)

await self._task_info_store.remove_task(task_id)
await self._forget_task(task_id)

async def exists_task(self, task_id: TaskID) -> bool:
return await self._task_info_store.exists_task(task_id)

@make_async()
def _forget_task(self, task_id: TaskID) -> None:
AbortableAsyncResult(task_id, app=self._celery_app).forget()
self._celery_app.AsyncResult(task_id).forget()

async def get_task_result(
self, task_filter: TaskFilter, task_uuid: TaskUUID
Expand All @@ -97,27 +97,27 @@ async def get_task_result(
msg=f"Get task result: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
if not await self.exists_task(task_id):
raise TaskNotFoundError(task_id=task_id)

async_result = self._celery_app.AsyncResult(task_id)
result = async_result.result
if async_result.ready():
task_metadata = await self._task_info_store.get_task_metadata(task_id)
if task_metadata is not None and task_metadata.ephemeral:
await self._forget_task(task_id)
await self._task_info_store.remove_task(task_id)
await self._forget_task(task_id)
return result

async def _get_task_progress_report(
self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState
self, task_id: TaskID, task_state: TaskState
) -> ProgressReport:
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
task_id = build_task_id(task_filter, task_uuid)
if task_state in (TaskState.STARTED, TaskState.RETRY):
progress = await self._task_info_store.get_task_progress(task_id)
if progress is not None:
return progress
if task_state in (
TaskState.SUCCESS,
TaskState.FAILURE,
):

if task_state in TASK_DONE_STATES:
return ProgressReport(
actual_value=_MAX_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
)
Expand All @@ -140,12 +140,15 @@ async def get_task_status(
msg=f"Getting task status: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
if not await self.exists_task(task_id):
raise TaskNotFoundError(task_id=task_id)

task_state = await self._get_task_celery_state(task_id)
return TaskStatus(
task_uuid=task_uuid,
task_state=task_state,
progress_report=await self._get_task_progress_report(
task_filter, task_uuid, task_state
task_id, task_state
),
)

Expand Down
18 changes: 10 additions & 8 deletions packages/celery-library/tests/unit/test_async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
AsyncJobGet,
)
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
)
from models_library.products import ProductName
from models_library.rabbitmq_basic_types import RPCNamespace
Expand Down Expand Up @@ -308,12 +308,6 @@ async def test_async_jobs_cancel(
job_filter=job_filter,
)

await _wait_for_job(
async_jobs_rabbitmq_rpc_client,
async_job_get=async_job_get,
job_filter=job_filter,
)

jobs = await async_jobs.list_jobs(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
Expand All @@ -322,7 +316,15 @@ async def test_async_jobs_cancel(
)
assert async_job_get.job_id not in [job.job_id for job in jobs]

with pytest.raises(JobAbortedError):
with pytest.raises(JobMissingError):
await async_jobs.status(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
job_id=async_job_get.job_id,
job_filter=job_filter,
)

with pytest.raises(JobMissingError):
await async_jobs.result(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
Expand Down
17 changes: 3 additions & 14 deletions packages/celery-library/tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

import pytest
from celery import Celery, Task # pylint: disable=no-name-in-module
from celery.contrib.abortable import AbortableTask # pylint: disable=no-name-in-module
from celery_library.errors import TransferrableCeleryError
from celery_library.errors import TaskNotFoundError, TransferrableCeleryError
from celery_library.task import register_task
from celery_library.task_manager import CeleryTaskManager
from celery_library.utils import get_app_server
Expand Down Expand Up @@ -72,7 +71,7 @@ def failure_task(task: Task, task_id: TaskID) -> None:
raise MyError(msg=msg)


async def dreamer_task(task: AbortableTask, task_id: TaskID) -> list[int]:
async def dreamer_task(task: Task, task_id: TaskID) -> list[int]:
numbers = []
for _ in range(30):
numbers.append(randint(1, 90)) # noqa: S311
Expand Down Expand Up @@ -164,18 +163,8 @@ async def test_cancelling_a_running_task_aborts_and_deletes(

await celery_task_manager.cancel_task(task_filter, task_uuid)

for attempt in Retrying(
retry=retry_if_exception_type(AssertionError),
wait=wait_fixed(1),
stop=stop_after_delay(30),
):
with attempt:
progress = await celery_task_manager.get_task_status(task_filter, task_uuid)
assert progress.task_state == TaskState.ABORTED

assert (
with pytest.raises(TaskNotFoundError):
await celery_task_manager.get_task_status(task_filter, task_uuid)
).task_state == TaskState.ABORTED

assert task_uuid not in await celery_task_manager.list_tasks(task_filter)

Expand Down
Loading
Loading