Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e85ccfd
add celery routing queues
giancarloromeo Apr 3, 2025
f31a007
fix task queue name
giancarloromeo Apr 3, 2025
15b071a
fix queue
giancarloromeo Apr 3, 2025
3b69bc3
pylint
giancarloromeo Apr 3, 2025
4ab39c2
add celery routing queues
giancarloromeo Apr 3, 2025
0a42019
fix task queue name
giancarloromeo Apr 3, 2025
9aa33ee
fix queue
giancarloromeo Apr 3, 2025
7fa116a
pylint
giancarloromeo Apr 3, 2025
40d988d
localy working
sanderegg Apr 4, 2025
2f2b23c
Merge remote-tracking branch 'upstream/master' into add-celery-routin…
giancarloromeo Apr 7, 2025
19d5119
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 7, 2025
8d96c9a
fix key prefix
giancarloromeo Apr 7, 2025
dd4c535
improve metadata store
giancarloromeo Apr 8, 2025
65557cb
fix get result
giancarloromeo Apr 8, 2025
054a553
remove task_queue param
giancarloromeo Apr 8, 2025
8d855c3
add expiration
giancarloromeo Apr 8, 2025
ec2325b
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 8, 2025
ad65bf3
rename
giancarloromeo Apr 8, 2025
3881595
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 8, 2025
21883cd
add fn
giancarloromeo Apr 8, 2025
643c8b5
use constant
giancarloromeo Apr 8, 2025
4342c71
naming
giancarloromeo Apr 8, 2025
5befb9d
fix name
giancarloromeo Apr 8, 2025
0f2f451
name
giancarloromeo Apr 8, 2025
a15590a
Merge remote-tracking branch 'upstream/master' into add-celery-routin…
giancarloromeo Apr 8, 2025
34e9d36
fix submit
giancarloromeo Apr 8, 2025
56783f0
typecheck
giancarloromeo Apr 8, 2025
c5e189a
restore missing decorator
giancarloromeo Apr 8, 2025
46501b5
typecheck
giancarloromeo Apr 9, 2025
3b3da36
fix key prefix
giancarloromeo Apr 9, 2025
39fc47e
remove prefix when forgetting
giancarloromeo Apr 9, 2025
5730c93
set default queue in tests
giancarloromeo Apr 9, 2025
a2eea77
remove
giancarloromeo Apr 9, 2025
7c599bf
fix test
giancarloromeo Apr 9, 2025
2d76802
fix abort
giancarloromeo Apr 9, 2025
8fe8e17
improve progress
giancarloromeo Apr 9, 2025
bd42b4f
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 9, 2025
2068e50
fix test_start_export_data
giancarloromeo Apr 9, 2025
2abd37a
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 9, 2025
6607fe9
skip test
giancarloromeo Apr 9, 2025
4f04d2f
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 9, 2025
9217f61
exclude worker
giancarloromeo Apr 9, 2025
4f5bcdc
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 9, 2025
563c616
exclude sto-worker-cpu-bound
giancarloromeo Apr 9, 2025
22c83da
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import json
import logging
import warnings
from collections.abc import Iterator
from dataclasses import dataclass
from io import StringIO
from typing import Iterator

