Skip to content

Commit fb26261

Browse files
✨ Enhance task cancellation (#7565)
1 parent 8c0a500 commit fb26261

File tree

11 files changed

+44
-29
lines changed

11 files changed

+44
-29
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Makefile @pcrespov @sanderegg
4141
/services/static-webserver/ @GitHK
4242
/services/static-webserver/client/ @odeimaiz
4343
/services/storage/ @sanderegg
44+
/services/storage/modules/celery @giancarloromeo
4445
/services/web/server/ @pcrespov @sanderegg @GitHK @matusdrobuliak66
4546
/tests/e2e-frontend/ @odeimaiz
4647
/tests/e2e-playwright/ @matusdrobuliak66

api/specs/web-server/_long_running_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_async_job_status(
5858
responses=_export_data_responses,
5959
status_code=status.HTTP_204_NO_CONTENT,
6060
)
61-
def abort_async_job(
61+
def cancel_async_job(
6262
_path_params: Annotated[_PathParam, Depends()],
6363
): ...
6464

services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData
3636
assert app # nosec
3737
assert job_id_data # nosec
3838
try:
39-
await get_celery_client(app).abort_task(
39+
await get_celery_client(app).cancel_task(
4040
task_context=job_id_data.model_dump(),
4141
task_uuid=job_id,
4242
)

services/storage/src/simcore_service_storage/modules/celery/_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ async def abort_monitor():
6767
main_task,
6868
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
6969
)
70+
AbortableAsyncResult(task_id, app=app).forget()
7071
raise TaskAbortedError
7172
await asyncio.sleep(
7273
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
class CeleryTaskClient:
3636
_celery_app: Celery
3737
_celery_settings: CelerySettings
38-
_task_store: TaskInfoStore
38+
_task_info_store: TaskInfoStore
3939

4040
async def submit_task(
4141
self,
@@ -63,22 +63,25 @@ async def submit_task(
6363
if task_metadata.ephemeral
6464
else self._celery_settings.CELERY_RESULT_EXPIRES
6565
)
66-
await self._task_store.create_task(task_id, task_metadata, expiry=expiry)
66+
await self._task_info_store.create_task(
67+
task_id, task_metadata, expiry=expiry
68+
)
6769
return task_uuid
6870

6971
@make_async()
70-
def _abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
71-
AbortableAsyncResult(
72-
build_task_id(task_context, task_uuid), app=self._celery_app
73-
).abort()
72+
def _abort_task(self, task_id: TaskID) -> None:
73+
AbortableAsyncResult(task_id, app=self._celery_app).abort()
7474

75-
async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
75+
async def cancel_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
7676
with log_context(
7777
_logger,
7878
logging.DEBUG,
79-
msg=f"Abort task: {task_context=} {task_uuid=}",
79+
msg=f"task cancellation: {task_context=} {task_uuid=}",
8080
):
81-
await self._abort_task(task_context, task_uuid)
81+
task_id = build_task_id(task_context, task_uuid)
82+
if not (await self.get_task_status(task_context, task_uuid)).is_done:
83+
await self._abort_task(task_id)
84+
await self._task_info_store.remove_task(task_id)
8285

8386
@make_async()
8487
def _forget_task(self, task_id: TaskID) -> None:
@@ -96,18 +99,18 @@ async def get_task_result(
9699
async_result = self._celery_app.AsyncResult(task_id)
97100
result = async_result.result
98101
if async_result.ready():
99-
task_metadata = await self._task_store.get_task_metadata(task_id)
102+
task_metadata = await self._task_info_store.get_task_metadata(task_id)
100103
if task_metadata is not None and task_metadata.ephemeral:
101-
await self._task_store.remove_task(task_id)
102104
await self._forget_task(task_id)
105+
await self._task_info_store.remove_task(task_id)
103106
return result
104107

105108
async def _get_task_progress_report(
106109
self, task_context: TaskContext, task_uuid: TaskUUID, task_state: TaskState
107110
) -> ProgressReport:
108111
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
109112
task_id = build_task_id(task_context, task_uuid)
110-
progress = await self._task_store.get_task_progress(task_id)
113+
progress = await self._task_info_store.get_task_progress(task_id)
111114
if progress is not None:
112115
return progress
113116
if task_state in (
@@ -124,10 +127,7 @@ async def _get_task_progress_report(
124127
)
125128

126129
@make_async()
127-
def _get_task_celery_state(
128-
self, task_context: TaskContext, task_uuid: TaskUUID
129-
) -> TaskState:
130-
task_id = build_task_id(task_context, task_uuid)
130+
def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
131131
return TaskState(self._celery_app.AsyncResult(task_id).state)
132132

133133
async def get_task_status(
@@ -138,7 +138,8 @@ async def get_task_status(
138138
logging.DEBUG,
139139
msg=f"Getting task status: {task_context=} {task_uuid=}",
140140
):
141-
task_state = await self._get_task_celery_state(task_context, task_uuid)
141+
task_id = build_task_id(task_context, task_uuid)
142+
task_state = await self._get_task_celery_state(task_id)
142143
return TaskStatus(
143144
task_uuid=task_uuid,
144145
task_state=task_state,
@@ -153,4 +154,4 @@ async def list_tasks(self, task_context: TaskContext) -> list[Task]:
153154
logging.DEBUG,
154155
msg=f"Listing tasks: {task_context=}",
155156
):
156-
return await self._task_store.list_tasks(task_context)
157+
return await self._task_info_store.list_tasks(task_context)

services/storage/tests/unit/test_async_jobs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,14 @@ async def test_async_jobs_cancel(
277277
job_id_data=job_id_data,
278278
)
279279

280+
jobs = await async_jobs.list_jobs(
281+
storage_rabbitmq_rpc_client,
282+
rpc_namespace=STORAGE_RPC_NAMESPACE,
283+
filter_="", # currently not used
284+
job_id_data=job_id_data,
285+
)
286+
assert async_job_get.job_id not in [job.job_id for job in jobs]
287+
280288
with pytest.raises(JobAbortedError):
281289
await async_jobs.result(
282290
storage_rabbitmq_rpc_client,

services/storage/tests/unit/test_modules_celery.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ async def test_submitting_task_with_failure_results_with_error(
166166
assert f"{raw_result}" == "Something strange happened: BOOM!"
167167

168168

169-
async def test_aborting_task_results_with_aborted_state(
169+
async def test_cancelling_a_running_task_aborts_and_deletes(
170170
celery_client: CeleryTaskClient,
171171
):
172172
task_context = TaskContext(user_id=42)
@@ -178,7 +178,7 @@ async def test_aborting_task_results_with_aborted_state(
178178
task_context=task_context,
179179
)
180180

181-
await celery_client.abort_task(task_context, task_uuid)
181+
await celery_client.cancel_task(task_context, task_uuid)
182182

183183
for attempt in Retrying(
184184
retry=retry_if_exception_type(AssertionError),
@@ -193,6 +193,8 @@ async def test_aborting_task_results_with_aborted_state(
193193
await celery_client.get_task_status(task_context, task_uuid)
194194
).task_state == TaskState.ABORTED
195195

196+
assert task_uuid not in await celery_client.list_tasks(task_context)
197+
196198

197199
async def test_listing_task_uuids_contains_submitted_task(
198200
celery_client: CeleryTaskClient,

services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3129,7 +3129,7 @@ paths:
31293129
- long-running-tasks
31303130
summary: Cancel And Delete Task
31313131
description: Cancels and deletes a task
3132-
operationId: abort_async_job
3132+
operationId: cancel_async_job
31333133
parameters:
31343134
- name: task_id
31353135
in: path

services/web/server/src/simcore_service_webserver/storage/_rest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _create_data_response_from_async_job(
185185
task_id=async_job_id,
186186
task_name=async_job_id,
187187
status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=async_job_id)))}",
188-
abort_href=f"{request.url.with_path(str(request.app.router['abort_async_job'].url_for(task_id=async_job_id)))}",
188+
abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=async_job_id)))}",
189189
result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=async_job_id)))}",
190190
),
191191
status=status.HTTP_202_ACCEPTED,
@@ -505,7 +505,7 @@ def allow_only_simcore(cls, v: int) -> int:
505505
task_id=_job_id,
506506
task_name=_job_id,
507507
status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=_job_id)))}",
508-
abort_href=f"{request.url.with_path(str(request.app.router['abort_async_job'].url_for(task_id=_job_id)))}",
508+
abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=_job_id)))}",
509509
result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=_job_id)))}",
510510
),
511511
status=status.HTTP_202_ACCEPTED,

