Skip to content

Commit bf3e5b7

Browse files
continue
1 parent 792de64 commit bf3e5b7

File tree

12 files changed

+125
-96
lines changed

12 files changed

+125
-96
lines changed

packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import Any, TypeAlias
22
from uuid import UUID
33

4-
from models_library.users import UserID
54
from pydantic import BaseModel
65

6+
from ..products import ProductName
77
from ..progress_bar import ProgressReport
8+
from ..users import UserID
89

910
AsyncJobId: TypeAlias = UUID
11+
AsyncJobName: TypeAlias = str
1012

1113

1214
class AsyncJobStatus(BaseModel):
@@ -21,6 +23,7 @@ class AsyncJobResult(BaseModel):
2123

2224
class AsyncJobGet(BaseModel):
2325
job_id: AsyncJobId
26+
job_name: AsyncJobName
2427

2528

2629
class AsyncJobAbort(BaseModel):
@@ -31,5 +34,5 @@ class AsyncJobAbort(BaseModel):
3134
class AsyncJobNameData(BaseModel):
3235
"""Data for controlling access to an async job"""
3336

37+
product_name: ProductName
3438
user_id: UserID
35-
product_name: str

services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def _task_progress_cb(
2626
) -> None:
2727
worker = get_celery_worker(task.app)
2828
assert task.name # nosec
29-
await worker.set_progress(
29+
await worker.set_task_progress(
3030
task_id=task_id,
3131
report=report,
3232
)
@@ -87,7 +87,7 @@ async def export_data(
8787

8888
async def _progress_cb(report: ProgressReport) -> None:
8989
assert task.name # nosec
90-
await get_celery_worker(task.app).set_progress(task_id, report)
90+
await get_celery_worker(task.app).set_task_progress(task_id, report)
9191
_logger.debug("'%s' progress %s", task_id, report.percent_value)
9292

9393
async with ProgressBarData(

services/storage/src/simcore_service_storage/api/rest/_files.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
UploadLinks,
3636
)
3737
from ...modules.celery.client import CeleryTaskClient
38-
from ...modules.celery.models import TaskUUID
38+
from ...modules.celery.models import TaskMetadata, TaskUUID
3939
from ...simcore_s3_dsm import SimcoreS3DataManager
4040
from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file
4141
from .dependencies.celery import get_celery_client
@@ -284,8 +284,10 @@ async def complete_upload_file(
284284
user_id=query_params.user_id,
285285
product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here
286286
)
287-
task_uuid = await celery_client.send_task(
288-
remote_complete_upload_file.__name__,
287+
task_uuid = await celery_client.submit_task(
288+
TaskMetadata(
289+
name=remote_complete_upload_file.__name__,
290+
),
289291
task_context=async_job_name_data.model_dump(),
290292
user_id=async_job_name_data.user_id,
291293
location_id=location_id,

services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,12 @@ async def list_jobs(
127127
_ = filter_
128128
assert app # nosec
129129
try:
130-
task_uuids = await get_celery_client(app).get_task_uuids(
130+
tasks = await get_celery_client(app).list_tasks(
131131
task_context=job_id_data.model_dump(),
132132
)
133133
except CeleryError as exc:
134134
raise JobSchedulerError(exc=f"{exc}") from exc
135135

136-
return [AsyncJobGet(job_id=task_uuid) for task_uuid in task_uuids]
136+
return [
137+
AsyncJobGet(job_id=task.uuid, job_name=task.metadata.name) for task in tasks
138+
]

services/storage/src/simcore_service_storage/api/rpc/_paths.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from servicelib.rabbitmq import RPCRouter
1111

1212
from ...modules.celery import get_celery_client
13+
from ...modules.celery.models import TaskMetadata
1314
from .._worker_tasks._paths import compute_path_size as remote_compute_path_size
1415
from .._worker_tasks._paths import delete_paths as remote_delete_paths
1516

@@ -24,8 +25,10 @@ async def compute_path_size(
2425
location_id: LocationID,
2526
path: Path,
2627
) -> AsyncJobGet:
27-
task_uuid = await get_celery_client(app).send_task(
28-
remote_compute_path_size.__name__,
28+
task_uuid = await get_celery_client(app).submit_task(
29+
task_metadata=TaskMetadata(
30+
name=remote_compute_path_size.__name__,
31+
),
2932
task_context=job_id_data.model_dump(),
3033
user_id=job_id_data.user_id,
3134
location_id=location_id,
@@ -42,8 +45,10 @@ async def delete_paths(
4245
location_id: LocationID,
4346
paths: set[Path],
4447
) -> AsyncJobGet:
45-
task_uuid = await get_celery_client(app).send_task(
46-
remote_delete_paths.__name__,
48+
task_uuid = await get_celery_client(app).submit_task(
49+
task_metadata=TaskMetadata(
50+
name=remote_delete_paths.__name__,
51+
),
4752
task_context=job_id_data.model_dump(),
4853
user_id=job_id_data.user_id,
4954
location_id=location_id,

services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ async def copy_folders_from_project(
2020
job_id_data: AsyncJobNameData,
2121
body: FoldersBody,
2222
) -> AsyncJobGet:
23-
task_uuid = await get_celery_client(app).send_task(
24-
deep_copy_files_from_project.__name__,
23+
task_uuid = await get_celery_client(app).submit_task(
24+
task_metadata=TaskMetadata(
25+
name=deep_copy_files_from_project.__name__,
26+
),
2527
task_context=job_id_data.model_dump(),
2628
user_id=job_id_data.user_id,
2729
body=body,
@@ -34,13 +36,13 @@ async def copy_folders_from_project(
3436
async def start_export_data(
3537
app: FastAPI, job_id_data: AsyncJobNameData, paths_to_export: list[PathToExport]
3638
) -> AsyncJobGet:
37-
task_uuid = await get_celery_client(app).send_task(
38-
export_data.__name__,
39-
task_context=job_id_data.model_dump(),
39+
task_uuid = await get_celery_client(app).submit_task(
4040
task_metadata=TaskMetadata(
41+
name=export_data.__name__,
4142
ephemeral=False,
4243
queue=TasksQueue.CPU_BOUND,
4344
),
45+
task_context=job_id_data.model_dump(),
4446
user_id=job_id_data.user_id,
4547
paths_to_export=paths_to_export,
4648
)

services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from servicelib.redis._client import RedisClientSDK
88

99
from ..models import (
10+
Task,
1011
TaskContext,
12+
TaskID,
1113
TaskMetadata,
1214
TaskUUID,
13-
build_task_id,
1415
build_task_id_prefix,
1516
)
1617

@@ -24,24 +25,21 @@
2425
_logger = logging.getLogger(__name__)
2526

2627

27-
def _build_key(task_context: TaskContext, task_uuid: TaskUUID | None = None) -> str:
28-
if task_uuid is None:
29-
return _CELERY_TASK_INFO_PREFIX + build_task_id_prefix(task_context)
30-
return _CELERY_TASK_INFO_PREFIX + build_task_id(task_context, task_uuid)
28+
def _build_key(task_id: TaskID) -> str:
29+
return _CELERY_TASK_INFO_PREFIX + task_id
3130

3231

3332
class RedisTaskInfoStore:
3433
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
3534
self._redis_client_sdk = redis_client_sdk
3635

37-
async def create(
36+
async def create_task(
3837
self,
39-
task_context: TaskContext,
40-
task_uuid: TaskUUID,
38+
task_id: TaskID,
4139
task_metadata: TaskMetadata,
4240
expiry: timedelta,
4341
) -> None:
44-
task_key = _build_key(task_context, task_uuid)
42+
task_key = _build_key(task_id)
4543
await self._redis_client_sdk.redis.hset(
4644
name=task_key,
4745
key=_CELERY_TASK_METADATA_KEY,
@@ -52,26 +50,28 @@ async def create(
5250
expiry,
5351
)
5452

55-
async def exists(self, task_context: TaskContext, task_uuid: TaskUUID) -> bool:
56-
n = await self._redis_client_sdk.redis.exists(_build_key(task_context, task_uuid)) # type: ignore
53+
async def exists_task(self, task_id: TaskID) -> bool:
54+
n = await self._redis_client_sdk.redis.exists(_build_key(task_id)) # type: ignore
5755
assert isinstance(n, int) # nosec
5856
return n > 0
5957

60-
async def get_metadata(
61-
self, task_context: TaskContext, task_uuid: TaskUUID
62-
) -> TaskMetadata | None:
63-
result = await self._redis_client_sdk.redis.hget(_build_key(task_context, task_uuid), _CELERY_TASK_METADATA_KEY) # type: ignore
58+
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
59+
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
6460
return TaskMetadata.model_validate_json(result) if result else None
6561

66-
async def get_progress(
67-
self, task_context: TaskContext, task_uuid: TaskUUID
68-
) -> ProgressReport | None:
69-
result = await self._redis_client_sdk.redis.hget(_build_key(task_context, task_uuid), _CELERY_TASK_PROGRESS_KEY) # type: ignore
62+
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
63+
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
7064
return ProgressReport.model_validate_json(result) if result else None
7165

72-
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
73-
search_key = _build_key(task_context) + _CELERY_TASK_ID_KEY_SEPARATOR
74-
keys = set()
66+
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
67+
search_key = (
68+
_CELERY_TASK_INFO_PREFIX
69+
+ build_task_id_prefix(task_context)
70+
+ _CELERY_TASK_ID_KEY_SEPARATOR
71+
)
72+
keys: list[str] = []
73+
tasks = []
74+
pipe = self._redis_client_sdk.redis.pipeline()
7575
async for key in self._redis_client_sdk.redis.scan_iter(
7676
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
7777
):
@@ -81,18 +81,28 @@ async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
8181
if isinstance(key, bytes)
8282
else key
8383
)
84-
keys.add(TaskUUID(_key.removeprefix(search_key)))
85-
return keys
84+
keys.append(_key)
8685

87-
async def remove(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
88-
await self._redis_client_sdk.redis.delete(_build_key(task_context, task_uuid)) # type: ignore
89-
AsyncResult(build_task_id(task_context, task_uuid)).forget()
86+
for key in keys:
87+
pipe.hget(key, _CELERY_TASK_METADATA_KEY)
9088

91-
async def set_progress(
92-
self, task_context: TaskContext, task_uuid: TaskUUID, report: ProgressReport
93-
) -> None:
89+
results = await pipe.execute()
90+
for key, task_metadata in zip(keys, results, strict=False):
91+
tasks.append(
92+
Task(
93+
uuid=TaskUUID(key.removeprefix(search_key)),
94+
metadata=TaskMetadata.model_validate_json(task_metadata),
95+
)
96+
)
97+
return tasks
98+
99+
async def remove_task(self, task_id: TaskID) -> None:
100+
await self._redis_client_sdk.redis.delete(_build_key(task_id)) # type: ignore
101+
AsyncResult(task_id).forget()
102+
103+
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
94104
await self._redis_client_sdk.redis.hset(
95-
name=_build_key(task_context, task_uuid),
105+
name=_build_key(task_id),
96106
key=_CELERY_TASK_PROGRESS_KEY,
97107
value=report.model_dump_json(),
98108
) # type: ignore

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from settings_library.celery import CelerySettings
1414

1515
from .models import (
16+
Task,
1617
TaskContext,
1718
TaskInfoStore,
1819
TaskMetadata,
@@ -35,7 +36,7 @@ class CeleryTaskClient:
3536
_celery_settings: CelerySettings
3637
_task_store: TaskInfoStore
3738

38-
async def send_task(
39+
async def submit_task(
3940
self,
4041
task_metadata: TaskMetadata,
4142
*,
@@ -48,9 +49,10 @@ async def send_task(
4849
msg=f"Submit {task_metadata.name=}: {task_context=} {task_params=}",
4950
):
5051
task_uuid = uuid4()
52+
task_id = build_task_id(task_context, task_uuid)
5153
self._celery_app.send_task(
5254
task_metadata.name,
53-
task_id=build_task_id(task_context, task_uuid),
55+
task_id=task_id,
5456
kwargs=task_params,
5557
queue=task_metadata.queue.value,
5658
)
@@ -60,9 +62,7 @@ async def send_task(
6062
if task_metadata.ephemeral
6163
else self._celery_settings.CELERY_RESULT_EXPIRES
6264
)
63-
await self._task_store.create(
64-
task_context, task_uuid, task_metadata, expiry=expiry
65-
)
65+
await self._task_store.create_task(task_id, task_metadata, expiry=expiry)
6666
return task_uuid
6767

6868
@make_async()
@@ -91,18 +91,17 @@ async def get_task_result(
9191
async_result = self._celery_app.AsyncResult(task_id)
9292
result = async_result.result
9393
if async_result.ready():
94-
task_metadata = await self._task_store.get_metadata(
95-
task_context, task_uuid
96-
)
94+
task_metadata = await self._task_store.get_task_metadata(task_id)
9795
if task_metadata is not None and task_metadata.ephemeral:
98-
await self._task_store.remove(task_context, task_uuid)
96+
await self._task_store.remove_task(task_id)
9997
return result
10098

101-
async def _get_progress_report(
99+
async def _get_task_progress_report(
102100
self, task_context: TaskContext, task_uuid: TaskUUID, state: TaskState
103101
) -> ProgressReport:
104102
if state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
105-
progress = await self._task_store.get_progress(task_context, task_uuid)
103+
task_id = build_task_id(task_context, task_uuid)
104+
progress = await self._task_store.get_task_progress(task_id)
106105
if progress is not None:
107106
return progress
108107
if state in (
@@ -119,7 +118,9 @@ async def _get_progress_report(
119118
)
120119

121120
@make_async()
122-
def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
121+
def _get_task_celery_state(
122+
self, task_context: TaskContext, task_uuid: TaskUUID
123+
) -> TaskState:
123124
task_id = build_task_id(task_context, task_uuid)
124125
return TaskState(self._celery_app.AsyncResult(task_id).state)
125126

@@ -131,18 +132,19 @@ async def get_task_status(
131132
logging.DEBUG,
132133
msg=f"Getting task status: {task_context=} {task_uuid=}",
133134
):
134-
task_state = await self._get_state(task_context, task_uuid)
135-
task_id = build_task_id(task_context, task_uuid)
135+
task_state = await self._get_task_celery_state(task_context, task_uuid)
136136
return TaskStatus(
137137
task_uuid=task_uuid,
138138
task_state=task_state,
139-
progress_report=await self._get_progress_report(task_id, task_state),
139+
progress_report=await self._get_task_progress_report(
140+
task_context, task_uuid, task_state
141+
),
140142
)
141143

142-
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
144+
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
143145
with log_context(
144146
_logger,
145147
logging.DEBUG,
146-
msg=f"Getting task uuids: {task_context=}",
148+
msg=f"Listing tasks: {task_context=}",
147149
):
148-
return await self._task_store.get_uuids(task_context)
150+
return await self._task_store.list_tasks(task_context)

0 commit comments

Comments
 (0)