diff --git a/services/storage/src/simcore_service_storage/api/rpc/_data_export.py b/services/storage/src/simcore_service_storage/api/rpc/_data_export.py index 424fbc2f0d0d..aab9d7339f62 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_data_export.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_data_export.py @@ -52,7 +52,7 @@ async def start_data_export( ) from err task_uuid = await get_celery_client(app).send_task( - "export_data_with_error", + "export_data", task_context=job_id_data.model_dump(), files=data_export_start.file_and_folder_ids, # ANE: adapt here your signature ) diff --git a/services/storage/src/simcore_service_storage/modules/celery/_common.py b/services/storage/src/simcore_service_storage/modules/celery/_common.py index f5a21145d301..52bb638772af 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/_common.py +++ b/services/storage/src/simcore_service_storage/modules/celery/_common.py @@ -27,7 +27,10 @@ def create_app(celery_settings: CelerySettings) -> Celery: app.conf.result_expires = celery_settings.CELERY_RESULT_EXPIRES app.conf.result_extended = True # original args are included in the results app.conf.result_serializer = "json" + app.conf.task_send_sent_event = True app.conf.task_track_started = True + app.conf.worker_send_task_events = True # enable tasks monitoring + return app diff --git a/services/storage/src/simcore_service_storage/modules/celery/client.py b/services/storage/src/simcore_service_storage/modules/celery/client.py index 42f93889ec99..b6c2a0906917 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/client.py +++ b/services/storage/src/simcore_service_storage/modules/celery/client.py @@ -18,7 +18,6 @@ _CELERY_INSPECT_TASK_STATUSES: Final[tuple[str, ...]] = ( "active", - "registered", "scheduled", "revoked", ) @@ -131,21 +130,22 @@ def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: @make_async() def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: - all_task_ids = self._get_completed_task_uuids(task_context) + task_uuids = self._get_completed_task_uuids(task_context) - search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context) + task_id_prefix = _build_task_id_prefix(task_context) + inspect = self._celery_app.control.inspect() for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES: - if task_ids := getattr( - self._celery_app.control.inspect(), task_inspect_status - )(): - for values in task_ids.values(): - for value in values: - all_task_ids.add( - TaskUUID( - value.removeprefix( - search_key + _CELERY_TASK_ID_KEY_SEPARATOR - ) - ) - ) - - return all_task_ids + tasks = getattr(inspect, task_inspect_status)() or {} + + task_uuids.update( + TaskUUID( + task_info["id"].removeprefix( + task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR + ) + ) + for tasks_per_worker in tasks.values() + for task_info in tasks_per_worker + if "id" in task_info + ) + + return task_uuids diff --git a/services/storage/tests/unit/modules/celery/conftest.py b/services/storage/tests/unit/modules/celery/conftest.py index 3cd06195b286..8bbb621ef0bd 100644 --- a/services/storage/tests/unit/modules/celery/conftest.py +++ b/services/storage/tests/unit/modules/celery/conftest.py @@ -43,6 +43,9 @@ def celery_conf() -> dict[str, Any]: "result_expires": timedelta(days=7), "result_extended": True, "pool": "threads", + "worker_send_task_events": True, + "task_track_started": True, + "task_send_sent_event": True, } @@ -77,7 +80,13 @@ def celery_worker_controller( register_celery_tasks(celery_app) - with start_worker(celery_app, loglevel="info", perform_ping_check=False) as worker: + with start_worker( + celery_app, + pool="threads", + loglevel="info", + perform_ping_check=False, + worker_kwargs={"hostname": "celery@worker1"}, + ) as worker: worker_init.send(sender=worker) yield worker diff --git a/services/storage/tests/unit/modules/celery/test_celery.py b/services/storage/tests/unit/modules/celery/test_celery.py index 99c3cc34263a..097e5b269ab9 100644 --- a/services/storage/tests/unit/modules/celery/test_celery.py +++ b/services/storage/tests/unit/modules/celery/test_celery.py @@ -4,17 +4,21 @@ from collections.abc import Callable from random import randint -from pydantic import TypeAdapter, ValidationError import pytest from celery import Celery, Task from celery.contrib.abortable import AbortableTask from common_library.errors_classes import OsparcErrorMixin from models_library.progress_bar import ProgressReport +from pydantic import TypeAdapter, ValidationError from servicelib.logging_utils import log_context from simcore_service_storage.modules.celery import get_event_loop from simcore_service_storage.modules.celery._common import define_task from simcore_service_storage.modules.celery.client import CeleryTaskQueueClient -from simcore_service_storage.modules.celery.models import TaskContext, TaskError, TaskState +from simcore_service_storage.modules.celery.models import ( + TaskContext, + TaskError, + TaskState, +) from simcore_service_storage.modules.celery.utils import ( get_celery_worker, get_fastapi_app, @@ -54,7 +58,7 @@ def sync_archive(task: Task, files: list[str]) -> str: class MyError(OsparcErrorMixin, Exception): - msg_template = "Something strange happened: {msg}" + msg_template = "Something strange happened: {msg}" def failure_task(task: Task): @@ -163,3 +167,25 @@ async def test_aborting_task_results_with_aborted_state( assert ( await celery_client.get_task_status(task_context, task_uuid) ).task_state == TaskState.ABORTED + + +@pytest.mark.usefixtures("celery_worker") +async def test_listing_task_uuids_contains_submitted_task( + celery_client: CeleryTaskQueueClient, +): + task_context = TaskContext(user_id=42) + + task_uuid = await celery_client.send_task( + "dreamer_task", + task_context=task_context, + ) + + for attempt in Retrying( + retry=retry_if_exception_type(AssertionError), + wait=wait_fixed(1), + stop=stop_after_delay(10), + ): + with attempt: + assert task_uuid in await celery_client.get_task_uuids(task_context) + + assert task_uuid in await celery_client.get_task_uuids(task_context)