import aiohttp
import pytest
Expand Down Expand Up @@ -38,6 +38,7 @@
"traefik",
"whoami",
"sto-worker",
"sto-worker-cpu-bound",
}
# TODO: unify healthcheck policies see https://github.com/ITISFoundation/osparc-simcore/pull/2281
SERVICE_PUBLISHED_PORT = {}
Expand Down
6 changes: 6 additions & 0 deletions packages/settings-library/src/settings_library/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class CelerySettings(BaseCustomSettings):
description="Time after which task results will be deleted (default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formating)."
),
] = timedelta(days=7)
CELERY_EPHEMERAL_RESULT_EXPIRES: Annotated[
timedelta,
Field(
description="Time after which ephemeral task results will be deleted (default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formating)."
),
] = timedelta(hours=1)
CELERY_RESULT_PERSISTENT: Annotated[
bool,
Field(
Expand Down
10 changes: 10 additions & 0 deletions services/docker-compose.devel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ services:
STORAGE_PROFILING : ${STORAGE_PROFILING}
STORAGE_LOGLEVEL: DEBUG

sto-worker-cpu-bound:
volumes:
- ./storage:/devel/services/storage
- ../packages:/devel/packages
- ${HOST_UV_CACHE_DIR}:/home/scu/.cache/uv
environment:
<<: *common-environment
STORAGE_PROFILING : ${STORAGE_PROFILING}
STORAGE_LOGLEVEL: DEBUG

agent:
environment:
<<: *common-environment
Expand Down
8 changes: 8 additions & 0 deletions services/docker-compose.local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ services:
ports:
- "8080"
- "3021:3000"

sto-worker-cpu-bound:
environment:
<<: *common_environment
STORAGE_REMOTE_DEBUGGING_PORT : 3000
ports:
- "8080"
- "3022:3000"
webserver:
environment: &webserver_environment_local
<<: *common_environment
Expand Down
11 changes: 11 additions & 0 deletions services/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1193,10 +1193,21 @@ services:
image: ${DOCKER_REGISTRY:-itisfoundation}/storage:${DOCKER_IMAGE_TAG:-master-github-latest}
init: true
hostname: "sto-worker-{{.Node.Hostname}}-{{.Task.Slot}}"
environment:
<<: *storage_environment
STORAGE_WORKER_MODE: "true"
CELERY_CONCURRENCY: 100
networks: *storage_networks

sto-worker-cpu-bound:
image: ${DOCKER_REGISTRY:-itisfoundation}/storage:${DOCKER_IMAGE_TAG:-master-github-latest}
init: true
hostname: "sto-worker-cpu-bound-{{.Node.Hostname}}-{{.Task.Slot}}"
environment:
<<: *storage_environment
STORAGE_WORKER_MODE: "true"
CELERY_CONCURRENCY: 1
CELERY_QUEUES: "cpu-bound"
networks: *storage_networks

rabbit:
Expand Down
6 changes: 4 additions & 2 deletions services/storage/docker/boot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ if [ "${STORAGE_WORKER_MODE}" = "true" ]; then
--app=simcore_service_storage.modules.celery.worker_main:app \
worker --pool=threads \
--loglevel="${SERVER_LOG_LEVEL}" \
--concurrency="${CELERY_CONCURRENCY}"
--concurrency="${CELERY_CONCURRENCY}" \
--queues="${CELERY_QUEUES:-default}"
else
exec celery \
--app=simcore_service_storage.modules.celery.worker_main:app \
worker --pool=threads \
--loglevel="${SERVER_LOG_LEVEL}" \
--concurrency="${CELERY_CONCURRENCY}"
--concurrency="${CELERY_CONCURRENCY}" \
--queues="${CELERY_QUEUES:-default}"
fi
else
if [ "${SC_BOOT_MODE}" = "debug" ]; then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from servicelib.rabbitmq import RPCRouter

from ...modules.celery import get_celery_client
from ...modules.celery.errors import decode_celery_transferrable_error
from ...modules.celery.errors import (
TransferrableCeleryError,
decode_celery_transferrable_error,
)
from ...modules.celery.models import TaskState

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,6 +105,7 @@ async def result(
# try to recover the original error
exception = None
with log_catch(_logger, reraise=False):
assert isinstance(_result, TransferrableCeleryError) # nosec
exception = decode_celery_transferrable_error(_result)
exc_type = type(exception).__name__
exc_msg = f"{exception}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from servicelib.rabbitmq import RPCRouter

from ...modules.celery import get_celery_client
from ...modules.celery.models import TaskMetadata, TasksQueue
from .._worker_tasks._simcore_s3 import deep_copy_files_from_project, export_data

router = RPCRouter()
Expand Down Expand Up @@ -36,6 +37,10 @@ async def start_export_data(
task_uuid = await get_celery_client(app).send_task(
export_data.__name__,
task_context=job_id_data.model_dump(),
task_metadata=TaskMetadata(
ephemeral=False,
queue=TasksQueue.CPU_BOUND,
),
user_id=job_id_data.user_id,
paths_to_export=paths_to_export,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...core.settings import get_application_settings
from ._celery_types import register_celery_types
from ._common import create_app
from .backends._redis import RedisTaskStore
from .backends._redis import RedisTaskMetadataStore
from .client import CeleryTaskQueueClient

_logger = logging.getLogger(__name__)
Expand All @@ -29,7 +29,9 @@ async def on_startup() -> None:
)

app.state.celery_client = CeleryTaskQueueClient(
celery_app, RedisTaskStore(redis_client_sdk)
celery_app,
celery_settings,
RedisTaskMetadataStore(redis_client_sdk),
)

register_celery_types()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]:
"result_expires": celery_settings.CELERY_RESULT_EXPIRES,
"result_extended": True,
"result_serializer": "json",
"task_default_queue": "default",
"task_send_sent_event": True,
"task_track_started": True,
"worker_send_task_events": True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,43 @@
import logging
from datetime import timedelta
from typing import Final

from celery.result import AsyncResult # type: ignore[import-untyped]
from servicelib.redis._client import RedisClientSDK

from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix
from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix

_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_TASK_METADATA_PREFIX: Final[str] = "celery-task-metadata-"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"

_logger = logging.getLogger(__name__)


class RedisTaskStore:
def _build_key(task_id: TaskID) -> str:
return _CELERY_TASK_METADATA_PREFIX + task_id


class RedisTaskMetadataStore:
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self._redis_client_sdk = redis_client_sdk

async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = build_task_id_prefix(task_context) + _CELERY_TASK_ID_KEY_SEPARATOR
async def exists(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
assert isinstance(n, int) # nosec
return n > 0

async def get(self, task_id: TaskID) -> TaskMetadata | None:
result = await self._redis_client_sdk.redis.get(_build_key(task_id))
return TaskMetadata.model_validate_json(result) if result else None

async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = (
_CELERY_TASK_METADATA_PREFIX
+ build_task_id_prefix(task_context)
+ _CELERY_TASK_ID_KEY_SEPARATOR
)
keys = set()
async for key in self._redis_client_sdk.redis.scan_iter(
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
Expand All @@ -29,13 +51,15 @@ async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
keys.add(TaskUUID(_key.removeprefix(search_key)))
return keys

async def task_exists(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(task_id)
assert isinstance(n, int) # nosec
return n > 0
async def remove(self, task_id: TaskID) -> None:
await self._redis_client_sdk.redis.delete(_build_key(task_id))
AsyncResult(task_id).forget()

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None:
async def set(
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
) -> None:
await self._redis_client_sdk.redis.set(
task_id,
task_data.model_dump_json(),
_build_key(task_id),
task_metadata.model_dump_json(),
ex=expiry,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from models_library.progress_bar import ProgressReport
from pydantic import ValidationError
from servicelib.logging_utils import log_context
from settings_library.celery import CelerySettings

from .models import (
TaskContext,
TaskData,
TaskID,
TaskMetadata,
TaskMetadataStore,
TaskState,
TaskStatus,
TaskStore,
TaskUUID,
build_task_id,
)
Expand All @@ -43,10 +45,16 @@
@dataclass
class CeleryTaskQueueClient:
_celery_app: Celery
_task_store: TaskStore
_celery_settings: CelerySettings
_task_store: TaskMetadataStore

async def send_task(
self, task_name: str, *, task_context: TaskContext, **task_params
self,
task_name: str,
*,
task_context: TaskContext,
task_metadata: TaskMetadata | None = None,
**task_params,
) -> TaskUUID:
with log_context(
_logger,
Expand All @@ -55,39 +63,59 @@ async def send_task(
):
task_uuid = uuid4()
task_id = build_task_id(task_context, task_uuid)
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
await self._task_store.set_task(
task_id, TaskData(status=TaskState.PENDING.name)
task_metadata = task_metadata or TaskMetadata()
self._celery_app.send_task(
task_name,
task_id=task_id,
kwargs=task_params,
queue=task_metadata.queue.value,
)

expiry = (
self._celery_settings.CELERY_EPHEMERAL_RESULT_EXPIRES
if task_metadata.ephemeral
else self._celery_settings.CELERY_RESULT_EXPIRES
)
await self._task_store.set(task_id, task_metadata, expiry=expiry)
return task_uuid

@staticmethod
@make_async()
def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
def _abort_task(self, task_id: TaskID) -> None:
AbortableAsyncResult(task_id, app=self._celery_app).abort()

async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
with log_context(
_logger,
logging.DEBUG,
msg=f"Abort task: {task_context=} {task_uuid=}",
):
task_id = build_task_id(task_context, task_uuid)
AbortableAsyncResult(task_id).abort()
return await self._abort_task(task_id)

@make_async()
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
def _get_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
task_id = build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result

async def get_task_result(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> Any:
with log_context(
_logger,
logging.DEBUG,
msg=f"Get task result: {task_context=} {task_uuid=}",
):
task_id = build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result
async_result = self._celery_app.AsyncResult(task_id)
result = async_result.result
if async_result.ready():
task_metadata = await self._task_store.get(task_id)
if task_metadata is not None and task_metadata.ephemeral:
await self._task_store.remove(task_id)
return result

def _get_progress_report(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> ProgressReport:
task_id = build_task_id(task_context, task_uuid)
result = self._celery_app.AsyncResult(task_id).result
state = self._get_state(task_context, task_uuid)
@staticmethod
async def _get_progress_report(state, result) -> ProgressReport:
if result and state == TaskState.RUNNING:
with contextlib.suppress(ValidationError):
# avoids exception if result is not a ProgressReport (or overwritten by a Celery's state update)
Expand All @@ -104,23 +132,25 @@ def _get_progress_report(
actual_value=_MIN_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
)

@make_async()
def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
task_id = build_task_id(task_context, task_uuid)
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]

@make_async()
def get_task_status(
async def get_task_status(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> TaskStatus:
with log_context(
_logger,
logging.DEBUG,
msg=f"Getting task status: {task_context=} {task_uuid=}",
):
task_state = await self._get_state(task_context, task_uuid)
result = await self._get_result(task_context, task_uuid)
return TaskStatus(
task_uuid=task_uuid,
task_state=self._get_state(task_context, task_uuid),
progress_report=self._get_progress_report(task_context, task_uuid),
task_state=task_state,
progress_report=await self._get_progress_report(task_state, result),
)

async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
Expand All @@ -129,4 +159,4 @@ async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
logging.DEBUG,
msg=f"Getting task uuids: {task_context=}",
):
return await self._task_store.get_task_uuids(task_context)
return await self._task_store.get_uuids(task_context)
Loading
Loading