66from functools import wraps
77from typing import Any , Concatenate , Final , ParamSpec , TypeVar , overload
88
9- from celery import Celery # type: ignore[import-untyped]
10- from celery .contrib .abortable import ( # type: ignore[import-untyped]
11- AbortableAsyncResult ,
12- AbortableTask ,
13- )
9+ from celery import Celery , Task # type: ignore[import-untyped]
1410from celery .exceptions import Ignore # type: ignore[import-untyped]
1511from common_library .async_tools import cancel_wait_task
1612from pydantic import NonNegativeInt
@@ -39,14 +35,14 @@ class TaskAbortedError(Exception): ...
3935def _async_task_wrapper (
4036 app : Celery ,
4137) -> Callable [
42- [Callable [Concatenate [AbortableTask , P ], Coroutine [Any , Any , R ]]],
43- Callable [Concatenate [AbortableTask , P ], R ],
38+ [Callable [Concatenate [Task , P ], Coroutine [Any , Any , R ]]],
39+ Callable [Concatenate [Task , P ], R ],
4440]:
4541 def decorator (
46- coro : Callable [Concatenate [AbortableTask , P ], Coroutine [Any , Any , R ]],
47- ) -> Callable [Concatenate [AbortableTask , P ], R ]:
42+ coro : Callable [Concatenate [Task , P ], Coroutine [Any , Any , R ]],
43+ ) -> Callable [Concatenate [Task , P ], R ]:
4844 @wraps (coro )
49- def wrapper (task : AbortableTask , * args : P .args , ** kwargs : P .kwargs ) -> R :
45+ def wrapper (task : Task , * args : P .args , ** kwargs : P .kwargs ) -> R :
5046 app_server = get_app_server (app )
5147 # NOTE: task.request is a thread local object, so we need to pass the id explicitly
5248 assert task .request .id is not None # nosec
@@ -59,14 +55,14 @@ async def run_task(task_id: TaskID) -> R:
5955 )
6056
6157 async def abort_monitor ():
62- abortable_result = AbortableAsyncResult (task_id , app = app )
6358 while not main_task .done ():
64- if abortable_result .is_aborted ():
59+ if not await app_server .task_manager .exists_task (
60+ task_id
61+ ):
6562 await cancel_wait_task (
6663 main_task ,
6764 max_delay = _DEFAULT_CANCEL_TASK_TIMEOUT .total_seconds (),
6865 )
69- AbortableAsyncResult (task_id , app = app ).forget ()
7066 raise TaskAbortedError
7167 await asyncio .sleep (
7268 _DEFAULT_ABORT_TASK_TIMEOUT .total_seconds ()
@@ -102,14 +98,14 @@ def _error_handling(
10298 delay_between_retries : timedelta ,
10399 dont_autoretry_for : tuple [type [Exception ], ...],
104100) -> Callable [
105- [Callable [Concatenate [AbortableTask , P ], R ]],
106- Callable [Concatenate [AbortableTask , P ], R ],
101+ [Callable [Concatenate [Task , P ], R ]],
102+ Callable [Concatenate [Task , P ], R ],
107103]:
108104 def decorator (
109- func : Callable [Concatenate [AbortableTask , P ], R ],
110- ) -> Callable [Concatenate [AbortableTask , P ], R ]:
105+ func : Callable [Concatenate [Task , P ], R ],
106+ ) -> Callable [Concatenate [Task , P ], R ]:
111107 @wraps (func )
112- def wrapper (task : AbortableTask , * args : P .args , ** kwargs : P .kwargs ) -> R :
108+ def wrapper (task : Task , * args : P .args , ** kwargs : P .kwargs ) -> R :
113109 try :
114110 return func (task , * args , ** kwargs )
115111 except TaskAbortedError as exc :
@@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
144140@overload
145141def register_task (
146142 app : Celery ,
147- fn : Callable [Concatenate [AbortableTask , TaskID , P ], Coroutine [Any , Any , R ]],
143+ fn : Callable [Concatenate [Task , TaskID , P ], Coroutine [Any , Any , R ]],
148144 task_name : str | None = None ,
149145 timeout : timedelta | None = _DEFAULT_TASK_TIMEOUT ,
150146 max_retries : NonNegativeInt = _DEFAULT_MAX_RETRIES ,
@@ -156,7 +152,7 @@ def register_task(
156152@overload
157153def register_task (
158154 app : Celery ,
159- fn : Callable [Concatenate [AbortableTask , P ], R ],
155+ fn : Callable [Concatenate [Task , P ], R ],
160156 task_name : str | None = None ,
161157 timeout : timedelta | None = _DEFAULT_TASK_TIMEOUT ,
162158 max_retries : NonNegativeInt = _DEFAULT_MAX_RETRIES ,
@@ -168,8 +164,8 @@ def register_task(
168164def register_task ( # type: ignore[misc]
169165 app : Celery ,
170166 fn : (
171- Callable [Concatenate [AbortableTask , TaskID , P ], Coroutine [Any , Any , R ]]
172- | Callable [Concatenate [AbortableTask , P ], R ]
167+ Callable [Concatenate [Task , TaskID , P ], Coroutine [Any , Any , R ]]
168+ | Callable [Concatenate [Task , P ], R ]
173169 ),
174170 task_name : str | None = None ,
175171 timeout : timedelta | None = _DEFAULT_TASK_TIMEOUT ,
@@ -186,7 +182,7 @@ def register_task( # type: ignore[misc]
186182 delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY})
187183 dont_autoretry_for -- exceptions that should not be retried when raised by the task
188184 """
189- wrapped_fn : Callable [Concatenate [AbortableTask , P ], R ]
185+ wrapped_fn : Callable [Concatenate [Task , P ], R ]
190186 if asyncio .iscoroutinefunction (fn ):
191187 wrapped_fn = _async_task_wrapper (app )(fn )
192188 else :
@@ -202,7 +198,6 @@ def register_task( # type: ignore[misc]
202198 app .task (
203199 name = task_name or fn .__name__ ,
204200 bind = True ,
205- base = AbortableTask ,
206201 time_limit = None if timeout is None else timeout .total_seconds (),
207202 pydantic = True ,
208203 )(wrapped_fn )
0 commit comments