Skip to content

Commit 2ebf1b7

Browse files
initial commit
1 parent 41523e7 commit 2ebf1b7

File tree

3 files changed

+51
-29
lines changed

3 files changed

+51
-29
lines changed

services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
from models_library.progress_bar import ProgressReport
77
from servicelib.redis._client import RedisClientSDK
88

9-
from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix
9+
from ..models import (
10+
TaskContext,
11+
TaskID,
12+
TaskMetadata,
13+
TaskUUID,
14+
build_task_id,
15+
build_task_id_prefix,
16+
)
1017

1118
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
1219
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
@@ -26,17 +33,39 @@ class RedisTaskInfoStore:
2633
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
2734
self._redis_client_sdk = redis_client_sdk
2835

29-
async def exists(self, task_id: TaskID) -> bool:
30-
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
36+
async def create(
37+
self,
38+
task_context: TaskContext,
39+
task_uuid: TaskUUID,
40+
task_metadata: TaskMetadata,
41+
expiry: timedelta,
42+
) -> None:
43+
task_key = _build_key(build_task_id(task_context, task_uuid))
44+
await self._redis_client_sdk.redis.hset(
45+
name=task_key,
46+
key=_CELERY_TASK_METADATA_KEY,
47+
value=task_metadata.model_dump_json(),
48+
) # type: ignore
49+
await self._redis_client_sdk.redis.expire(
50+
task_key,
51+
expiry,
52+
)
53+
54+
async def exists(self, task_context: TaskContext, task_uuid: TaskUUID) -> bool:
55+
n = await self._redis_client_sdk.redis.exists(_build_key(build_task_id(task_context, task_uuid))) # type: ignore
3156
assert isinstance(n, int) # nosec
3257
return n > 0
3358

34-
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None:
35-
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
59+
async def get_metadata(
60+
self, task_context: TaskContext, task_uuid: TaskUUID
61+
) -> TaskMetadata | None:
62+
result = await self._redis_client_sdk.redis.hget(_build_key(build_task_id(task_context, task_uuid)), _CELERY_TASK_METADATA_KEY) # type: ignore
3663
return TaskMetadata.model_validate_json(result) if result else None
3764

38-
async def get_progress(self, task_id: TaskID) -> ProgressReport | None:
39-
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
65+
async def get_progress(
66+
self, task_context: TaskContext, task_uuid: TaskUUID
67+
) -> ProgressReport | None:
68+
result = await self._redis_client_sdk.redis.hget(_build_key(build_task_id(task_context, task_uuid)), _CELERY_TASK_PROGRESS_KEY) # type: ignore
4069
return ProgressReport.model_validate_json(result) if result else None
4170

4271
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
@@ -62,19 +91,6 @@ async def remove(self, task_id: TaskID) -> None:
6291
await self._redis_client_sdk.redis.delete(_build_key(task_id))
6392
AsyncResult(task_id).forget()
6493

65-
async def set_metadata(
66-
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
67-
) -> None:
68-
await self._redis_client_sdk.redis.hset(
69-
name=_build_key(task_id),
70-
key=_CELERY_TASK_METADATA_KEY,
71-
value=task_metadata.model_dump_json(),
72-
) # type: ignore
73-
await self._redis_client_sdk.redis.expire(
74-
_build_key(task_id),
75-
expiry,
76-
)
77-
7894
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None:
7995
await self._redis_client_sdk.redis.hset(
8096
name=_build_key(task_id),

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ class CeleryTaskClient:
3838

3939
async def send_task(
4040
self,
41-
task_name: str,
41+
task_metadata: TaskMetadata,
4242
*,
4343
task_context: TaskContext,
44-
task_metadata: TaskMetadata | None = None,
4544
**task_params,
4645
) -> TaskUUID:
4746
with log_context(
@@ -50,11 +49,10 @@ async def send_task(
5049
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
5150
):
5251
task_uuid = uuid4()
53-
task_id = build_task_id(task_context, task_uuid)
5452
task_metadata = task_metadata or TaskMetadata()
5553
self._celery_app.send_task(
5654
task_name,
57-
task_id=task_id,
55+
task_id=build_task_id(task_context, task_uuid),
5856
kwargs=task_params,
5957
queue=task_metadata.queue.value,
6058
)
@@ -64,7 +62,9 @@ async def send_task(
6462
if task_metadata.ephemeral
6563
else self._celery_settings.CELERY_RESULT_EXPIRES
6664
)
67-
await self._task_store.set_metadata(task_id, task_metadata, expiry=expiry)
65+
await self._task_store.create(
66+
task_context, task_uuid, task_metadata, expiry=expiry
67+
)
6868
return task_uuid
6969

7070
@make_async()

services/storage/src/simcore_service_storage/modules/celery/models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
TaskContext: TypeAlias = dict[str, Any]
1010
TaskID: TypeAlias = str
11+
TaskName: TypeAlias = str
1112
TaskUUID: TypeAlias = UUID
1213

1314
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
@@ -40,6 +41,7 @@ class TasksQueue(StrEnum):
4041

4142

4243
class TaskMetadata(BaseModel):
44+
name: TaskName
4345
ephemeral: bool = True
4446
queue: TasksQueue = TasksQueue.DEFAULT
4547

@@ -50,6 +52,14 @@ class TaskMetadata(BaseModel):
5052
class TaskInfoStore(Protocol):
5153
async def exists(self, task_id: TaskID) -> bool: ...
5254

55+
async def create(
56+
self,
57+
task_context: TaskContext,
58+
task_uuid: TaskUUID,
59+
task_metadata: TaskMetadata,
60+
expiry: timedelta,
61+
) -> None: ...
62+
5363
async def get_progress(self, task_id: TaskID) -> ProgressReport | None: ...
5464

5565
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None: ...
@@ -58,10 +68,6 @@ async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...
5868

5969
async def remove(self, task_id: TaskID) -> None: ...
6070

61-
async def set_metadata(
62-
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
63-
) -> None: ...
64-
6571
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None: ...
6672

6773

0 commit comments

Comments
 (0)