Skip to content

Commit fe327f4

Browse files
Monitor Celery tasks cancellation
1 parent bb3b81c commit fe327f4

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

services/storage/src/simcore_service_storage/modules/celery/_task.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@
77
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload
88

99
from celery import Celery # type: ignore[import-untyped]
10-
from celery.contrib.abortable import AbortableTask # type: ignore[import-untyped]
10+
from celery.contrib.abortable import ( # type: ignore[import-untyped]
11+
AbortableAsyncResult,
12+
AbortableTask,
13+
)
1114
from pydantic import NonNegativeInt
15+
from servicelib.async_utils import cancel_wait_task
1216

1317
from . import get_event_loop
1418
from .errors import encore_celery_transferrable_error
15-
from .models import TaskId
19+
from .models import TaskID, TaskId
1620
from .utils import get_fastapi_app
1721

1822
_logger = logging.getLogger(__name__)
1923

2024
_DEFAULT_TASK_TIMEOUT: Final[timedelta | None] = None
2125
_DEFAULT_MAX_RETRIES: Final[NonNegativeInt] = 3
2226
_DEFAULT_WAIT_BEFORE_RETRY: Final[timedelta] = timedelta(seconds=5)
23-
_DEFAULT_DONT_AUTORETRY_FOR: Final[tuple[type[Exception], ...]] = tuple()
24-
27+
_DEFAULT_DONT_AUTORETRY_FOR: Final[tuple[type[Exception], ...]] = ()
28+
_DEFAULT_ABORT_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=0.5)
2529

2630
T = TypeVar("T")
2731
P = ParamSpec("P")
@@ -43,8 +47,25 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
4347
_logger.debug("task id: %s", task.request.id)
4448
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
4549
assert task.request.id is not None # nosec
50+
51+
async def run_task(task_id: TaskID) -> R:
52+
task_coro = asyncio.create_task(coro(task, task_id, *args, **kwargs))
53+
54+
can_continue = True
55+
while can_continue:
56+
if AbortableAsyncResult(task_id).is_aborted():
57+
_logger.warning("Task %s was aborted by user.", task_id)
58+
await cancel_wait_task(task_coro, max_delay=5) # to constant
59+
raise asyncio.CancelledError
60+
if task_coro.done():
61+
break
62+
63+
await asyncio.sleep(_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds())
64+
65+
return task_coro.result()
66+
4667
return asyncio.run_coroutine_threadsafe(
47-
coro(task, task.request.id, *args, **kwargs),
68+
run_task(task.request.id),
4869
get_event_loop(fastapi_app),
4970
).result()
5071

services/storage/tests/unit/test_async_jobs.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ async def test_async_jobs_workflow(
240240
@pytest.mark.parametrize(
241241
"exposed_rpc_start",
242242
[
243-
rpc_sync_job.__name__,
244243
rpc_async_job.__name__,
245244
],
246245
)
@@ -259,7 +258,7 @@ async def test_async_jobs_cancel(
259258
user_id=user_id,
260259
product_name=product_name,
261260
action=Action.SLEEP,
262-
payload=10,
261+
payload=60 * 10, # test hangs if not cancelled properly
263262
)
264263

265264
await async_jobs.cancel(
@@ -283,6 +282,30 @@ async def test_async_jobs_cancel(
283282
job_id_data=job_id_data,
284283
)
285284

285+
async_job_get, job_id_data = await _start_task_via_rpc(
286+
storage_rabbitmq_rpc_client,
287+
rpc_task_name=exposed_rpc_start,
288+
user_id=user_id,
289+
product_name=product_name,
290+
action=Action.ECHO,
291+
payload="bla",
292+
)
293+
294+
await _wait_for_job(
295+
storage_rabbitmq_rpc_client,
296+
async_job_get=async_job_get,
297+
job_id_data=job_id_data,
298+
stop_after=timedelta(seconds=15),
299+
)
300+
301+
async_job_result = await async_jobs.result(
302+
storage_rabbitmq_rpc_client,
303+
rpc_namespace=STORAGE_RPC_NAMESPACE,
304+
job_id=async_job_get.job_id,
305+
job_id_data=job_id_data,
306+
)
307+
assert async_job_result.result == "bla"
308+
286309

287310
@pytest.mark.parametrize(
288311
"exposed_rpc_start",

0 commit comments

Comments
 (0)