Skip to content

Commit 792de64

Browse files
update signature
1 parent 347e323 commit 792de64

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from .models import (
1616
TaskContext,
17-
TaskID,
1817
TaskInfoStore,
1918
TaskMetadata,
2019
TaskState,
@@ -46,12 +45,11 @@ async def send_task(
4645
with log_context(
4746
_logger,
4847
logging.DEBUG,
49-
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
48+
msg=f"Submit {task_metadata.name=}: {task_context=} {task_params=}",
5049
):
5150
task_uuid = uuid4()
52-
task_metadata = task_metadata or TaskMetadata()
5351
self._celery_app.send_task(
54-
task_name,
52+
task_metadata.name,
5553
task_id=build_task_id(task_context, task_uuid),
5654
kwargs=task_params,
5755
queue=task_metadata.queue.value,
@@ -68,17 +66,18 @@ async def send_task(
6866
return task_uuid
6967

7068
@make_async()
71-
def _abort_task(self, task_id: TaskID) -> None:
72-
AbortableAsyncResult(task_id, app=self._celery_app).abort()
69+
def _abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
70+
AbortableAsyncResult(
71+
build_task_id(task_context, task_uuid), app=self._celery_app
72+
).abort()
7373

7474
async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
7575
with log_context(
7676
_logger,
7777
logging.DEBUG,
7878
msg=f"Abort task: {task_context=} {task_uuid=}",
7979
):
80-
task_id = build_task_id(task_context, task_uuid)
81-
await self._abort_task(task_id)
80+
await self._abort_task(task_context, task_uuid)
8281

8382
async def get_task_result(
8483
self, task_context: TaskContext, task_uuid: TaskUUID
@@ -92,16 +91,18 @@ async def get_task_result(
9291
async_result = self._celery_app.AsyncResult(task_id)
9392
result = async_result.result
9493
if async_result.ready():
95-
task_metadata = await self._task_store.get_metadata(task_id)
94+
task_metadata = await self._task_store.get_metadata(
95+
task_context, task_uuid
96+
)
9697
if task_metadata is not None and task_metadata.ephemeral:
97-
await self._task_store.remove(task_id)
98+
await self._task_store.remove(task_context, task_uuid)
9899
return result
99100

100101
async def _get_progress_report(
101-
self, task_id: TaskID, state: TaskState
102+
self, task_context: TaskContext, task_uuid: TaskUUID, state: TaskState
102103
) -> ProgressReport:
103104
if state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
104-
progress = await self._task_store.get_progress(task_id)
105+
progress = await self._task_store.get_progress(task_context, task_uuid)
105106
if progress is not None:
106107
return progress
107108
if state in (

services/storage/src/simcore_service_storage/modules/celery/models.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ class TaskMetadata(BaseModel):
5050

5151

5252
class TaskInfoStore(Protocol):
53-
async def exists(self, task_id: TaskID) -> bool: ...
54-
5553
async def create(
5654
self,
5755
task_context: TaskContext,
@@ -60,15 +58,23 @@ async def create(
6058
expiry: timedelta,
6159
) -> None: ...
6260

63-
async def get_progress(self, task_id: TaskID) -> ProgressReport | None: ...
61+
async def exists(self, task_context: TaskContext, task_uuid: TaskUUID) -> bool: ...
62+
63+
async def get_metadata(
64+
self, task_context: TaskContext, task_uuid: TaskUUID
65+
) -> TaskMetadata | None: ...
6466

65-
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None: ...
67+
async def get_progress(
68+
self, task_context: TaskContext, task_uuid: TaskUUID
69+
) -> ProgressReport | None: ...
6670

6771
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...
6872

69-
async def remove(self, task_id: TaskID) -> None: ...
73+
async def remove(self, task_context: TaskContext, task_uuid: TaskUUID) -> None: ...
7074

71-
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None: ...
75+
async def set_progress(
76+
self, task_context: TaskContext, task_uuid: TaskUUID, report: ProgressReport
77+
) -> None: ...
7278

7379

7480
class TaskStatus(BaseModel):

0 commit comments

Comments
 (0)