Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"

Expand Down
6 changes: 6 additions & 0 deletions packages/celery-library/src/celery_library/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import pickle

from common_library.errors_classes import OsparcErrorMixin


class TransferrableCeleryError(Exception):
def __repr__(self) -> str:
Expand All @@ -22,3 +24,7 @@ def decode_celery_transferrable_error(error: TransferrableCeleryError) -> Except
assert isinstance(error, TransferrableCeleryError) # nosec
result: Exception = pickle.loads(base64.b64decode(error.args[0])) # noqa: S301
return result


class TaskNotFoundError(OsparcErrorMixin, Exception):
msg_template = "Task with id '{task_id}' not found"
17 changes: 12 additions & 5 deletions packages/celery-library/src/celery_library/rpc/_async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobSchedulerError,
)
Expand All @@ -22,6 +23,7 @@
from servicelib.rabbitmq import RPCRouter

from ..errors import (
TaskNotFoundError,
TransferrableCeleryError,
decode_celery_transferrable_error,
)
Expand All @@ -30,7 +32,7 @@
router = RPCRouter()


@router.expose(reraise_if_error_type=(JobSchedulerError,))
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
async def cancel(
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
):
Expand All @@ -42,11 +44,13 @@ async def cancel(
task_filter=task_filter,
task_uuid=job_id,
)
except TaskNotFoundError as exc:
raise JobMissingError(job_id=job_id) from exc
except CeleryError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc


@router.expose(reraise_if_error_type=(JobSchedulerError,))
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
async def status(
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
) -> AsyncJobStatus:
Expand All @@ -59,6 +63,8 @@ async def status(
task_filter=task_filter,
task_uuid=job_id,
)
except TaskNotFoundError as exc:
raise JobMissingError(job_id=job_id) from exc
except CeleryError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc

