Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def start_data_export(
) from err

task_uuid = await get_celery_client(app).send_task(
"export_data_with_error",
"export_data",
task_context=job_id_data.model_dump(),
files=data_export_start.file_and_folder_ids, # ANE: adapt here your signature
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def create_app(celery_settings: CelerySettings) -> Celery:
app.conf.result_expires = celery_settings.CELERY_RESULT_EXPIRES
app.conf.result_extended = True # original args are included in the results
app.conf.result_serializer = "json"
app.conf.task_send_sent_event = True
app.conf.task_track_started = True
app.conf.worker_send_task_events = True # enable tasks monitoring

return app


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

_CELERY_INSPECT_TASK_STATUSES: Final[tuple[str, ...]] = (
"active",
"registered",
"scheduled",
"revoked",
)
Expand Down Expand Up @@ -131,21 +130,22 @@ def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:

@make_async()
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
all_task_ids = self._get_completed_task_uuids(task_context)
task_uuids = self._get_completed_task_uuids(task_context)

search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
task_id_prefix = _build_task_id_prefix(task_context)
inspect = self._celery_app.control.inspect()
for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES:
if task_ids := getattr(
self._celery_app.control.inspect(), task_inspect_status
)():
for values in task_ids.values():
for value in values:
all_task_ids.add(
TaskUUID(
value.removeprefix(
search_key + _CELERY_TASK_ID_KEY_SEPARATOR
)
)
)

return all_task_ids
tasks = getattr(inspect, task_inspect_status)() or {}

task_uuids.update(
TaskUUID(
task_info["id"].removeprefix(
task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR
)
)
for tasks_per_worker in tasks.values()
for task_info in tasks_per_worker
if "id" in task_info
)

return task_uuids
11 changes: 10 additions & 1 deletion services/storage/tests/unit/modules/celery/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def celery_conf() -> dict[str, Any]:
"result_expires": timedelta(days=7),
"result_extended": True,
"pool": "threads",
"worker_send_task_events": True,
"task_track_started": True,
"task_send_sent_event": True,
}


Expand Down Expand Up @@ -77,7 +80,13 @@ def celery_worker_controller(

register_celery_tasks(celery_app)

with start_worker(celery_app, loglevel="info", perform_ping_check=False) as worker:
with start_worker(
celery_app,
pool="threads",
loglevel="info",
perform_ping_check=False,
worker_kwargs={"hostname": "celery@worker1"},
) as worker:
worker_init.send(sender=worker)

yield worker
Expand Down
32 changes: 29 additions & 3 deletions services/storage/tests/unit/modules/celery/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
from collections.abc import Callable
from random import randint

from pydantic import TypeAdapter, ValidationError
import pytest
from celery import Celery, Task
from celery.contrib.abortable import AbortableTask
from common_library.errors_classes import OsparcErrorMixin
from models_library.progress_bar import ProgressReport
from pydantic import TypeAdapter, ValidationError
from servicelib.logging_utils import log_context
from simcore_service_storage.modules.celery import get_event_loop
from simcore_service_storage.modules.celery._common import define_task
from simcore_service_storage.modules.celery.client import CeleryTaskQueueClient
from simcore_service_storage.modules.celery.models import TaskContext, TaskError, TaskState
from simcore_service_storage.modules.celery.models import (
TaskContext,
TaskError,
TaskState,
)
from simcore_service_storage.modules.celery.utils import (
get_celery_worker,
get_fastapi_app,
Expand Down Expand Up @@ -54,7 +58,7 @@ def sync_archive(task: Task, files: list[str]) -> str:


class MyError(OsparcErrorMixin, Exception):
msg_template = "Something strange happened: {msg}"
msg_template = "Something strange happened: {msg}"


def failure_task(task: Task):
Expand Down Expand Up @@ -163,3 +167,25 @@ async def test_aborting_task_results_with_aborted_state(
assert (
await celery_client.get_task_status(task_context, task_uuid)
).task_state == TaskState.ABORTED


@pytest.mark.usefixtures("celery_worker")
async def test_listing_task_uuids_contains_submitted_task(
celery_client: CeleryTaskQueueClient,
):
task_context = TaskContext(user_id=42)

task_uuid = await celery_client.send_task(
"dreamer_task",
task_context=task_context,
)

for attempt in Retrying(
retry=retry_if_exception_type(AssertionError),
wait=wait_fixed(1),
stop=stop_after_delay(10),
):
with attempt:
assert task_uuid in await celery_client.get_task_uuids(task_context)

assert task_uuid in await celery_client.get_task_uuids(task_context)
Loading