66from functools import wraps
77from typing import Any , Concatenate , Final , ParamSpec , TypeVar , overload
88
9- from celery import Celery # type: ignore[import-untyped]
10- from celery .contrib .abortable import ( # type: ignore[import-untyped]
11- AbortableAsyncResult ,
12- AbortableTask ,
13- )
9+ from celery import Celery , Task # type: ignore[import-untyped]
1410from celery .exceptions import Ignore # type: ignore[import-untyped]
1511from common_library .async_tools import cancel_wait_task
1612from pydantic import NonNegativeInt
@@ -39,42 +35,42 @@ class TaskAbortedError(Exception): ...
3935def _async_task_wrapper (
4036 app : Celery ,
4137) -> Callable [
42- [Callable [Concatenate [AbortableTask , P ], Coroutine [Any , Any , R ]]],
43- Callable [Concatenate [AbortableTask , P ], R ],
38+ [Callable [Concatenate [Task , P ], Coroutine [Any , Any , R ]]],
39+ Callable [Concatenate [Task , P ], R ],
4440]:
4541 def decorator (
46- coro : Callable [Concatenate [AbortableTask , P ], Coroutine [Any , Any , R ]],
47- ) -> Callable [Concatenate [AbortableTask , P ], R ]:
42+ coro : Callable [Concatenate [Task , P ], Coroutine [Any , Any , R ]],
43+ ) -> Callable [Concatenate [Task , P ], R ]:
4844 @wraps (coro )
49- def wrapper (task : AbortableTask , * args : P .args , ** kwargs : P .kwargs ) -> R :
45+ def wrapper (task : Task , * args : P .args , ** kwargs : P .kwargs ) -> R :
5046 app_server = get_app_server (app )
5147 # NOTE: task.request is a thread local object, so we need to pass the id explicitly
5248 assert task .request .id is not None # nosec
5349
54- async def run_task (task_id : TaskID ) -> R :
50+ async def _run_task (task_id : TaskID ) -> R :
5551 try :
5652 async with asyncio .TaskGroup () as tg :
57- main_task = tg .create_task (
53+ async_io_task = tg .create_task (
5854 coro (task , * args , ** kwargs ),
5955 )
6056
61- async def abort_monitor ():
62- abortable_result = AbortableAsyncResult (task_id , app = app )
63- while not main_task .done ():
64- if abortable_result .is_aborted ():
57+ async def _abort_monitor ():
58+ while not async_io_task .done ():
59+ if not await app_server .task_manager .task_exists (
60+ task_id
61+ ):
6562 await cancel_wait_task (
66- main_task ,
63+ async_io_task ,
6764 max_delay = _DEFAULT_CANCEL_TASK_TIMEOUT .total_seconds (),
6865 )
69- AbortableAsyncResult (task_id , app = app ).forget ()
7066 raise TaskAbortedError
7167 await asyncio .sleep (
7268 _DEFAULT_ABORT_TASK_TIMEOUT .total_seconds ()
7369 )
7470
75- tg .create_task (abort_monitor ())
71+ tg .create_task (_abort_monitor ())
7672
77- return main_task .result ()
73+ return async_io_task .result ()
7874 except BaseExceptionGroup as eg :
7975 task_aborted_errors , other_errors = eg .split (TaskAbortedError )
8076
@@ -88,7 +84,7 @@ async def abort_monitor():
8884 raise other_errors .exceptions [0 ] from eg
8985
9086 return asyncio .run_coroutine_threadsafe (
91- run_task (task .request .id ),
87+ _run_task (task .request .id ),
9288 app_server .event_loop ,
9389 ).result ()
9490
@@ -102,14 +98,14 @@ def _error_handling(
10298 delay_between_retries : timedelta ,
10399 dont_autoretry_for : tuple [type [Exception ], ...],
104100) -> Callable [
105- [Callable [Concatenate [AbortableTask , P ], R ]],
106- Callable [Concatenate [AbortableTask , P ], R ],
101+ [Callable [Concatenate [Task , P ], R ]],
102+ Callable [Concatenate [Task , P ], R ],
107103]:
108104 def decorator (
109- func : Callable [Concatenate [AbortableTask , P ], R ],
110- ) -> Callable [Concatenate [AbortableTask , P ], R ]:
105+ func : Callable [Concatenate [Task , P ], R ],
106+ ) -> Callable [Concatenate [Task , P ], R ]:
111107 @wraps (func )
112- def wrapper (task : AbortableTask , * args : P .args , ** kwargs : P .kwargs ) -> R :
108+ def wrapper (task : Task , * args : P .args , ** kwargs : P .kwargs ) -> R :
113109 try :
114110 return func (task , * args , ** kwargs )
115111 except TaskAbortedError as exc :
@@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
144140@overload
145141def register_task (
146142 app : Celery ,
147- fn : Callable [Concatenate [AbortableTask , TaskID , P ], Coroutine [Any , Any , R ]],
143+ fn : Callable [Concatenate [Task , TaskID , P ], Coroutine [Any , Any , R ]],
148144 task_name : str | None = None ,
149145 timeout : timedelta | None = _DEFAULT_TASK_TIMEOUT ,
150146 max_retries : NonNegativeInt = _DEFAULT_MAX_RETRIES ,
@@ -156,7 +152,7 @@ def register_task(
156152@overload
157153def register_task (
158154 app : Celery ,
159- fn : Callable [Concatenate [AbortableTask , P ], R ],
155+ fn : Callable [Concatenate [Task , P ], R ],
160156 task_name : str | None = None ,
161157 timeout : timedelta | None = _DEFAULT_TASK_TIMEOUT ,
162158 max_retries : NonNegativeInt = _DEFAULT_MAX_RETRIES ,
@@ -168,8 +164,8 @@ def register_task(
168164def register_task ( # type: ignore[misc]
169165 app : Celery ,
170166 fn : (
171- Callable [Concatenate [AbortableTask , TaskID , P ], Coroutine [Any , Any , R ]]
172- | Callable [Concatenate [AbortableTask , P ], R ]
167+ Callable [Concatenate [Task , TaskID , P ], Coroutine [Any , Any , R ]]
168+ | Callable [Concatenate [Task , P ], R ]
173169 ),
174170 task_name : str | None = None ,
175171 timeout : timedelta | None = _DEFAULT_TASK_TIMEOUT ,
@@ -186,7 +182,7 @@ def register_task( # type: ignore[misc]
186182 delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY})
187183 dont_autoretry_for -- exceptions that should not be retried when raised by the task
188184 """
189- wrapped_fn : Callable [Concatenate [AbortableTask , P ], R ]
185+ wrapped_fn : Callable [Concatenate [Task , P ], R ]
190186 if asyncio .iscoroutinefunction (fn ):
191187 wrapped_fn = _async_task_wrapper (app )(fn )
192188 else :
@@ -202,7 +198,6 @@ def register_task( # type: ignore[misc]
202198 app .task (
203199 name = task_name or fn .__name__ ,
204200 bind = True ,
205- base = AbortableTask ,
206201 time_limit = None if timeout is None else timeout .total_seconds (),
207202 pydantic = True ,
208203 )(wrapped_fn )
0 commit comments