11import asyncio
2+ import contextlib
23import inspect
34import logging
45from collections .abc import Callable , Coroutine
1314)
1415from celery .exceptions import Ignore # type: ignore[import-untyped]
1516from pydantic import NonNegativeInt
16- from servicelib .async_utils import cancel_wait_task
1717
1818from . import get_event_loop
1919from .errors import encore_celery_transferrable_error
@@ -48,7 +48,7 @@ def decorator(
4848 ) -> Callable [Concatenate [AbortableTask , P ], R ]:
4949 @wraps (coro )
5050 def wrapper (task : AbortableTask , * args : P .args , ** kwargs : P .kwargs ) -> R :
51- fastapi_app = get_fastapi_app (app )
51+ event_loop = get_event_loop ( get_fastapi_app (app ) )
5252 # NOTE: task.request is a thread local object, so we need to pass the id explicitly
5353 assert task .request .id is not None # nosec
5454
@@ -63,10 +63,17 @@ async def abort_monitor():
6363 abortable_result = AbortableAsyncResult (task_id , app = app )
6464 while not main_task .done ():
6565 if abortable_result .is_aborted ():
66- await cancel_wait_task (
67- main_task ,
68- max_delay = _DEFAULT_CANCEL_TASK_TIMEOUT .total_seconds (),
69- )
66+ main_task .cancel ()
67+
68+ with contextlib .suppress (
69+ asyncio .CancelledError , TimeoutError
70+ ):
71+ await asyncio .wait_for (
72+ asyncio .gather (
73+ * asyncio .all_tasks (loop = event_loop )
74+ ),
75+ timeout = _DEFAULT_CANCEL_TASK_TIMEOUT .total_seconds (),
76+ )
7077 AbortableAsyncResult (task_id , app = app ).forget ()
7178 raise TaskAbortedError
7279 await asyncio .sleep (
@@ -90,7 +97,7 @@ async def abort_monitor():
9097
9198 return asyncio .run_coroutine_threadsafe (
9299 run_task (task .request .id ),
93- get_event_loop ( fastapi_app ) ,
100+ event_loop ,
94101 ).result ()
95102
96103 return wrapper
0 commit comments