diff --git a/services/storage/src/simcore_service_storage/modules/celery/__init__.py b/services/storage/src/simcore_service_storage/modules/celery/__init__.py index fc6ed86c7b54..41c15756be02 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/__init__.py +++ b/services/storage/src/simcore_service_storage/modules/celery/__init__.py @@ -2,7 +2,11 @@ from asyncio import AbstractEventLoop from fastapi import FastAPI +from servicelib.redis._client import RedisClientSDK +from settings_library.redis import RedisDatabase +from simcore_service_storage.modules.celery.backends._redis import RedisTaskStore +from ..._meta import APP_NAME from ...core.settings import get_application_settings from ._celery_types import register_celery_types from ._common import create_app @@ -13,10 +17,20 @@ def setup_celery_client(app: FastAPI) -> None: async def on_startup() -> None: - celery_settings = get_application_settings(app).STORAGE_CELERY + application_settings = get_application_settings(app) + celery_settings = application_settings.STORAGE_CELERY assert celery_settings # nosec celery_app = create_app(celery_settings) - app.state.celery_client = CeleryTaskQueueClient(celery_app) + redis_client_sdk = RedisClientSDK( + celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( + RedisDatabase.CELERY_TASKS + ), + client_name=f"{APP_NAME}.celery_tasks", + ) + + app.state.celery_client = CeleryTaskQueueClient( + celery_app, RedisTaskStore(redis_client_sdk) + ) register_celery_types() diff --git a/services/storage/src/simcore_service_storage/modules/celery/backends/__init__.py b/services/storage/src/simcore_service_storage/modules/celery/backends/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py b/services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py new file mode 100644 index 000000000000..7da2c58714aa --- /dev/null +++ b/services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py @@ -0,0 +1,38 @@ +from typing import Final + +from servicelib.redis._client import RedisClientSDK + +from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix + +_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-" +_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":" +_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000 + + +class RedisTaskStore: + def __init__(self, redis_client_sdk: RedisClientSDK) -> None: + self._redis_client_sdk = redis_client_sdk + + async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: + search_key = ( + _CELERY_TASK_META_PREFIX + + build_task_id_prefix(task_context) + + _CELERY_TASK_ID_KEY_SEPARATOR + ) + keys = set() + async for key in self._redis_client_sdk.redis.scan_iter( + match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH + ): + keys.add(TaskUUID(f"{key}".removeprefix(search_key))) + return keys + + async def task_exists(self, task_id: TaskID) -> bool: + n = await self._redis_client_sdk.redis.exists(task_id) + assert isinstance(n, int) # nosec + return n > 0 + + async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: + await self._redis_client_sdk.redis.set( + task_id, + task_data.model_dump_json(), + ) diff --git a/services/storage/src/simcore_service_storage/modules/celery/client.py b/services/storage/src/simcore_service_storage/modules/celery/client.py index 1b491ff197e8..14276720d313 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/client.py +++ b/services/storage/src/simcore_service_storage/modules/celery/client.py @@ -13,17 +13,18 @@ from pydantic import ValidationError from servicelib.logging_utils import log_context -from ...exceptions.errors import ConfigurationError -from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID +from .models import ( + TaskContext, + TaskData, + TaskState, + TaskStatus, + TaskStore, + TaskUUID, + build_task_id, +) _logger = logging.getLogger(__name__) -_CELERY_INSPECT_TASK_STATUSES: Final[tuple[str, ...]] = ( - "active", - "scheduled", - "revoked", -) -_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-" _CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = { "PENDING": TaskState.PENDING, "STARTED": TaskState.PENDING, @@ -34,33 +35,17 @@ "FAILURE": TaskState.ERROR, "ERROR": TaskState.ERROR, } -_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":" -_CELERY_TASK_ID_KEY_ENCODING = "utf-8" _MIN_PROGRESS_VALUE = 0.0 _MAX_PROGRESS_VALUE = 1.0 -def _build_context_prefix(task_context: TaskContext) -> list[str]: - return [f"{task_context[key]}" for key in sorted(task_context)] - - -def _build_task_id_prefix(task_context: TaskContext) -> str: - return _CELERY_TASK_ID_KEY_SEPARATOR.join(_build_context_prefix(task_context)) - - -def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID: - return _CELERY_TASK_ID_KEY_SEPARATOR.join( - [_build_task_id_prefix(task_context), f"{task_uuid}"] - ) - - @dataclass class CeleryTaskQueueClient: _celery_app: Celery + _task_store: TaskStore - @make_async() - def send_task( + async def send_task( self, task_name: str, *, task_context: TaskContext, **task_params ) -> TaskUUID: with log_context( @@ -69,35 +54,29 @@ def send_task( msg=f"Submit {task_name=}: {task_context=} {task_params=}", ): task_uuid = uuid4() - task_id = _build_task_id(task_context, task_uuid) + task_id = build_task_id(task_context, task_uuid) self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params) + await self._task_store.set_task( + task_id, TaskData(status=TaskState.PENDING.name) + ) return task_uuid @staticmethod @make_async() def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None: - with log_context( - _logger, - logging.DEBUG, - msg=f"Abort task {task_uuid=}: {task_context=}", - ): - task_id = _build_task_id(task_context, task_uuid) - AbortableAsyncResult(task_id).abort() + task_id = build_task_id(task_context, task_uuid) + _logger.info("Aborting task %s", task_id) + AbortableAsyncResult(task_id).abort() @make_async() def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any: - with log_context( - _logger, - logging.DEBUG, - msg=f"Get task {task_uuid=}: {task_context=} result", - ): - task_id = _build_task_id(task_context, task_uuid) - return self._celery_app.AsyncResult(task_id).result + task_id = build_task_id(task_context, task_uuid) + return self._celery_app.AsyncResult(task_id).result def _get_progress_report( self, task_context: TaskContext, task_uuid: TaskUUID ) -> ProgressReport: - task_id = _build_task_id(task_context, task_uuid) + task_id = build_task_id(task_context, task_uuid) result = self._celery_app.AsyncResult(task_id).result state = self._get_state(task_context, task_uuid) if result and state == TaskState.RUNNING: @@ -117,7 +96,7 @@ def _get_progress_report( ) def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState: - task_id = _build_task_id(task_context, task_uuid) + task_id = build_task_id(task_context, task_uuid) return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state] @make_async() @@ -130,51 +109,5 @@ def get_task_status( progress_report=self._get_progress_report(task_context, task_uuid), ) - def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: - search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context) - backend_client = self._celery_app.backend.client - if hasattr(backend_client, "keys"): - if keys := backend_client.keys(f"{search_key}*"): - return { - TaskUUID( - f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}" - ) - for key in keys - } - return set() - if hasattr(backend_client, "cache"): - # NOTE: backend used in testing. It is a dict-like object - found_keys = set() - for key in backend_client.cache: - str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING) - if str_key.startswith(search_key): - found_keys.add( - TaskUUID( - f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}" - ) - ) - return found_keys - msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}" - raise ConfigurationError(msg=msg) - - @make_async() - def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: - task_uuids = self._get_completed_task_uuids(task_context) - - task_id_prefix = _build_task_id_prefix(task_context) - inspect = self._celery_app.control.inspect() - for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES: - tasks = getattr(inspect, task_inspect_status)() or {} - - task_uuids.update( - TaskUUID( - task_info["id"].removeprefix( - task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR - ) - ) - for tasks_per_worker in tasks.values() - for task_info in tasks_per_worker - if "id" in task_info - ) - - return task_uuids + async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: + return await self._task_store.get_task_uuids(task_context) diff --git a/services/storage/src/simcore_service_storage/modules/celery/models.py b/services/storage/src/simcore_service_storage/modules/celery/models.py index 94f961a0e29a..6f2193b2da6e 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/models.py +++ b/services/storage/src/simcore_service_storage/modules/celery/models.py @@ -1,5 +1,5 @@ from enum import StrEnum, auto -from typing import Any, Self, TypeAlias +from typing import Any, Final, Protocol, Self, TypeAlias from uuid import UUID from models_library.progress_bar import ProgressReport @@ -9,6 +9,20 @@ TaskID: TypeAlias = str TaskUUID: TypeAlias = UUID +_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":" + + +def build_task_id_prefix(task_context: TaskContext) -> str: + return _CELERY_TASK_ID_KEY_SEPARATOR.join( + [f"{task_context[key]}" for key in sorted(task_context)] + ) + + +def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID: + return _CELERY_TASK_ID_KEY_SEPARATOR.join( + [build_task_id_prefix(task_context), f"{task_uuid}"] + ) + class TaskState(StrEnum): PENDING = auto() @@ -18,9 +32,21 @@ class TaskState(StrEnum): ABORTED = auto() +class TaskData(BaseModel): + status: str + + _TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED} +class TaskStore(Protocol): + async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ... + + async def task_exists(self, task_id: TaskID) -> bool: ... + + async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: ... + + class TaskStatus(BaseModel): task_uuid: TaskUUID task_state: TaskState