Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from asyncio import AbstractEventLoop

from fastapi import FastAPI
from servicelib.redis._client import RedisClientSDK
from settings_library.redis import RedisDatabase
from simcore_service_storage.modules.celery.backends._redis import RedisTaskStore

from ..._meta import APP_NAME
from ...core.settings import get_application_settings
from ._celery_types import register_celery_types
from ._common import create_app
Expand All @@ -13,10 +17,20 @@

def setup_celery_client(app: FastAPI) -> None:
async def on_startup() -> None:
celery_settings = get_application_settings(app).STORAGE_CELERY
application_settings = get_application_settings(app)
celery_settings = application_settings.STORAGE_CELERY
assert celery_settings # nosec
celery_app = create_app(celery_settings)
app.state.celery_client = CeleryTaskQueueClient(celery_app)
redis_client_sdk = RedisClientSDK(
celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
RedisDatabase.CELERY_TASKS
),
client_name=f"{APP_NAME}.celery_tasks",
)

app.state.celery_client = CeleryTaskQueueClient(
celery_app, RedisTaskStore(redis_client_sdk)
)

register_celery_types()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Final

from servicelib.redis._client import RedisClientSDK

from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix

_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000


class RedisTaskStore:
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self._redis_client_sdk = redis_client_sdk

async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = (
_CELERY_TASK_META_PREFIX
+ build_task_id_prefix(task_context)
+ _CELERY_TASK_ID_KEY_SEPARATOR
)
keys = set()
async for key in self._redis_client_sdk.redis.scan_iter(
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
):
keys.add(TaskUUID(f"{key}".removeprefix(search_key)))
return keys

async def task_exists(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(task_id)
assert isinstance(n, int) # nosec
return n > 0

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None:
await self._redis_client_sdk.redis.set(
task_id,
task_data.model_dump_json(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
from pydantic import ValidationError
from servicelib.logging_utils import log_context

from ...exceptions.errors import ConfigurationError
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
from .models import (
TaskContext,
TaskData,
TaskState,
TaskStatus,
TaskStore,
TaskUUID,
build_task_id,
)

_logger = logging.getLogger(__name__)

_CELERY_INSPECT_TASK_STATUSES: Final[tuple[str, ...]] = (
"active",
"scheduled",
"revoked",
)
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = {
"PENDING": TaskState.PENDING,
"STARTED": TaskState.PENDING,
Expand All @@ -34,33 +35,17 @@
"FAILURE": TaskState.ERROR,
"ERROR": TaskState.ERROR,
}
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"

_MIN_PROGRESS_VALUE = 0.0
_MAX_PROGRESS_VALUE = 1.0


def _build_context_prefix(task_context: TaskContext) -> list[str]:
return [f"{task_context[key]}" for key in sorted(task_context)]


def _build_task_id_prefix(task_context: TaskContext) -> str:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(_build_context_prefix(task_context))


def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[_build_task_id_prefix(task_context), f"{task_uuid}"]
)


@dataclass
class CeleryTaskQueueClient:
_celery_app: Celery
_task_store: TaskStore

@make_async()
def send_task(
async def send_task(
self, task_name: str, *, task_context: TaskContext, **task_params
) -> TaskUUID:
with log_context(
Expand All @@ -69,35 +54,29 @@ def send_task(
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
):
task_uuid = uuid4()
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
await self._task_store.set_task(
task_id, TaskData(status=TaskState.PENDING.name)
)
return task_uuid

@staticmethod
@make_async()
def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"Abort task {task_uuid=}: {task_context=}",
):
task_id = _build_task_id(task_context, task_uuid)
AbortableAsyncResult(task_id).abort()
task_id = build_task_id(task_context, task_uuid)
_logger.info("Aborting task %s", task_id)
AbortableAsyncResult(task_id).abort()

@make_async()
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
with log_context(
_logger,
logging.DEBUG,
msg=f"Get task {task_uuid=}: {task_context=} result",
):
task_id = _build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result
task_id = build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result

def _get_progress_report(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> ProgressReport:
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
result = self._celery_app.AsyncResult(task_id).result
state = self._get_state(task_context, task_uuid)
if result and state == TaskState.RUNNING:
Expand All @@ -117,7 +96,7 @@ def _get_progress_report(
)

def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]

@make_async()
Expand All @@ -130,51 +109,5 @@ def get_task_status(
progress_report=self._get_progress_report(task_context, task_uuid),
)

def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
backend_client = self._celery_app.backend.client
if hasattr(backend_client, "keys"):
if keys := backend_client.keys(f"{search_key}*"):
return {
TaskUUID(
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
for key in keys
}
return set()
if hasattr(backend_client, "cache"):
# NOTE: backend used in testing. It is a dict-like object
found_keys = set()
for key in backend_client.cache:
str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING)
if str_key.startswith(search_key):
found_keys.add(
TaskUUID(
f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
)
return found_keys
msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}"
raise ConfigurationError(msg=msg)

@make_async()
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
task_uuids = self._get_completed_task_uuids(task_context)

task_id_prefix = _build_task_id_prefix(task_context)
inspect = self._celery_app.control.inspect()
for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES:
tasks = getattr(inspect, task_inspect_status)() or {}

task_uuids.update(
TaskUUID(
task_info["id"].removeprefix(
task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR
)
)
for tasks_per_worker in tasks.values()
for task_info in tasks_per_worker
if "id" in task_info
)

return task_uuids
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
return await self._task_store.get_task_uuids(task_context)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import StrEnum, auto
from typing import Any, Self, TypeAlias
from typing import Any, Final, Protocol, Self, TypeAlias
from uuid import UUID

from models_library.progress_bar import ProgressReport
Expand All @@ -9,6 +9,20 @@
TaskID: TypeAlias = str
TaskUUID: TypeAlias = UUID

_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"


def build_task_id_prefix(task_context: TaskContext) -> str:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[f"{task_context[key]}" for key in sorted(task_context)]
)


def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[build_task_id_prefix(task_context), f"{task_uuid}"]
)


class TaskState(StrEnum):
PENDING = auto()
Expand All @@ -18,9 +32,21 @@ class TaskState(StrEnum):
ABORTED = auto()


class TaskData(BaseModel):
status: str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type?



_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}


class TaskStore(Protocol):
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...

async def task_exists(self, task_id: TaskID) -> bool: ...

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: ...


class TaskStatus(BaseModel):
task_uuid: TaskUUID
task_state: TaskState
Expand Down
Loading