Expand All @@ -71,9 +77,10 @@ async def status(

@router.expose(
reraise_if_error_type=(
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobAbortedError,
JobSchedulerError,
)
)
Expand All @@ -97,11 +104,11 @@ async def result(
task_filter=task_filter,
task_uuid=job_id,
)
except TaskNotFoundError as exc:
raise JobMissingError(job_id=job_id) from exc
except CeleryError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc

if _status.task_state == TaskState.ABORTED:
raise JobAbortedError(job_id=job_id)
if _status.task_state == TaskState.FAILURE:
# fallback exception to report
exc_type = type(_result).__name__
Expand Down
43 changes: 19 additions & 24 deletions packages/celery-library/src/celery_library/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from functools import wraps
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload

from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import ( # type: ignore[import-untyped]
AbortableAsyncResult,
AbortableTask,
)
from celery import Celery, Task # type: ignore[import-untyped]
from celery.exceptions import Ignore # type: ignore[import-untyped]
from common_library.async_tools import cancel_wait_task
from pydantic import NonNegativeInt
Expand Down Expand Up @@ -39,14 +35,14 @@ class TaskAbortedError(Exception): ...
def _async_task_wrapper(
app: Celery,
) -> Callable[
[Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[AbortableTask, P], R],
[Callable[Concatenate[Task, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[Task, P], R],
]:
def decorator(
coro: Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[AbortableTask, P], R]:
coro: Callable[Concatenate[Task, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[Task, P], R]:
@wraps(coro)
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
app_server = get_app_server(app)
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
assert task.request.id is not None # nosec
Expand All @@ -59,14 +55,14 @@ async def run_task(task_id: TaskID) -> R:
)

async def abort_monitor():
abortable_result = AbortableAsyncResult(task_id, app=app)
while not main_task.done():
if abortable_result.is_aborted():
if not await app_server.task_manager.exists_task(
task_id
):
await cancel_wait_task(
main_task,
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
)
AbortableAsyncResult(task_id, app=app).forget()
raise TaskAbortedError
await asyncio.sleep(
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()
Expand Down Expand Up @@ -102,14 +98,14 @@ def _error_handling(
delay_between_retries: timedelta,
dont_autoretry_for: tuple[type[Exception], ...],
) -> Callable[
[Callable[Concatenate[AbortableTask, P], R]],
Callable[Concatenate[AbortableTask, P], R],
[Callable[Concatenate[Task, P], R]],
Callable[Concatenate[Task, P], R],
]:
def decorator(
func: Callable[Concatenate[AbortableTask, P], R],
) -> Callable[Concatenate[AbortableTask, P], R]:
func: Callable[Concatenate[Task, P], R],
) -> Callable[Concatenate[Task, P], R]:
@wraps(func)
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
try:
return func(task, *args, **kwargs)
except TaskAbortedError as exc:
Expand Down Expand Up @@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
@overload
def register_task(
app: Celery,
fn: Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]],
fn: Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]],
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
Expand All @@ -156,7 +152,7 @@ def register_task(
@overload
def register_task(
app: Celery,
fn: Callable[Concatenate[AbortableTask, P], R],
fn: Callable[Concatenate[Task, P], R],
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
Expand All @@ -168,8 +164,8 @@ def register_task(
def register_task( # type: ignore[misc]
app: Celery,
fn: (
Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[AbortableTask, P], R]
Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[Task, P], R]
),
task_name: str | None = None,
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
Expand All @@ -186,7 +182,7 @@ def register_task( # type: ignore[misc]
delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY})
dont_autoretry_for -- exceptions that should not be retried when raised by the task
"""
wrapped_fn: Callable[Concatenate[AbortableTask, P], R]
wrapped_fn: Callable[Concatenate[Task, P], R]
if asyncio.iscoroutinefunction(fn):
wrapped_fn = _async_task_wrapper(app)(fn)
else:
Expand All @@ -202,7 +198,6 @@ def register_task( # type: ignore[misc]
app.task(
name=task_name or fn.__name__,
bind=True,
base=AbortableTask,
time_limit=None if timeout is None else timeout.total_seconds(),
pydantic=True,
)(wrapped_fn)
41 changes: 22 additions & 19 deletions packages/celery-library/src/celery_library/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
from uuid import uuid4

from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import ( # type: ignore[import-untyped]
AbortableAsyncResult,
)
from common_library.async_tools import make_async
from models_library.progress_bar import ProgressReport
from servicelib.celery.models import (
TASK_DONE_STATES,
Task,
TaskFilter,
TaskID,
Expand All @@ -23,6 +21,7 @@
from servicelib.logging_utils import log_context
from settings_library.celery import CelerySettings

from .errors import TaskNotFoundError
from .utils import build_task_id

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,24 +68,25 @@ async def submit_task(
)
return task_uuid

@make_async()
def _abort_task(self, task_id: TaskID) -> None:
AbortableAsyncResult(task_id, app=self._celery_app).abort()

async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"task cancellation: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
await self._abort_task(task_id)
if not await self.exists_task(task_id):
raise TaskNotFoundError(task_id=task_id)

await self._task_info_store.remove_task(task_id)
await self._forget_task(task_id)

async def exists_task(self, task_id: TaskID) -> bool:
return await self._task_info_store.exists_task(task_id)

@make_async()
def _forget_task(self, task_id: TaskID) -> None:
AbortableAsyncResult(task_id, app=self._celery_app).forget()
self._celery_app.AsyncResult(task_id).forget()

async def get_task_result(
self, task_filter: TaskFilter, task_uuid: TaskUUID
Expand All @@ -97,27 +97,27 @@ async def get_task_result(
msg=f"Get task result: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
if not await self.exists_task(task_id):
raise TaskNotFoundError(task_id=task_id)

async_result = self._celery_app.AsyncResult(task_id)
result = async_result.result
if async_result.ready():
task_metadata = await self._task_info_store.get_task_metadata(task_id)
if task_metadata is not None and task_metadata.ephemeral:
await self._forget_task(task_id)
await self._task_info_store.remove_task(task_id)
await self._forget_task(task_id)
return result

async def _get_task_progress_report(
self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState
self, task_id: TaskID, task_state: TaskState
) -> ProgressReport:
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
task_id = build_task_id(task_filter, task_uuid)
if task_state in (TaskState.STARTED, TaskState.RETRY):
progress = await self._task_info_store.get_task_progress(task_id)
if progress is not None:
return progress
if task_state in (
TaskState.SUCCESS,
TaskState.FAILURE,
):

if task_state in TASK_DONE_STATES:
return ProgressReport(
actual_value=_MAX_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
)
Expand All @@ -140,12 +140,15 @@ async def get_task_status(
msg=f"Getting task status: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
if not await self.exists_task(task_id):
raise TaskNotFoundError(task_id=task_id)

task_state = await self._get_task_celery_state(task_id)
return TaskStatus(
task_uuid=task_uuid,
task_state=task_state,
progress_report=await self._get_task_progress_report(
task_filter, task_uuid, task_state
task_id, task_state
),
)

Expand Down
18 changes: 10 additions & 8 deletions packages/celery-library/tests/unit/test_async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
AsyncJobGet,
)
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
)
from models_library.products import ProductName
from models_library.rabbitmq_basic_types import RPCNamespace
Expand Down Expand Up @@ -308,12 +308,6 @@ async def test_async_jobs_cancel(
job_filter=job_filter,
)

await _wait_for_job(
async_jobs_rabbitmq_rpc_client,
async_job_get=async_job_get,
job_filter=job_filter,
)

jobs = await async_jobs.list_jobs(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
Expand All @@ -322,7 +316,15 @@ async def test_async_jobs_cancel(
)
assert async_job_get.job_id not in [job.job_id for job in jobs]

with pytest.raises(JobAbortedError):
with pytest.raises(JobMissingError):
await async_jobs.status(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
job_id=async_job_get.job_id,
job_filter=job_filter,
)

with pytest.raises(JobMissingError):
await async_jobs.result(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
Expand Down
Loading
Loading