Skip to content

Commit ba50b78

Browse files
wait for asyncio tasks
1 parent fb26261 commit ba50b78

File tree

1 file changed

+14
-7
lines changed
  • services/storage/src/simcore_service_storage/modules/celery

1 file changed

+14
-7
lines changed

services/storage/src/simcore_service_storage/modules/celery/_task.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import inspect
34
import logging
45
from collections.abc import Callable, Coroutine
@@ -13,7 +14,6 @@
1314
)
1415
from celery.exceptions import Ignore # type: ignore[import-untyped]
1516
from pydantic import NonNegativeInt
16-
from servicelib.async_utils import cancel_wait_task
1717

1818
from . import get_event_loop
1919
from .errors import encore_celery_transferrable_error
@@ -48,7 +48,7 @@ def decorator(
4848
) -> Callable[Concatenate[AbortableTask, P], R]:
4949
@wraps(coro)
5050
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
51-
fastapi_app = get_fastapi_app(app)
51+
event_loop = get_event_loop(get_fastapi_app(app))
5252
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
5353
assert task.request.id is not None # nosec
5454

@@ -63,10 +63,17 @@ async def abort_monitor():
6363
abortable_result = AbortableAsyncResult(task_id, app=app)
6464
while not main_task.done():
6565
if abortable_result.is_aborted():
66-
await cancel_wait_task(
67-
main_task,
68-
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
69-
)
66+
main_task.cancel()
67+
68+
with contextlib.suppress(
69+
asyncio.CancelledError, TimeoutError
70+
):
71+
await asyncio.wait_for(
72+
asyncio.gather(
73+
*asyncio.all_tasks(loop=event_loop)
74+
),
75+
timeout=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
76+
)
7077
AbortableAsyncResult(task_id, app=app).forget()
7178
raise TaskAbortedError
7279
await asyncio.sleep(
@@ -90,7 +97,7 @@ async def abort_monitor():
9097

9198
return asyncio.run_coroutine_threadsafe(
9299
run_task(task.request.id),
93-
get_event_loop(fastapi_app),
100+
event_loop,
94101
).result()
95102

96103
return wrapper

0 commit comments

Comments
 (0)