Skip to content

Commit 6dc3e0b

Browse files
🎨 Store and retrieve task_name when listing Celery tasks (#7538)
Co-authored-by: Odei Maiz <[email protected]>
1 parent 2c92f02 commit 6dc3e0b

File tree

15 files changed

+247
-117
lines changed

15 files changed

+247
-117
lines changed

packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from typing import Any, TypeAlias
1+
from typing import Annotated, Any, TypeAlias
22
from uuid import UUID
33

4-
from models_library.users import UserID
5-
from pydantic import BaseModel
4+
from pydantic import BaseModel, StringConstraints
65

6+
from ..products import ProductName
77
from ..progress_bar import ProgressReport
8+
from ..users import UserID
89

910
AsyncJobId: TypeAlias = UUID
11+
AsyncJobName: TypeAlias = Annotated[
12+
str, StringConstraints(strip_whitespace=True, min_length=1)
13+
]
1014

1115

1216
class AsyncJobStatus(BaseModel):
@@ -21,6 +25,7 @@ class AsyncJobResult(BaseModel):
2125

2226
class AsyncJobGet(BaseModel):
2327
job_id: AsyncJobId
28+
job_name: AsyncJobName
2429

2530

2631
class AsyncJobAbort(BaseModel):
@@ -31,5 +36,5 @@ class AsyncJobAbort(BaseModel):
3136
class AsyncJobNameData(BaseModel):
3237
"""Data for controlling access to an async job"""
3338

39+
product_name: ProductName
3440
user_id: UserID
35-
product_name: str

services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def _task_progress_cb(
2626
) -> None:
2727
worker = get_celery_worker(task.app)
2828
assert task.name # nosec
29-
await worker.set_progress(
29+
await worker.set_task_progress(
3030
task_id=task_id,
3131
report=report,
3232
)
@@ -87,7 +87,7 @@ async def export_data(
8787

8888
async def _progress_cb(report: ProgressReport) -> None:
8989
assert task.name # nosec
90-
await get_celery_worker(task.app).set_progress(task_id, report)
90+
await get_celery_worker(task.app).set_task_progress(task_id, report)
9191
_logger.debug("'%s' progress %s", task_id, report.percent_value)
9292

9393
async with ProgressBarData(

services/storage/src/simcore_service_storage/api/rest/_files.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
UploadLinks,
3636
)
3737
from ...modules.celery.client import CeleryTaskClient
38-
from ...modules.celery.models import TaskUUID
38+
from ...modules.celery.models import TaskMetadata, TaskUUID
3939
from ...simcore_s3_dsm import SimcoreS3DataManager
4040
from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file
4141
from .dependencies.celery import get_celery_client
@@ -284,8 +284,10 @@ async def complete_upload_file(
284284
user_id=query_params.user_id,
285285
product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here
286286
)
287-
task_uuid = await celery_client.send_task(
288-
remote_complete_upload_file.__name__,
287+
task_uuid = await celery_client.submit_task(
288+
TaskMetadata(
289+
name=remote_complete_upload_file.__name__,
290+
),
289291
task_context=async_job_name_data.model_dump(),
290292
user_id=async_job_name_data.user_id,
291293
location_id=location_id,

services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,12 @@ async def list_jobs(
127127
_ = filter_
128128
assert app # nosec
129129
try:
130-
task_uuids = await get_celery_client(app).get_task_uuids(
130+
tasks = await get_celery_client(app).list_tasks(
131131
task_context=job_id_data.model_dump(),
132132
)
133133
except CeleryError as exc:
134134
raise JobSchedulerError(exc=f"{exc}") from exc
135135

136-
return [AsyncJobGet(job_id=task_uuid) for task_uuid in task_uuids]
136+
return [
137+
AsyncJobGet(job_id=task.uuid, job_name=task.metadata.name) for task in tasks
138+
]

services/storage/src/simcore_service_storage/api/rpc/_paths.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from servicelib.rabbitmq import RPCRouter
1111

1212
from ...modules.celery import get_celery_client
13+
from ...modules.celery.models import TaskMetadata
1314
from .._worker_tasks._paths import compute_path_size as remote_compute_path_size
1415
from .._worker_tasks._paths import delete_paths as remote_delete_paths
1516

@@ -24,15 +25,18 @@ async def compute_path_size(
2425
location_id: LocationID,
2526
path: Path,
2627
) -> AsyncJobGet:
27-
task_uuid = await get_celery_client(app).send_task(
28-
remote_compute_path_size.__name__,
28+
task_name = remote_compute_path_size.__name__
29+
task_uuid = await get_celery_client(app).submit_task(
30+
task_metadata=TaskMetadata(
31+
name=task_name,
32+
),
2933
task_context=job_id_data.model_dump(),
3034
user_id=job_id_data.user_id,
3135
location_id=location_id,
3236
path=path,
3337
)
3438

35-
return AsyncJobGet(job_id=task_uuid)
39+
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
3640

3741

3842
@router.expose(reraise_if_error_type=None)
@@ -42,11 +46,14 @@ async def delete_paths(
4246
location_id: LocationID,
4347
paths: set[Path],
4448
) -> AsyncJobGet:
45-
task_uuid = await get_celery_client(app).send_task(
46-
remote_delete_paths.__name__,
49+
task_name = remote_delete_paths.__name__
50+
task_uuid = await get_celery_client(app).submit_task(
51+
task_metadata=TaskMetadata(
52+
name=task_name,
53+
),
4754
task_context=job_id_data.model_dump(),
4855
user_id=job_id_data.user_id,
4956
location_id=location_id,
5057
paths=paths,
5158
)
52-
return AsyncJobGet(job_id=task_uuid)
59+
return AsyncJobGet(job_id=task_uuid, job_name=task_name)

services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,32 @@ async def copy_folders_from_project(
2020
job_id_data: AsyncJobNameData,
2121
body: FoldersBody,
2222
) -> AsyncJobGet:
23-
task_uuid = await get_celery_client(app).send_task(
24-
deep_copy_files_from_project.__name__,
23+
task_name = deep_copy_files_from_project.__name__
24+
task_uuid = await get_celery_client(app).submit_task(
25+
task_metadata=TaskMetadata(
26+
name=task_name,
27+
),
2528
task_context=job_id_data.model_dump(),
2629
user_id=job_id_data.user_id,
2730
body=body,
2831
)
2932

30-
return AsyncJobGet(job_id=task_uuid)
33+
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
3134

3235

3336
@router.expose()
3437
async def start_export_data(
3538
app: FastAPI, job_id_data: AsyncJobNameData, paths_to_export: list[PathToExport]
3639
) -> AsyncJobGet:
37-
task_uuid = await get_celery_client(app).send_task(
38-
export_data.__name__,
39-
task_context=job_id_data.model_dump(),
40+
task_name = export_data.__name__
41+
task_uuid = await get_celery_client(app).submit_task(
4042
task_metadata=TaskMetadata(
43+
name=task_name,
4144
ephemeral=False,
4245
queue=TasksQueue.CPU_BOUND,
4346
),
47+
task_context=job_id_data.model_dump(),
4448
user_id=job_id_data.user_id,
4549
paths_to_export=paths_to_export,
4650
)
47-
return AsyncJobGet(job_id=task_uuid)
51+
return AsyncJobGet(job_id=task_uuid, job_name=task_name)

services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
import contextlib
12
import logging
23
from datetime import timedelta
34
from typing import Final
45

5-
from celery.result import AsyncResult # type: ignore[import-untyped]
66
from models_library.progress_bar import ProgressReport
7+
from pydantic import ValidationError
78
from servicelib.redis._client import RedisClientSDK
89

9-
from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix
10+
from ..models import (
11+
Task,
12+
TaskContext,
13+
TaskID,
14+
TaskMetadata,
15+
TaskUUID,
16+
build_task_id_prefix,
17+
)
1018

1119
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
1220
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
@@ -26,26 +34,64 @@ class RedisTaskInfoStore:
2634
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
2735
self._redis_client_sdk = redis_client_sdk
2836

29-
async def exists(self, task_id: TaskID) -> bool:
37+
async def create_task(
38+
self,
39+
task_id: TaskID,
40+
task_metadata: TaskMetadata,
41+
expiry: timedelta,
42+
) -> None:
43+
task_key = _build_key(task_id)
44+
await self._redis_client_sdk.redis.hset(
45+
name=task_key,
46+
key=_CELERY_TASK_METADATA_KEY,
47+
value=task_metadata.model_dump_json(),
48+
) # type: ignore
49+
await self._redis_client_sdk.redis.expire(
50+
task_key,
51+
expiry,
52+
)
53+
54+
async def exists_task(self, task_id: TaskID) -> bool:
3055
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
3156
assert isinstance(n, int) # nosec
3257
return n > 0
3358

34-
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None:
35-
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
36-
return TaskMetadata.model_validate_json(result) if result else None
59+
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
60+
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
61+
if not raw_result:
62+
return None
63+
64+
try:
65+
return TaskMetadata.model_validate_json(raw_result)
66+
except ValidationError as exc:
67+
_logger.debug(
68+
"Failed to deserialize task metadata for task %s: %s", task_id, f"{exc}"
69+
)
70+
return None
3771

38-
async def get_progress(self, task_id: TaskID) -> ProgressReport | None:
39-
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
40-
return ProgressReport.model_validate_json(result) if result else None
72+
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
73+
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
74+
if not raw_result:
75+
return None
4176

42-
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
77+
try:
78+
return ProgressReport.model_validate_json(raw_result)
79+
except ValidationError as exc:
80+
_logger.debug(
81+
"Failed to deserialize task progress for task %s: %s", task_id, f"{exc}"
82+
)
83+
return None
84+
85+
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
4386
search_key = (
4487
_CELERY_TASK_INFO_PREFIX
4588
+ build_task_id_prefix(task_context)
4689
+ _CELERY_TASK_ID_KEY_SEPARATOR
4790
)
48-
keys = set()
91+
search_key_len = len(search_key)
92+
93+
keys: list[str] = []
94+
pipeline = self._redis_client_sdk.redis.pipeline()
4995
async for key in self._redis_client_sdk.redis.scan_iter(
5096
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
5197
):
@@ -55,27 +101,31 @@ async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
55101
if isinstance(key, bytes)
56102
else key
57103
)
58-
keys.add(TaskUUID(_key.removeprefix(search_key)))
59-
return keys
104+
keys.append(_key)
105+
pipeline.hget(_key, _CELERY_TASK_METADATA_KEY)
60106

61-
async def remove(self, task_id: TaskID) -> None:
62-
await self._redis_client_sdk.redis.delete(_build_key(task_id))
63-
AsyncResult(task_id).forget()
107+
results = await pipeline.execute()
64108

65-
async def set_metadata(
66-
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
67-
) -> None:
68-
await self._redis_client_sdk.redis.hset(
69-
name=_build_key(task_id),
70-
key=_CELERY_TASK_METADATA_KEY,
71-
value=task_metadata.model_dump_json(),
72-
) # type: ignore
73-
await self._redis_client_sdk.redis.expire(
74-
_build_key(task_id),
75-
expiry,
76-
)
109+
tasks = []
110+
for key, raw_metadata in zip(keys, results, strict=True):
111+
if raw_metadata is None:
112+
continue
113+
114+
with contextlib.suppress(ValidationError):
115+
task_metadata = TaskMetadata.model_validate_json(raw_metadata)
116+
tasks.append(
117+
Task(
118+
uuid=TaskUUID(key[search_key_len:]),
119+
metadata=task_metadata,
120+
)
121+
)
122+
123+
return tasks
124+
125+
async def remove_task(self, task_id: TaskID) -> None:
126+
await self._redis_client_sdk.redis.delete(_build_key(task_id))
77127

78-
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None:
128+
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
79129
await self._redis_client_sdk.redis.hset(
80130
name=_build_key(task_id),
81131
key=_CELERY_TASK_PROGRESS_KEY,

0 commit comments

Comments
 (0)