Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,35 @@
from asyncio import AbstractEventLoop

from fastapi import FastAPI
from servicelib.redis._client import RedisClientSDK
from settings_library.redis import RedisDatabase

from ..._meta import APP_NAME
from ...core.settings import get_application_settings
from ._celery_types import register_celery_types
from ._common import create_app
from .backends._redis import RedisTaskStore
from .client import CeleryTaskQueueClient

_logger = logging.getLogger(__name__)


def setup_celery_client(app: FastAPI) -> None:
async def on_startup() -> None:
celery_settings = get_application_settings(app).STORAGE_CELERY
application_settings = get_application_settings(app)
celery_settings = application_settings.STORAGE_CELERY
assert celery_settings # nosec
celery_app = create_app(celery_settings)
app.state.celery_client = CeleryTaskQueueClient(celery_app)
redis_client_sdk = RedisClientSDK(
celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
RedisDatabase.CELERY_TASKS
),
client_name=f"{APP_NAME}.celery_tasks",
)

app.state.celery_client = CeleryTaskQueueClient(
celery_app, RedisTaskStore(redis_client_sdk)
)

register_celery_types()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Final

from servicelib.redis._client import RedisClientSDK

from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix

_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"


class RedisTaskStore:
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self._redis_client_sdk = redis_client_sdk

async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = build_task_id_prefix(task_context) + _CELERY_TASK_ID_KEY_SEPARATOR
keys = set()
async for key in self._redis_client_sdk.redis.scan_iter(
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
):
# fake redis (tests) returns bytes, real redis returns str
_key = (
key.decode(_CELERY_TASK_ID_KEY_ENCODING)
if isinstance(key, bytes)
else key
)
keys.add(TaskUUID(_key.removeprefix(search_key)))
return keys

async def task_exists(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(task_id)
assert isinstance(n, int) # nosec
return n > 0

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None:
await self._redis_client_sdk.redis.set(
task_id,
task_data.model_dump_json(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
from pydantic import ValidationError
from servicelib.logging_utils import log_context

from ...exceptions.errors import ConfigurationError
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
from .models import (
TaskContext,
TaskData,
TaskState,
TaskStatus,
TaskStore,
TaskUUID,
build_task_id,
)

_logger = logging.getLogger(__name__)

_CELERY_INSPECT_TASK_STATUSES: Final[tuple[str, ...]] = (
"active",
"scheduled",
"revoked",
)
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = {
"PENDING": TaskState.PENDING,
"STARTED": TaskState.PENDING,
Expand All @@ -34,33 +35,17 @@
"FAILURE": TaskState.ERROR,
"ERROR": TaskState.ERROR,
}
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"

_MIN_PROGRESS_VALUE = 0.0
_MAX_PROGRESS_VALUE = 1.0


def _build_context_prefix(task_context: TaskContext) -> list[str]:
return [f"{task_context[key]}" for key in sorted(task_context)]


def _build_task_id_prefix(task_context: TaskContext) -> str:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(_build_context_prefix(task_context))


def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[_build_task_id_prefix(task_context), f"{task_uuid}"]
)


@dataclass
class CeleryTaskQueueClient:
_celery_app: Celery
_task_store: TaskStore

@make_async()
def send_task(
async def send_task(
self, task_name: str, *, task_context: TaskContext, **task_params
) -> TaskUUID:
with log_context(
Expand All @@ -69,8 +54,11 @@ def send_task(
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
):
task_uuid = uuid4()
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
await self._task_store.set_task(
task_id, TaskData(status=TaskState.PENDING.name)
)
return task_uuid

@staticmethod
Expand All @@ -79,25 +67,25 @@ def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"Abort task {task_uuid=}: {task_context=}",
msg=f"Abort task: {task_context=} {task_uuid=}",
):
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
AbortableAsyncResult(task_id).abort()

@make_async()
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
with log_context(
_logger,
logging.DEBUG,
msg=f"Get task {task_uuid=}: {task_context=} result",
msg=f"Get task result: {task_context=} {task_uuid=}",
):
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result

def _get_progress_report(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> ProgressReport:
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
result = self._celery_app.AsyncResult(task_id).result
state = self._get_state(task_context, task_uuid)
if result and state == TaskState.RUNNING:
Expand All @@ -117,64 +105,28 @@ def _get_progress_report(
)

def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]

@make_async()
def get_task_status(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> TaskStatus:
return TaskStatus(
task_uuid=task_uuid,
task_state=self._get_state(task_context, task_uuid),
progress_report=self._get_progress_report(task_context, task_uuid),
)

def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
backend_client = self._celery_app.backend.client
if hasattr(backend_client, "keys"):
if keys := backend_client.keys(f"{search_key}*"):
return {
TaskUUID(
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
for key in keys
}
return set()
if hasattr(backend_client, "cache"):
# NOTE: backend used in testing. It is a dict-like object
found_keys = set()
for key in backend_client.cache:
str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING)
if str_key.startswith(search_key):
found_keys.add(
TaskUUID(
f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
)
return found_keys
msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}"
raise ConfigurationError(msg=msg)

@make_async()
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
task_uuids = self._get_completed_task_uuids(task_context)

task_id_prefix = _build_task_id_prefix(task_context)
inspect = self._celery_app.control.inspect()
for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES:
tasks = getattr(inspect, task_inspect_status)() or {}

task_uuids.update(
TaskUUID(
task_info["id"].removeprefix(
task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR
)
)
for tasks_per_worker in tasks.values()
for task_info in tasks_per_worker
if "id" in task_info
with log_context(
_logger,
logging.DEBUG,
msg=f"Getting task status: {task_context=} {task_uuid=}",
):
return TaskStatus(
task_uuid=task_uuid,
task_state=self._get_state(task_context, task_uuid),
progress_report=self._get_progress_report(task_context, task_uuid),
)

return task_uuids
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
with log_context(
_logger,
logging.DEBUG,
msg=f"Getting task uuids: {task_context=}",
):
return await self._task_store.get_task_uuids(task_context)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import StrEnum, auto
from typing import Any, Self, TypeAlias
from typing import Any, Final, Protocol, Self, TypeAlias
from uuid import UUID

from models_library.progress_bar import ProgressReport
Expand All @@ -9,6 +9,20 @@
TaskID: TypeAlias = str
TaskUUID: TypeAlias = UUID

_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"


def build_task_id_prefix(task_context: TaskContext) -> str:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[f"{task_context[key]}" for key in sorted(task_context)]
)


def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[build_task_id_prefix(task_context), f"{task_uuid}"]
)


class TaskState(StrEnum):
PENDING = auto()
Expand All @@ -18,9 +32,21 @@ class TaskState(StrEnum):
ABORTED = auto()


class TaskData(BaseModel):
status: str


_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}


class TaskStore(Protocol):
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...

async def task_exists(self, task_id: TaskID) -> bool: ...

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: ...


class TaskStatus(BaseModel):
task_uuid: TaskUUID
task_state: TaskState
Expand Down
Loading