Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
from servicelib.logging_utils import log_context

from ...dsm import get_dsm_provider
from ...modules.celery.models import TaskId
from ...modules.celery.utils import get_fastapi_app

_logger = logging.getLogger(__name__)


async def compute_path_size(
task: Task, user_id: UserID, location_id: LocationID, path: Path
task: Task, task_id: TaskId, user_id: UserID, location_id: LocationID, path: Path
) -> ByteSize:
assert task_id # nosec
with log_context(
_logger,
logging.INFO,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel

from ...models import FileMetaData
from ...modules.celery.models import TaskError


def _path_encoder(obj):
Expand Down Expand Up @@ -57,3 +58,4 @@ def register_celery_types() -> None:
_register_pydantic_types(FileUploadCompletionBody)
_register_pydantic_types(FileMetaData)
_register_pydantic_types(FoldersBody)
_register_pydantic_types(TaskError)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from functools import wraps
from typing import Any, Concatenate, ParamSpec, TypeVar, overload

from celery import ( # type: ignore[import-untyped]
Celery,
Task,
)
from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import AbortableTask # type: ignore[import-untyped]
from celery.exceptions import Ignore # type: ignore[import-untyped]

Expand All @@ -22,7 +19,7 @@

def error_handling(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any:
def wrapper(task: AbortableTask, *args: Any, **kwargs: Any) -> Any:
try:
return func(task, *args, **kwargs)
except Exception as exc:
Expand All @@ -31,8 +28,9 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any:
exc_traceback = traceback.format_exc().split("\n")

_logger.exception(
"Task %s failed with exception: %s",
"Task %s failed with exception: %s:%s",
task.request.id,
exc_type,
exc_message,
)

Expand All @@ -57,14 +55,14 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any:
def _async_task_wrapper(
app: Celery,
) -> Callable[
[Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[Task, P], R],
[Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]]],
Callable[Concatenate[AbortableTask, P], R],
]:
def decorator(
coro: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[Task, P], R]:
coro: Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[AbortableTask, P], R]:
@wraps(coro)
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
fastapi_app = get_fastapi_app(app)
_logger.debug("task id: %s", task.request.id)
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
Expand All @@ -82,29 +80,29 @@ def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
@overload
def define_task(
app: Celery,
fn: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]],
fn: Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]],
task_name: str | None = None,
) -> None: ...


@overload
def define_task(
app: Celery,
fn: Callable[Concatenate[Task, P], R],
fn: Callable[Concatenate[AbortableTask, P], R],
task_name: str | None = None,
) -> None: ...


def define_task( # type: ignore[misc]
app: Celery,
fn: (
Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[Task, P], R]
Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]]
| Callable[Concatenate[AbortableTask, P], R]
),
task_name: str | None = None,
) -> None:
"""Decorator to define a celery task with error handling and abortable support"""
wrapped_fn: Callable[Concatenate[Task, P], R]
wrapped_fn: Callable[Concatenate[AbortableTask, P], R]
if asyncio.iscoroutinefunction(fn):
wrapped_fn = _async_task_wrapper(app)(fn)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import logging
from dataclasses import dataclass
from typing import Any, Final
from uuid import uuid4

Expand All @@ -12,6 +13,7 @@
from pydantic import ValidationError
from servicelib.logging_utils import log_context

from ...exceptions.errors import ConfigurationError
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,36 +55,44 @@ def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
)


@dataclass
class CeleryTaskQueueClient:
def __init__(self, celery_app: Celery):
self._celery_app = celery_app
_celery_app: Celery

@make_async()
def send_task(
self, task_name: str, *, task_context: TaskContext, **task_params
) -> TaskUUID:
task_uuid = uuid4()
task_id = _build_task_id(task_context, task_uuid)
with log_context(
_logger,
logging.DEBUG,
msg=f"Submitting task {task_name}: {task_id=} {task_params=}",
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
):
task_uuid = uuid4()
task_id = _build_task_id(task_context, task_uuid)
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
return task_uuid

@staticmethod
@make_async()
def abort_task( # pylint: disable=R6301
self, task_context: TaskContext, task_uuid: TaskUUID
) -> None:
task_id = _build_task_id(task_context, task_uuid)
_logger.info("Aborting task %s", task_id)
AbortableAsyncResult(task_id).abort()
def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"Abort task {task_uuid=}: {task_context=}",
):
task_id = _build_task_id(task_context, task_uuid)
AbortableAsyncResult(task_id).abort()

@make_async()
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
task_id = _build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result
with log_context(
_logger,
logging.DEBUG,
msg=f"Get task {task_uuid=}: {task_context=} result",
):
task_id = _build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result

def _get_progress_report(
self, task_context: TaskContext, task_uuid: TaskUUID
Expand Down Expand Up @@ -118,15 +128,30 @@ def get_task_status(

def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
redis = self._celery_app.backend.client
if hasattr(redis, "keys") and (keys := redis.keys(search_key + "*")):
backend_client = self._celery_app.backend.client
if hasattr(backend_client, "keys") and (
keys := backend_client.keys(f"{search_key}*")
):
return {
TaskUUID(
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
for key in keys
}
return set()
if hasattr(backend_client, "cache"):
# NOTE: backend used in testing. It is a dict-like object
found_keys = set()
for key in backend_client.cache:
str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING)
if str_key.startswith(search_key):
found_keys.add(
TaskUUID(
f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
)
return found_keys
msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}"
raise ConfigurationError(msg=msg)

@make_async()
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
Expand Down
Loading