Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions packages/celery-library/src/celery_library/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from functools import wraps
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload

from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import ( # type: ignore[import-untyped]
AbortableAsyncResult,
AbortableTask,
)
from celery import Celery, Task # type: ignore[import-untyped]
from celery.exceptions import Ignore # type: ignore[import-untyped]
from common_library.async_tools import cancel_wait_task
from pydantic import NonNegativeInt
Expand Down Expand Up @@ -39,14 +35,14 @@ class TaskAbortedError(Exception): ...
def _async_task_wrapper(
app: Celery,
) -> Callable[
[Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[AbortableTask, P], R],
[Callable[Concatenate[Task, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[Task, P], R],
]:
def decorator(
coro: Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[AbortableTask, P], R]:
coro: Callable[Concatenate[Task, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[Task, P], R]:
@wraps(coro)
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
app_server = get_app_server(app)
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
assert task.request.id is not None # nosec
Expand All @@ -59,14 +55,14 @@ async def run_task(task_id: TaskID) -> R:
)

async def abort_monitor():
abortable_result = AbortableAsyncResult(task_id, app=app)
while not main_task.done():
if abortable_result.is_aborted():
if not await app_server.task_manager.exists_task(
task_id
):
await cancel_wait_task(
main_task,
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
)
AbortableAsyncResult(task_id, app=app).forget()
raise TaskAbortedError
await asyncio.sleep(
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()
Expand Down Expand Up @@ -102,14 +98,14 @@ def _error_handling(
delay_between_retries: timedelta,
dont_autoretry_for: tuple[type[Exception], ...],
) -> Callable[
[Callable[Concatenate[AbortableTask, P], R]],
Callable[Concatenate[AbortableTask, P], R],
[Callable[Concatenate[Task, P], R]],
Callable[Concatenate[Task, P], R],
]:
def decorator(
func: Callable[Concatenate[AbortableTask, P], R],
) -> Callable[Concatenate[AbortableTask, P], R]:
func: Callable[Concatenate[Task, P], R],
) -> Callable[Concatenate[Task, P], R]:
@wraps(func)
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
try:
return func(task, *args, **kwargs)
except TaskAbortedError as exc:
Expand Down Expand Up @@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
@overload
def register_task(
app: Celery,
fn: Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]],
fn: Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]],
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
Expand All @@ -156,7 +152,7 @@ def register_task(
@overload
def register_task(
app: Celery,
fn: Callable[Concatenate[AbortableTask, P], R],
fn: Callable[Concatenate[Task, P], R],
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
Expand All @@ -168,8 +164,8 @@ def register_task(
def register_task( # type: ignore[misc]
app: Celery,
fn: (
Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[AbortableTask, P], R]
Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[Task, P], R]
),
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
Expand All @@ -186,7 +182,7 @@ def register_task( # type: ignore[misc]
delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY})
dont_autoretry_for -- exceptions that should not be retried when raised by the task
"""
wrapped_fn: Callable[Concatenate[AbortableTask, P], R]
wrapped_fn: Callable[Concatenate[Task, P], R]
if asyncio.iscoroutinefunction(fn):
wrapped_fn = _async_task_wrapper(app)(fn)
else:
Expand All @@ -202,7 +198,6 @@ def register_task( # type: ignore[misc]
app.task(
name=task_name or fn.__name__,
bind=True,
base=AbortableTask,
time_limit=None if timeout is None else timeout.total_seconds(),
pydantic=True,
)(wrapped_fn)
29 changes: 13 additions & 16 deletions packages/celery-library/src/celery_library/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
from uuid import uuid4

from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import ( # type: ignore[import-untyped]
AbortableAsyncResult,
)
from common_library.async_tools import make_async
from models_library.progress_bar import ProgressReport
from servicelib.celery.models import (
TASK_FINAL_STATES,
Task,
TaskFilter,
TaskID,
Expand Down Expand Up @@ -69,24 +67,22 @@ async def submit_task(
)
return task_uuid

@make_async()
def _abort_task(self, task_id: TaskID) -> None:
AbortableAsyncResult(task_id, app=self._celery_app).abort()

async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"task cancellation: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
await self._abort_task(task_id)
await self._forget_task(task_id)
await self._task_info_store.remove_task(task_id)

async def exists_task(self, task_id: TaskID) -> bool:
return await self._task_info_store.exists_task(task_id)

@make_async()
def _forget_task(self, task_id: TaskID) -> None:
AbortableAsyncResult(task_id, app=self._celery_app).forget()
self._celery_app.AsyncResult(task_id).forget()

async def get_task_result(
self, task_filter: TaskFilter, task_uuid: TaskUUID
Expand All @@ -109,15 +105,13 @@ async def get_task_result(
async def _get_task_progress_report(
self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState
) -> ProgressReport:
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
if task_state in (TaskState.STARTED, TaskState.RETRY):
task_id = build_task_id(task_filter, task_uuid)
progress = await self._task_info_store.get_task_progress(task_id)
if progress is not None:
return progress
if task_state in (
TaskState.SUCCESS,
TaskState.FAILURE,
):

if task_state in TASK_FINAL_STATES:
return ProgressReport(
actual_value=_MAX_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
)
Expand All @@ -140,7 +134,10 @@ async def get_task_status(
msg=f"Getting task status: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
task_state = await self._get_task_celery_state(task_id)
if not await self.exists_task(task_id):
task_state = TaskState.ABORTED
else:
task_state = await self._get_task_celery_state(task_id)
return TaskStatus(
task_uuid=task_uuid,
task_state=task_state,
Expand Down
3 changes: 1 addition & 2 deletions packages/celery-library/tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import pytest
from celery import Celery, Task # pylint: disable=no-name-in-module
from celery.contrib.abortable import AbortableTask # pylint: disable=no-name-in-module
from celery_library.errors import TransferrableCeleryError
from celery_library.task import register_task
from celery_library.task_manager import CeleryTaskManager
Expand Down Expand Up @@ -72,7 +71,7 @@ def failure_task(task: Task, task_id: TaskID) -> None:
raise MyError(msg=msg)


async def dreamer_task(task: AbortableTask, task_id: TaskID) -> list[int]:
async def dreamer_task(task: Task, task_id: TaskID) -> list[int]:
numbers = []
for _ in range(30):
numbers.append(randint(1, 90)) # noqa: S311
Expand Down
8 changes: 4 additions & 4 deletions packages/service-library/src/servicelib/celery/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class TaskState(StrEnum):
ABORTED = "ABORTED"


TASK_FINAL_STATES = {TaskState.SUCCESS, TaskState.FAILURE, TaskState.ABORTED}


class TasksQueue(StrEnum):
CPU_BOUND = "cpu_bound"
DEFAULT = "default"
Expand Down Expand Up @@ -78,9 +81,6 @@ def _update_json_schema_extra(schema: JsonDict) -> None:
model_config = ConfigDict(json_schema_extra=_update_json_schema_extra)


_TASK_DONE = {TaskState.SUCCESS, TaskState.FAILURE, TaskState.ABORTED}


class TaskInfoStore(Protocol):
async def create_task(
self,
Expand Down Expand Up @@ -138,4 +138,4 @@ def _update_json_schema_extra(schema: JsonDict) -> None:

@property
def is_done(self) -> bool:
return self.task_state in _TASK_DONE
return self.task_state in TASK_FINAL_STATES
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ async def cancel_task(
self, task_filter: TaskFilter, task_uuid: TaskUUID
) -> None: ...

async def exists_task(self, task_id: TaskID) -> bool: ...

async def get_task_result(
self, task_filter: TaskFilter, task_uuid: TaskUUID
) -> Any: ...
Expand Down
Loading