services/web/server/src/simcore_service_webserver/tasks/_rest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def get_async_jobs(request: web.Request) -> web.Response:
8787
task_id=f"{job.job_id}",
8888
task_name=job.job_name,
8989
status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=str(job.job_id))))}",
90-
abort_href=f"{request.url.with_path(str(request.app.router['abort_async_job'].url_for(task_id=str(job.job_id))))}",
90+
abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=str(job.job_id))))}",
9191
result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=str(job.job_id))))}",
9292
)
9393
for job in user_async_jobs
@@ -136,17 +136,18 @@ async def get_async_job_status(request: web.Request) -> web.Response:
136136

137137
@routes.delete(
138138
_task_prefix + "/{task_id}",
139-
name="abort_async_job",
139+
name="cancel_async_job",
140140
)
141141
@login_required
142142
@permission_required("storage.files.*")
143143
@handle_export_data_exceptions
144-
async def abort_async_job(request: web.Request) -> web.Response:
144+
async def cancel_async_job(request: web.Request) -> web.Response:
145145

146146
_req_ctx = RequestContext.model_validate(request)
147147

148148
rabbitmq_rpc_client = get_rabbitmq_rpc_client(request.app)
149149
async_job_get = parse_request_path_parameters_as(_StorageAsyncJobId, request)
150+
150151
await async_jobs.cancel(
151152
rabbitmq_rpc_client=rabbitmq_rpc_client,
152153
rpc_namespace=STORAGE_RPC_NAMESPACE,
@@ -155,6 +156,7 @@ async def abort_async_job(request: web.Request) -> web.Response:
155156
user_id=_req_ctx.user_id, product_name=_req_ctx.product_name
156157
),
157158
)
159+
158160
return web.Response(status=status.HTTP_204_NO_CONTENT)
159161

160162

0 commit comments

Comments
 (0)