77from typing import Any , Concatenate , Final , ParamSpec , TypeVar , overload
88
99from celery import Celery # type: ignore[import-untyped]
10- from celery .contrib .abortable import AbortableTask # type: ignore[import-untyped]
10+ from celery .contrib .abortable import ( # type: ignore[import-untyped]
11+ AbortableAsyncResult ,
12+ AbortableTask ,
13+ )
1114from pydantic import NonNegativeInt
15+ from servicelib .async_utils import cancel_wait_task
1216
1317from . import get_event_loop
1418from .errors import encore_celery_transferrable_error
15- from .models import TaskId
19+ from .models import TaskID , TaskId
1620from .utils import get_fastapi_app
1721
1822_logger = logging .getLogger (__name__ )
1923
2024_DEFAULT_TASK_TIMEOUT : Final [timedelta | None ] = None
2125_DEFAULT_MAX_RETRIES : Final [NonNegativeInt ] = 3
2226_DEFAULT_WAIT_BEFORE_RETRY : Final [timedelta ] = timedelta (seconds = 5 )
23- _DEFAULT_DONT_AUTORETRY_FOR : Final [tuple [type [Exception ], ...]] = tuple ()
24-
27+ _DEFAULT_DONT_AUTORETRY_FOR : Final [tuple [type [Exception ], ...]] = ()
28+ _DEFAULT_ABORT_TASK_TIMEOUT : Final [ timedelta ] = timedelta ( seconds = 0.5 )
2529
2630T = TypeVar ("T" )
2731P = ParamSpec ("P" )
@@ -43,8 +47,25 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
4347 _logger .debug ("task id: %s" , task .request .id )
4448 # NOTE: task.request is a thread local object, so we need to pass the id explicitly
4549 assert task .request .id is not None # nosec
50+
51+ async def run_task (task_id : TaskID ) -> R :
52+ task_coro = asyncio .create_task (coro (task , task_id , * args , ** kwargs ))
53+
54+ can_continue = True
55+ while can_continue :
56+ if AbortableAsyncResult (task_id ).is_aborted ():
57+ _logger .warning ("Task %s was aborted by user." , task_id )
58+ await cancel_wait_task (task_coro , max_delay = 5 ) # to constant
59+ raise asyncio .CancelledError
60+ if task_coro .done ():
61+ break
62+
63+ await asyncio .sleep (_DEFAULT_ABORT_TASK_TIMEOUT .total_seconds ())
64+
65+ return task_coro .result ()
66+
4667 return asyncio .run_coroutine_threadsafe (
47- coro (task , task .request .id , * args , ** kwargs ),
68+ run_task (task .request .id ),
4869 get_event_loop (fastapi_app ),
4970 ).result ()
5071
0 commit comments