Skip to content

Commit 5bbcc74

Browse files
🎨 Use async redis client (#7443)
Co-authored-by: Giancarlo Romeo <[email protected]>
1 parent c535784 commit 5bbcc74

File tree

5 files changed

+121
-88
lines changed

5 files changed

+121
-88
lines changed

services/storage/src/simcore_service_storage/modules/celery/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,35 @@
22
from asyncio import AbstractEventLoop
33

44
from fastapi import FastAPI
5+
from servicelib.redis._client import RedisClientSDK
6+
from settings_library.redis import RedisDatabase
57

8+
from ..._meta import APP_NAME
69
from ...core.settings import get_application_settings
710
from ._celery_types import register_celery_types
811
from ._common import create_app
12+
from .backends._redis import RedisTaskStore
913
from .client import CeleryTaskQueueClient
1014

1115
_logger = logging.getLogger(__name__)
1216

1317

1418
def setup_celery_client(app: FastAPI) -> None:
1519
async def on_startup() -> None:
16-
celery_settings = get_application_settings(app).STORAGE_CELERY
20+
application_settings = get_application_settings(app)
21+
celery_settings = application_settings.STORAGE_CELERY
1722
assert celery_settings # nosec
1823
celery_app = create_app(celery_settings)
19-
app.state.celery_client = CeleryTaskQueueClient(celery_app)
24+
redis_client_sdk = RedisClientSDK(
25+
celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
26+
RedisDatabase.CELERY_TASKS
27+
),
28+
client_name=f"{APP_NAME}.celery_tasks",
29+
)
30+
31+
app.state.celery_client = CeleryTaskQueueClient(
32+
celery_app, RedisTaskStore(redis_client_sdk)
33+
)
2034

2135
register_celery_types()
2236

services/storage/src/simcore_service_storage/modules/celery/backends/__init__.py

Whitespace-only changes.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Final
2+
3+
from servicelib.redis._client import RedisClientSDK
4+
5+
from ..models import TaskContext, TaskData, TaskID, TaskUUID, build_task_id_prefix
6+
7+
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
8+
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
9+
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
10+
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
11+
12+
13+
class RedisTaskStore:
14+
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
15+
self._redis_client_sdk = redis_client_sdk
16+
17+
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
18+
search_key = build_task_id_prefix(task_context) + _CELERY_TASK_ID_KEY_SEPARATOR
19+
keys = set()
20+
async for key in self._redis_client_sdk.redis.scan_iter(
21+
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
22+
):
23+
# fake redis (tests) returns bytes, real redis returns str
24+
_key = (
25+
key.decode(_CELERY_TASK_ID_KEY_ENCODING)
26+
if isinstance(key, bytes)
27+
else key
28+
)
29+
keys.add(TaskUUID(_key.removeprefix(search_key)))
30+
return keys
31+
32+
async def task_exists(self, task_id: TaskID) -> bool:
33+
n = await self._redis_client_sdk.redis.exists(task_id)
34+
assert isinstance(n, int) # nosec
35+
return n > 0
36+
37+
async def set_task(self, task_id: TaskID, task_data: TaskData) -> None:
38+
await self._redis_client_sdk.redis.set(
39+
task_id,
40+
task_data.model_dump_json(),
41+
)

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 37 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
from pydantic import ValidationError
1414
from servicelib.logging_utils import log_context
1515

16-
from ...exceptions.errors import ConfigurationError
17-
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
16+
from .models import (
17+
TaskContext,
18+
TaskData,
19+
TaskState,
20+
TaskStatus,
21+
TaskStore,
22+
TaskUUID,
23+
build_task_id,
24+
)
1825

1926
_logger = logging.getLogger(__name__)
2027

21-
_CELERY_INSPECT_TASK_STATUSES: Final[tuple[str, ...]] = (
22-
"active",
23-
"scheduled",
24-
"revoked",
25-
)
26-
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
2728
_CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = {
2829
"PENDING": TaskState.PENDING,
2930
"STARTED": TaskState.PENDING,
@@ -34,33 +35,17 @@
3435
"FAILURE": TaskState.ERROR,
3536
"ERROR": TaskState.ERROR,
3637
}
37-
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
38-
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
3938

4039
_MIN_PROGRESS_VALUE = 0.0
4140
_MAX_PROGRESS_VALUE = 1.0
4241

4342

44-
def _build_context_prefix(task_context: TaskContext) -> list[str]:
45-
return [f"{task_context[key]}" for key in sorted(task_context)]
46-
47-
48-
def _build_task_id_prefix(task_context: TaskContext) -> str:
49-
return _CELERY_TASK_ID_KEY_SEPARATOR.join(_build_context_prefix(task_context))
50-
51-
52-
def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
53-
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
54-
[_build_task_id_prefix(task_context), f"{task_uuid}"]
55-
)
56-
57-
5843
@dataclass
5944
class CeleryTaskQueueClient:
6045
_celery_app: Celery
46+
_task_store: TaskStore
6147

