Skip to content

Commit ec2b890

Browse files
add task store protocol
1 parent 4e959a6 commit ec2b890

File tree

6 files changed

+88
-48
lines changed

6 files changed

+88
-48
lines changed

services/storage/src/simcore_service_storage/modules/celery/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi import FastAPI
55
from servicelib.redis._client import RedisClientSDK
66
from settings_library.redis import RedisDatabase
7+
from simcore_service_storage.modules.celery.backends._redis import RedisTaskStore
78

89
from ..._meta import APP_NAME
910
from ...core.settings import get_application_settings
@@ -26,7 +27,10 @@ async def on_startup() -> None:
2627
),
2728
client_name=f"{APP_NAME}.celery_tasks",
2829
)
29-
app.state.celery_client = CeleryTaskQueueClient(celery_app, redis_client_sdk)
30+
31+
app.state.celery_client = CeleryTaskQueueClient(
32+
celery_app, RedisTaskStore(redis_client_sdk)
33+
)
3034

3135
register_celery_types()
3236

services/storage/src/simcore_service_storage/modules/celery/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def create_app(celery_settings: CelerySettings) -> Celery:
1818
),
1919
)
2020
app.conf.broker_connection_retry_on_startup = True
21-
# NOTE: disable SSL cert validation (https://github.com/ITISFoundation/osparc-simcore/pull/7407)
2221
if celery_settings.CELERY_REDIS_RESULT_BACKEND.REDIS_SECURE:
22+
# NOTE: disable SSL cert validation (https://github.com/ITISFoundation/osparc-simcore/pull/7407)
2323
app.conf.redis_backend_use_ssl = {"ssl_cert_reqs": ssl.CERT_NONE}
2424
app.conf.result_expires = celery_settings.CELERY_RESULT_EXPIRES
2525
app.conf.result_extended = True # original args are included in the results

services/storage/src/simcore_service_storage/modules/celery/backends/__init__.py

Whitespace-only changes.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Final
2+
3+
from servicelib.redis._client import RedisClientSDK
4+
5+
from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix
6+
7+
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
8+
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
9+
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
10+
11+
12+
class RedisTaskStore:
13+
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
14+
self._redis_client_sdk = redis_client_sdk
15+
16+
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
17+
search_key = (
18+
_CELERY_TASK_META_PREFIX
19+
+ build_task_id_prefix(task_context)
20+
+ _CELERY_TASK_ID_KEY_SEPARATOR
21+
)
22+
keys = set()
23+
async for key in self._redis_client_sdk.redis.scan_iter(
24+
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
25+
):
26+
keys.add(TaskUUID(f"{key}".removeprefix(search_key)))
27+
return keys
28+
29+
async def task_exists(self, task_id: TaskID) -> bool:
30+
return await self._redis_client_sdk.redis.exists(task_id) > 0
31+
32+
async def set_task(self, task_id: TaskID, task_data: TaskData) -> None:
33+
await self._redis_client_sdk.redis.set(
34+
task_id,
35+
task_data.model_dump_json(),
36+
)
Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
import json
32
import logging
43
from typing import Any, Final
54
from uuid import uuid4
@@ -12,13 +11,19 @@
1211
from models_library.progress_bar import ProgressReport
1312
from pydantic import ValidationError
1413
from servicelib.logging_utils import log_context
15-
from servicelib.redis._client import RedisClientSDK
1614

17-
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
15+
from .models import (
16+
TaskContext,
17+
TaskData,
18+
TaskState,
19+
TaskStatus,
20+
TaskStore,
21+
TaskUUID,
22+
build_task_id,
23+
)
1824

1925
_logger = logging.getLogger(__name__)
2026

21-
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
2227
_CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = {
2328
"PENDING": TaskState.PENDING,
2429
"STARTED": TaskState.PENDING,
@@ -29,70 +34,49 @@
2934
"FAILURE": TaskState.ERROR,
3035
"ERROR": TaskState.ERROR,
3136
}
32-
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
33-
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
3437

3538
_MIN_PROGRESS_VALUE = 0.0
3639
_MAX_PROGRESS_VALUE = 100.0
3740

3841

39-
def _build_context_prefix(task_context: TaskContext) -> list[str]:
40-
return [f"{task_context[key]}" for key in sorted(task_context)]
41-
42-
43-
def _build_task_id_prefix(task_context: TaskContext) -> str:
44-
return _CELERY_TASK_ID_KEY_SEPARATOR.join(_build_context_prefix(task_context))
45-
46-
47-
def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
48-
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
49-
[_build_task_id_prefix(task_context), f"{task_uuid}"]
50-
)
51-
52-
5342
class CeleryTaskQueueClient:
54-
def __init__(self, celery_app: Celery, redis_client_sdk: RedisClientSDK) -> None:
43+
def __init__(self, celery_app: Celery, task_store: TaskStore) -> None:
5544
self._celery_app = celery_app
56-
self._redis_client_sdk = redis_client_sdk
45+
self._task_store = task_store
5746

