Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e85ccfd
add celery routing queues
giancarloromeo Apr 3, 2025
f31a007
fix task queue name
giancarloromeo Apr 3, 2025
15b071a
fix queue
giancarloromeo Apr 3, 2025
3b69bc3
pylint
giancarloromeo Apr 3, 2025
4ab39c2
add celery routing queues
giancarloromeo Apr 3, 2025
0a42019
fix task queue name
giancarloromeo Apr 3, 2025
9aa33ee
fix queue
giancarloromeo Apr 3, 2025
7fa116a
pylint
giancarloromeo Apr 3, 2025
40d988d
localy working
sanderegg Apr 4, 2025
2f2b23c
Merge remote-tracking branch 'upstream/master' into add-celery-routin…
giancarloromeo Apr 7, 2025
19d5119
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 7, 2025
8d96c9a
fix key prefix
giancarloromeo Apr 7, 2025
dd4c535
improve metadata store
giancarloromeo Apr 8, 2025
65557cb
fix get result
giancarloromeo Apr 8, 2025
054a553
remove task_queue param
giancarloromeo Apr 8, 2025
8d855c3
add expiration
giancarloromeo Apr 8, 2025
ec2325b
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 8, 2025
ad65bf3
rename
giancarloromeo Apr 8, 2025
3881595
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 8, 2025
21883cd
add fn
giancarloromeo Apr 8, 2025
643c8b5
use constant
giancarloromeo Apr 8, 2025
4342c71
naming
giancarloromeo Apr 8, 2025
5befb9d
fix name
giancarloromeo Apr 8, 2025
0f2f451
name
giancarloromeo Apr 8, 2025
a15590a
Merge remote-tracking branch 'upstream/master' into add-celery-routin…
giancarloromeo Apr 8, 2025
34e9d36
fix submit
giancarloromeo Apr 8, 2025
56783f0
typecheck
giancarloromeo Apr 8, 2025
c5e189a
restore missing decorator
giancarloromeo Apr 8, 2025
46501b5
typecheck
giancarloromeo Apr 9, 2025
3b3da36
fix key prefix
giancarloromeo Apr 9, 2025
39fc47e
remove prefix when forgetting
giancarloromeo Apr 9, 2025
5730c93
set default queue in tests
giancarloromeo Apr 9, 2025
a2eea77
remove
giancarloromeo Apr 9, 2025
7c599bf
fix test
giancarloromeo Apr 9, 2025
2d76802
fix abort
giancarloromeo Apr 9, 2025
8fe8e17
improve progress
giancarloromeo Apr 9, 2025
bd42b4f
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 9, 2025
2068e50
fix test_start_export_data
giancarloromeo Apr 9, 2025
2abd37a
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 9, 2025
6607fe9
skip test
giancarloromeo Apr 9, 2025
4f04d2f
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 9, 2025
9217f61
exclude worker
giancarloromeo Apr 9, 2025
4f5bcdc
Merge branch 'add-celery-routing-queues' of github.com:giancarloromeo…
giancarloromeo Apr 9, 2025
563c616
exclude sto-worker-cpu-bound
giancarloromeo Apr 9, 2025
22c83da
Merge branch 'master' into add-celery-routing-queues
giancarloromeo Apr 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/settings-library/src/settings_library/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class CelerySettings(BaseCustomSettings):
description="Time after which task results will be deleted (default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formating)."
),
] = timedelta(days=7)
CELERY_EPHEMERAL_RESULT_EXPIRES: Annotated[
timedelta,
Field(
description="Time after which ephemeral task results will be deleted (default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formating)."
),
] = timedelta(hours=1)
CELERY_RESULT_PERSISTENT: Annotated[
bool,
Field(
Expand Down
10 changes: 10 additions & 0 deletions services/docker-compose.devel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ services:
STORAGE_PROFILING : ${STORAGE_PROFILING}
STORAGE_LOGLEVEL: DEBUG

sto-worker-cpu-bound:
volumes:
- ./storage:/devel/services/storage
- ../packages:/devel/packages
- ${HOST_UV_CACHE_DIR}:/home/scu/.cache/uv
environment:
<<: *common-environment
STORAGE_PROFILING : ${STORAGE_PROFILING}
STORAGE_LOGLEVEL: DEBUG

agent:
environment:
<<: *common-environment
Expand Down
8 changes: 8 additions & 0 deletions services/docker-compose.local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ services:
ports:
- "8080"
- "3021:3000"

sto-worker-cpu-bound:
environment:
<<: *common_environment
STORAGE_REMOTE_DEBUGGING_PORT : 3000
ports:
- "8080"
- "3022:3000"
webserver:
environment: &webserver_environment_local
<<: *common_environment
Expand Down
11 changes: 11 additions & 0 deletions services/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1193,10 +1193,21 @@ services:
image: ${DOCKER_REGISTRY:-itisfoundation}/storage:${DOCKER_IMAGE_TAG:-master-github-latest}
init: true
hostname: "sto-worker-{{.Node.Hostname}}-{{.Task.Slot}}"
environment:
<<: *storage_environment
STORAGE_WORKER_MODE: "true"
CELERY_CONCURRENCY: 100
networks: *storage_networks

sto-worker-cpu-bound:
image: ${DOCKER_REGISTRY:-itisfoundation}/storage:${DOCKER_IMAGE_TAG:-master-github-latest}
init: true
hostname: "sto-worker-cpu-bound-{{.Node.Hostname}}-{{.Task.Slot}}"
environment:
<<: *storage_environment
STORAGE_WORKER_MODE: "true"
CELERY_CONCURRENCY: 1
CELERY_QUEUES: "cpu-bound"
networks: *storage_networks

rabbit:
Expand Down
6 changes: 4 additions & 2 deletions services/storage/docker/boot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ if [ "${STORAGE_WORKER_MODE}" = "true" ]; then
--app=simcore_service_storage.modules.celery.worker_main:app \
worker --pool=threads \
--loglevel="${SERVER_LOG_LEVEL}" \
--concurrency="${CELERY_CONCURRENCY}"
--concurrency="${CELERY_CONCURRENCY}" \
--queues="${CELERY_QUEUES:-default}"
else
exec celery \
--app=simcore_service_storage.modules.celery.worker_main:app \
worker --pool=threads \
--loglevel="${SERVER_LOG_LEVEL}" \
--concurrency="${CELERY_CONCURRENCY}"
--concurrency="${CELERY_CONCURRENCY}" \
--queues="${CELERY_QUEUES:-default}"
fi
else
if [ "${SC_BOOT_MODE}" = "debug" ]; then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ...dsm import get_dsm_provider
from ...exceptions.errors import FileAccessRightError
from ...modules.celery import get_celery_client
from ...modules.celery.models import TaskMetadata, TasksQueue
from ...modules.datcore_adapter.datcore_adapter_exceptions import DatcoreAdapterError
from ...simcore_s3_dsm import SimcoreS3DataManager

Expand Down Expand Up @@ -60,6 +61,10 @@ async def start_data_export(
task_uuid = await get_celery_client(app).send_task(
"export_data",
task_context=job_id_data.model_dump(),
task_metadata=TaskMetadata(
ephemeral=False,
queue=TasksQueue.CPU_BOUND,
),
files=data_export_start.file_and_folder_ids, # ANE: adapt here your signature
)
except CeleryError as exc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...core.settings import get_application_settings
from ._celery_types import register_celery_types
from ._common import create_app
from .backends._redis import RedisTaskStore
from .backends._redis import RedisTaskMetadataStore
from .client import CeleryTaskQueueClient

_logger = logging.getLogger(__name__)
Expand All @@ -29,7 +29,9 @@ async def on_startup() -> None:
)

app.state.celery_client = CeleryTaskQueueClient(
celery_app, RedisTaskStore(redis_client_sdk)
celery_app,
celery_settings,
RedisTaskMetadataStore(redis_client_sdk),
)

register_celery_types()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]:
"result_expires": celery_settings.CELERY_RESULT_EXPIRES,
"result_extended": True,
"result_serializer": "json",
"task_default_queue": "default",
"task_send_sent_event": True,
"task_track_started": True,
"worker_send_task_events": True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,44 @@
import logging
from datetime import timedelta
from typing import Final

from celery.result import AsyncResult
from servicelib.redis._client import RedisClientSDK

from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix
from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix

_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
_CELERY_TASK_METADATA_PREFIX: Final[str] = "celery-task-metadata-"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"

_logger = logging.getLogger(__name__)


class RedisTaskStore:
def _build_key(task_id: TaskID) -> str:
return _CELERY_TASK_META_PREFIX + task_id


class RedisTaskMetadataStore:
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self._redis_client_sdk = redis_client_sdk

async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = build_task_id_prefix(task_context) + _CELERY_TASK_ID_KEY_SEPARATOR
async def exists(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
assert isinstance(n, int) # nosec
return n > 0

async def get(self, task_id: TaskID) -> TaskMetadata | None:
result = await self._redis_client_sdk.redis.get(_build_key(task_id))
return TaskMetadata.model_validate_json(result) if result else None

async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = (
_CELERY_TASK_METADATA_PREFIX
+ build_task_id_prefix(task_context)
+ _CELERY_TASK_ID_KEY_SEPARATOR
)
keys = set()
async for key in self._redis_client_sdk.redis.scan_iter(
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
Expand All @@ -29,13 +52,15 @@ async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
keys.add(TaskUUID(_key.removeprefix(search_key)))
return keys

async def task_exists(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(task_id)
assert isinstance(n, int) # nosec
return n > 0
async def remove(self, task_id: TaskID) -> None:
await self._redis_client_sdk.redis.delete(_build_key(task_id))
AsyncResult(_CELERY_TASK_META_PREFIX + task_id).forget()

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None:
async def set(
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
) -> None:
await self._redis_client_sdk.redis.set(
task_id,
task_data.model_dump_json(),
_build_key(task_id),
task_metadata.model_dump_json(),
ex=expiry,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from models_library.progress_bar import ProgressReport
from pydantic import ValidationError
from servicelib.logging_utils import log_context
from settings_library.celery import CelerySettings

from .models import (
TaskContext,
TaskData,
TaskMetadata,
TaskMetadataStore,
TaskState,
TaskStatus,
TaskStore,
TaskUUID,
build_task_id,
)
Expand All @@ -43,10 +44,16 @@
@dataclass
class CeleryTaskQueueClient:
_celery_app: Celery
_task_store: TaskStore
_celery_settings: CelerySettings
_task_store: TaskMetadataStore

async def send_task(
self, task_name: str, *, task_context: TaskContext, **task_params
self,
task_name: str,
*,
task_context: TaskContext,
task_metadata: TaskMetadata | None = None,
**task_params,
) -> TaskUUID:
with log_context(
_logger,
Expand All @@ -55,10 +62,20 @@ async def send_task(
):
task_uuid = uuid4()
task_id = build_task_id(task_context, task_uuid)
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
await self._task_store.set_task(
task_id, TaskData(status=TaskState.PENDING.name)
task_metadata = task_metadata or TaskMetadata()
self._celery_app.send_task(
task_name,
task_id=task_id,
kwargs=task_params,
queue=task_metadata.queue,
)

expiry = (
self._celery_settings.CELERY_EPHEMERAL_RESULT_EXPIRES
if task_metadata.ephemeral
else self._celery_settings.CELERY_RESULT_EXPIRES
)
await self._task_store.set(task_id, task_metadata, expiry=expiry)
return task_uuid

@staticmethod
Expand All @@ -72,15 +89,23 @@ def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
task_id = build_task_id(task_context, task_uuid)
AbortableAsyncResult(task_id).abort()

@make_async()
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
async def get_task_result(
self, task_context: TaskContext, task_uuid: TaskUUID
) -> Any:
with log_context(
_logger,
logging.DEBUG,
msg=f"Get task result: {task_context=} {task_uuid=}",
):
task_id = build_task_id(task_context, task_uuid)
return self._celery_app.AsyncResult(task_id).result
async_result = self._celery_app.AsyncResult(task_id)
result = async_result.result
if async_result.ready():
task_metadata = await self._task_store.get(task_id)
if task_metadata is not None and task_metadata.ephemeral:
await self._task_store.remove(task_id)
await self._task_store.remove(task_id)
return result

def _get_progress_report(
self, task_context: TaskContext, task_uuid: TaskUUID
Expand Down Expand Up @@ -129,4 +154,4 @@ async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
logging.DEBUG,
msg=f"Getting task uuids: {task_context=}",
):
return await self._task_store.get_task_uuids(task_context)
return await self._task_store.get_uuids(task_context)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from enum import StrEnum, auto
from typing import Any, Final, Protocol, TypeAlias
from uuid import UUID
Expand Down Expand Up @@ -32,19 +33,31 @@ class TaskState(StrEnum):
ABORTED = auto()


class TaskData(BaseModel):
status: str
class TasksQueue(StrEnum):
CPU_BOUND = "cpu_bound"
DEFAULT = "default"


class TaskMetadata(BaseModel):
ephemeral: bool = True
queue: TasksQueue = TasksQueue.DEFAULT


_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}


class TaskStore(Protocol):
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...
class TaskMetadataStore(Protocol):
async def exists(self, task_id: TaskID) -> bool: ...

async def get(self, task_id: TaskID) -> TaskMetadata | None: ...

async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...

async def task_exists(self, task_id: TaskID) -> bool: ...
async def remove(self, task_id: TaskID) -> None: ...

async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: ...
async def set(
self, task_id: TaskID, task_data: TaskMetadata, expiry: timedelta
) -> None: ...


class TaskStatus(BaseModel):
Expand Down