62-
@make_async()
63-
def send_task(
48+
async def send_task(
6449
self, task_name: str, *, task_context: TaskContext, **task_params
6550
) -> TaskUUID:
6651
with log_context(
@@ -69,8 +54,11 @@ def send_task(
6954
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
7055
):
7156
task_uuid = uuid4()
72-
task_id = _build_task_id(task_context, task_uuid)
57+
task_id = build_task_id(task_context, task_uuid)
7358
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
59+
await self._task_store.set_task(
60+
task_id, TaskData(status=TaskState.PENDING.name)
61+
)
7462
return task_uuid
7563

7664
@staticmethod
@@ -79,25 +67,25 @@ def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
7967
with log_context(
8068
_logger,
8169
logging.DEBUG,
82-
msg=f"Abort task {task_uuid=}: {task_context=}",
70+
msg=f"Abort task: {task_context=} {task_uuid=}",
8371
):
84-
task_id = _build_task_id(task_context, task_uuid)
72+
task_id = build_task_id(task_context, task_uuid)
8573
AbortableAsyncResult(task_id).abort()
8674

8775
@make_async()
8876
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
8977
with log_context(
9078
_logger,
9179
logging.DEBUG,
92-
msg=f"Get task {task_uuid=}: {task_context=} result",
80+
msg=f"Get task result: {task_context=} {task_uuid=}",
9381
):
94-
task_id = _build_task_id(task_context, task_uuid)
82+
task_id = build_task_id(task_context, task_uuid)
9583
return self._celery_app.AsyncResult(task_id).result
9684

9785
def _get_progress_report(
9886
self, task_context: TaskContext, task_uuid: TaskUUID
9987
) -> ProgressReport:
100-
task_id = _build_task_id(task_context, task_uuid)
88+
task_id = build_task_id(task_context, task_uuid)
10189
result = self._celery_app.AsyncResult(task_id).result
10290
state = self._get_state(task_context, task_uuid)
10391
if result and state == TaskState.RUNNING:
@@ -117,64 +105,28 @@ def _get_progress_report(
117105
)
118106

119107
def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
120-
task_id = _build_task_id(task_context, task_uuid)
108+
task_id = build_task_id(task_context, task_uuid)
121109
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]
122110

123111
@make_async()
124112
def get_task_status(
125113
self, task_context: TaskContext, task_uuid: TaskUUID
126114
) -> TaskStatus:
127-
return TaskStatus(
128-
task_uuid=task_uuid,
129-
task_state=self._get_state(task_context, task_uuid),
130-
progress_report=self._get_progress_report(task_context, task_uuid),
131-
)
132-
133-
def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
134-
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
135-
backend_client = self._celery_app.backend.client
136-
if hasattr(backend_client, "keys"):
137-
if keys := backend_client.keys(f"{search_key}*"):
138-
return {
139-
TaskUUID(
140-
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
141-
)
142-
for key in keys
143-
}
144-
return set()
145-
if hasattr(backend_client, "cache"):
146-
# NOTE: backend used in testing. It is a dict-like object
147-
found_keys = set()
148-
for key in backend_client.cache:
149-
str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING)
150-
if str_key.startswith(search_key):
151-
found_keys.add(
152-
TaskUUID(
153-
f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
154-
)
155-
)
156-
return found_keys
157-
msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}"
158-
raise ConfigurationError(msg=msg)
159-
160-
@make_async()
161-
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
162-
task_uuids = self._get_completed_task_uuids(task_context)
163-
164-
task_id_prefix = _build_task_id_prefix(task_context)
165-
inspect = self._celery_app.control.inspect()
166-
for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES:
167-
tasks = getattr(inspect, task_inspect_status)() or {}
168-
169-
task_uuids.update(
170-
TaskUUID(
171-
task_info["id"].removeprefix(
172-
task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR
173-
)
174-
)
175-
for tasks_per_worker in tasks.values()
176-
for task_info in tasks_per_worker
177-
if "id" in task_info
115+
with log_context(
116+
_logger,
117+
logging.DEBUG,
118+
msg=f"Getting task status: {task_context=} {task_uuid=}",
119+
):
120+
return TaskStatus(
121+
task_uuid=task_uuid,
122+
task_state=self._get_state(task_context, task_uuid),
123+
progress_report=self._get_progress_report(task_context, task_uuid),
178124
)
179125

180-
return task_uuids
126+
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
127+
with log_context(
128+
_logger,
129+
logging.DEBUG,
130+
msg=f"Getting task uuids: {task_context=}",
131+
):
132+
return await self._task_store.get_task_uuids(task_context)

services/storage/src/simcore_service_storage/modules/celery/models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import StrEnum, auto
2-
from typing import Any, Self, TypeAlias
2+
from typing import Any, Final, Protocol, Self, TypeAlias
33
from uuid import UUID
44

55
from models_library.progress_bar import ProgressReport
@@ -9,6 +9,20 @@
99
TaskID: TypeAlias = str
1010
TaskUUID: TypeAlias = UUID
1111

12+
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
13+
14+
15+
def build_task_id_prefix(task_context: TaskContext) -> str:
16+
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
17+
[f"{task_context[key]}" for key in sorted(task_context)]
18+
)
19+
20+
21+
def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
22+
return _CELERY_TASK_ID_KEY_SEPARATOR.join(
23+
[build_task_id_prefix(task_context), f"{task_uuid}"]
24+
)
25+
1226

1327
class TaskState(StrEnum):
1428
PENDING = auto()
@@ -18,9 +32,21 @@ class TaskState(StrEnum):
1832
ABORTED = auto()
1933

2034

35+
class TaskData(BaseModel):
36+
status: str
37+
38+
2139
_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}
2240

2341

42+
class TaskStore(Protocol):
43+
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...
44+
45+
async def task_exists(self, task_id: TaskID) -> bool: ...
46+
47+
async def set_task(self, task_id: TaskID, task_data: TaskData) -> None: ...
48+
49+
2450
class TaskStatus(BaseModel):
2551
task_uuid: TaskUUID
2652
task_state: TaskState

0 commit comments

Comments
 (0)