5847
async def send_task(
5948
self, task_name: str, *, task_context: TaskContext, **task_params
6049
) -> TaskUUID:
6150
task_uuid = uuid4()
62-
task_id = _build_task_id(task_context, task_uuid)
51+
task_id = build_task_id(task_context, task_uuid)
6352
with log_context(
6453
_logger,
6554
logging.DEBUG,
6655
msg=f"Submitting task {task_name}: {task_id=} {task_params=}",
6756
):
6857
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
69-
await self._redis_client_sdk.redis.set(
70-
_CELERY_TASK_META_PREFIX + task_id,
71-
json.dumps(
72-
{
73-
"status": "PENDING",
74-
}
75-
),
58+
await self._task_store.set_task(
59+
task_id, TaskData(status=TaskState.PENDING.name)
7660
)
7761
return task_uuid
7862

7963
@make_async()
8064
def abort_task( # pylint: disable=R6301
8165
self, task_context: TaskContext, task_uuid: TaskUUID
8266
) -> None:
83-
task_id = _build_task_id(task_context, task_uuid)
67+
task_id = build_task_id(task_context, task_uuid)
8468
_logger.info("Aborting task %s", task_id)
8569
AbortableAsyncResult(task_id).abort()
8670

8771
@make_async()
8872
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
89-
task_id = _build_task_id(task_context, task_uuid)
73+
task_id = build_task_id(task_context, task_uuid)
9074
return self._celery_app.AsyncResult(task_id).result
9175

9276
def _get_progress_report(
9377
self, task_context: TaskContext, task_uuid: TaskUUID
9478
) -> ProgressReport:
95-
task_id = _build_task_id(task_context, task_uuid)
79+
task_id = build_task_id(task_context, task_uuid)
9680
result = self._celery_app.AsyncResult(task_id).result
9781
state = self._get_state(task_context, task_uuid)
9882
if result and state == TaskState.RUNNING:
@@ -108,7 +92,7 @@ def _get_progress_report(
10892
return ProgressReport(actual_value=_MIN_PROGRESS_VALUE)
10993

11094
def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
111-
task_id = _build_task_id(task_context, task_uuid)
95+
task_id = build_task_id(task_context, task_uuid)
11296
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]
11397

11498
@make_async()
@@ -122,14 +106,4 @@ def get_task_status(
122106
)
123107

124108
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
125-
search_key = (
126-
_CELERY_TASK_META_PREFIX
127-
+ _build_task_id_prefix(task_context)
128-
+ _CELERY_TASK_ID_KEY_SEPARATOR
129-
)
130-
keys = set()
131-
async for key in self._redis_client_sdk.redis.scan_iter(
132-
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
133-
):
134-
keys.add(TaskUUID(f"{key}".removeprefix(search_key)))
135-
return keys
109+
return await self._task_store.get_task_uuids(task_context)

services/storage/src/simcore_service_storage/modules/celery/models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import StrEnum, auto
2-
from typing import Any, Final, Self, TypeAlias
2+
from typing import Any, Final, Protocol, Self, TypeAlias
33
from uuid import UUID
44

55
from models_library.progress_bar import ProgressReport
@@ -12,6 +12,20 @@
1212
_MIN_PROGRESS: Final[float] = 0.0
1313
_MAX_PROGRESS: Final[float] = 100.0
1414

15+
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
16+
17+
18+
def build_task_id_prefix(task_context: TaskContext) -> str:
19+
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
20+
[f"{task_context[key]}" for key in sorted(task_context)]
21+
)
22+
23+
24+
def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
25+
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
26+
[build_task_id_prefix(task_context), f"{task_uuid}"]
27+
)
28+
1529

1630
class TaskState(StrEnum):
1731
PENDING = auto()
@@ -21,9 +35,21 @@ class TaskState(StrEnum):
2135
ABORTED = auto()
2236

2337

38+
class TaskData(BaseModel):
39+
status: str
40+
41+
2442
_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}
2543

2644

45+
class TaskStore(Protocol):
46+
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...
47+
48+
async def task_exists(self, task_id: TaskID) -> bool: ...
49+
50+
async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: ...
51+
52+
2753
class TaskStatus(BaseModel):
2854
task_uuid: TaskUUID
2955
task_state: TaskState

0 commit comments

Comments
 (0)