Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from asyncio import AbstractEventLoop

from fastapi import FastAPI
from servicelib.redis._client import RedisClientSDK
from settings_library.redis import RedisDatabase
from simcore_service_storage.modules.celery.backends._redis import RedisTaskStore

from ..._meta import APP_NAME
from ...core.settings import get_application_settings
from ._celery_types import register_celery_types
from ._common import create_app
Expand All @@ -13,10 +17,20 @@

def setup_celery_client(app: FastAPI) -> None:
async def on_startup() -> None:
celery_settings = get_application_settings(app).STORAGE_CELERY
application_settings = get_application_settings(app)
celery_settings = application_settings.STORAGE_CELERY
assert celery_settings # nosec
celery_app = create_app(celery_settings)
app.state.celery_client = CeleryTaskQueueClient(celery_app)
redis_client_sdk = RedisClientSDK(
celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
RedisDatabase.CELERY_TASKS
),
client_name=f"{APP_NAME}.celery_tasks",
)

app.state.celery_client = CeleryTaskQueueClient(
celery_app, RedisTaskStore(redis_client_sdk)
)

register_celery_types()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Final

from servicelib.redis._client import RedisClientSDK

from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix

_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"


class RedisTaskStore:
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self._redis_client_sdk = redis_client_sdk

async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = build_task_id_prefix(task_context) + _CELERY_TASK_ID_KEY_SEPARATOR
keys = set()
async for key in self._redis_client_sdk.redis.scan_iter(
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
):
keys.add(
TaskUUID(
key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key)
)
)
return keys

async def task_exists(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(task_id)
assert isinstance(n, int) # nosec
return n > 0

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None:
await self._redis_client_sdk.redis.set(
task_id,
task_data.model_dump_json(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
from pydantic import ValidationError
from servicelib.logging_utils import log_context

from ...exceptions.errors import ConfigurationError
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
from .models import (
TaskContext,
TaskData,
TaskState,
TaskStatus,
TaskStore,
TaskUUID,
build_task_id,
)

_logger = logging.getLogger(__name__)

_CELERY_INSPECT_TASK_STATUSES: Final[tuple[str, ...]] = (
"active",
"scheduled",
"revoked",
)
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = {
"PENDING": TaskState.PENDING,
"STARTED": TaskState.PENDING,
Expand All @@ -34,33 +35,17 @@
"FAILURE": TaskState.ERROR,
"ERROR": TaskState.ERROR,
}
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"

_MIN_PROGRESS_VALUE = 0.0
_MAX_PROGRESS_VALUE = 1.0


def _build_context_prefix(task_context: TaskContext) -> list[str]:
return [f"{task_context[key]}" for key in sorted(task_context)]


def _build_task_id_prefix(task_context: TaskContext) -> str:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(_build_context_prefix(task_context))


def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[_build_task_id_prefix(task_context), f"{task_uuid}"]
)


@dataclass
class CeleryTaskQueueClient:
_celery_app: Celery
_task_store: TaskStore

@make_async()
def send_task(
async def send_task(
self, task_name: str, *, task_context: TaskContext, **task_params
) -> TaskUUID:
with log_context(
Expand All @@ -69,35 +54,29 @@ def send_task(
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
):
task_uuid = uuid4()
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
await self._task_store.set_task(
task_id, TaskData(status=TaskState.PENDING.name)
)
return task_uuid

@staticmethod
@make_async()
def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"Abort task {task_uuid=}: {task_context=}",
):
task_id = _build_task_id(task_context, task_uuid)
AbortableAsyncResult(task_id).abort()
task_id = build_task_id(task_context, task_uuid)
_logger.info("Aborting task %s", task_id)
AbortableAsyncResult(task_id).abort()

@make_async()
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
with log_context(
_logger,
logging.DEBUG,
msg=f"Get task {task_uuid=}: {task_context=} result",
):
task_id = _build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result
task_id = build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result

def _get_progress_report(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> ProgressReport:
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
result = self._celery_app.AsyncResult(task_id).result
state = self._get_state(task_context, task_uuid)
if result and state == TaskState.RUNNING:
Expand All @@ -117,7 +96,7 @@ def _get_progress_report(
)

def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
task_id = _build_task_id(task_context, task_uuid)
task_id = build_task_id(task_context, task_uuid)
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]

@make_async()
Expand All @@ -130,51 +109,5 @@ def get_task_status(
progress_report=self._get_progress_report(task_context, task_uuid),
)

def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
backend_client = self._celery_app.backend.client
if hasattr(backend_client, "keys"):
if keys := backend_client.keys(f"{search_key}*"):
return {
TaskUUID(
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
for key in keys
}
return set()
if hasattr(backend_client, "cache"):
# NOTE: backend used in testing. It is a dict-like object
found_keys = set()
for key in backend_client.cache:
str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING)
if str_key.startswith(search_key):
found_keys.add(
TaskUUID(
f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
)
return found_keys
msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}"
raise ConfigurationError(msg=msg)

@make_async()
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
task_uuids = self._get_completed_task_uuids(task_context)

task_id_prefix = _build_task_id_prefix(task_context)
inspect = self._celery_app.control.inspect()
for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES:
tasks = getattr(inspect, task_inspect_status)() or {}

task_uuids.update(
TaskUUID(
task_info["id"].removeprefix(
task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR
)
)
for tasks_per_worker in tasks.values()
for task_info in tasks_per_worker
if "id" in task_info
)

return task_uuids
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
return await self._task_store.get_task_uuids(task_context)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import StrEnum, auto
from typing import Any, Self, TypeAlias
from typing import Any, Final, Protocol, Self, TypeAlias
from uuid import UUID

from models_library.progress_bar import ProgressReport
Expand All @@ -9,6 +9,20 @@
TaskID: TypeAlias = str
TaskUUID: TypeAlias = UUID

_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"


def build_task_id_prefix(task_context: TaskContext) -> str:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[f"{task_context[key]}" for key in sorted(task_context)]
)


def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
[build_task_id_prefix(task_context), f"{task_uuid}"]
)


class TaskState(StrEnum):
PENDING = auto()
Expand All @@ -18,9 +32,21 @@ class TaskState(StrEnum):
ABORTED = auto()


class TaskData(BaseModel):
status: str


_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}


class TaskStore(Protocol):
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...

async def task_exists(self, task_id: TaskID) -> bool: ...

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: ...


class TaskStatus(BaseModel):
task_uuid: TaskUUID
task_state: TaskState
Expand Down
Loading