Skip to content

Commit 75545fe

Browse files
fix: remove abortable task
1 parent 7f0830f commit 75545fe

File tree

4 files changed

+32
-45
lines changed

4 files changed

+32
-45
lines changed

packages/celery-library/src/celery_library/task.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@
66
from functools import wraps
77
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload
88

9-
from celery import Celery # type: ignore[import-untyped]
10-
from celery.contrib.abortable import ( # type: ignore[import-untyped]
11-
AbortableAsyncResult,
12-
AbortableTask,
13-
)
9+
from celery import Celery, Task # type: ignore[import-untyped]
1410
from celery.exceptions import Ignore # type: ignore[import-untyped]
1511
from common_library.async_tools import cancel_wait_task
1612
from pydantic import NonNegativeInt
@@ -39,14 +35,14 @@ class TaskAbortedError(Exception): ...
3935
def _async_task_wrapper(
4036
app: Celery,
4137
) -> Callable[
42-
[Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]]],
43-
Callable[Concatenate[AbortableTask, P], R],
38+
[Callable[Concatenate[Task, P], Coroutine[Any, Any, R]]],
39+
Callable[Concatenate[Task, P], R],
4440
]:
4541
def decorator(
46-
coro: Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]],
47-
) -> Callable[Concatenate[AbortableTask, P], R]:
42+
coro: Callable[Concatenate[Task, P], Coroutine[Any, Any, R]],
43+
) -> Callable[Concatenate[Task, P], R]:
4844
@wraps(coro)
49-
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
45+
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
5046
app_server = get_app_server(app)
5147
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
5248
assert task.request.id is not None # nosec
@@ -59,14 +55,14 @@ async def run_task(task_id: TaskID) -> R:
5955
)
6056

6157
async def abort_monitor():
62-
abortable_result = AbortableAsyncResult(task_id, app=app)
6358
while not main_task.done():
64-
if abortable_result.is_aborted():
59+
if not await app_server.task_manager.exists_task(
60+
task_id
61+
):
6562
await cancel_wait_task(
6663
main_task,
6764
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
6865
)
69-
AbortableAsyncResult(task_id, app=app).forget()
7066
raise TaskAbortedError
7167
await asyncio.sleep(
7268
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()
@@ -102,14 +98,14 @@ def _error_handling(
10298
delay_between_retries: timedelta,
10399
dont_autoretry_for: tuple[type[Exception], ...],
104100
) -> Callable[
105-
[Callable[Concatenate[AbortableTask, P], R]],
106-
Callable[Concatenate[AbortableTask, P], R],
101+
[Callable[Concatenate[Task, P], R]],
102+
Callable[Concatenate[Task, P], R],
107103
]:
108104
def decorator(
109-
func: Callable[Concatenate[AbortableTask, P], R],
110-
) -> Callable[Concatenate[AbortableTask, P], R]:
105+
func: Callable[Concatenate[Task, P], R],
106+
) -> Callable[Concatenate[Task, P], R]:
111107
@wraps(func)
112-
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
108+
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
113109
try:
114110
return func(task, *args, **kwargs)
115111
except TaskAbortedError as exc:
@@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
144140
@overload
145141
def register_task(
146142
app: Celery,
147-
fn: Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]],
143+
fn: Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]],
148144
task_name: str | None = None,
149145
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
150146
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
@@ -156,7 +152,7 @@ def register_task(
156152
@overload
157153
def register_task(
158154
app: Celery,
159-
fn: Callable[Concatenate[AbortableTask, P], R],
155+
fn: Callable[Concatenate[Task, P], R],
160156
task_name: str | None = None,
161157
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
162158
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
@@ -168,8 +164,8 @@ def register_task(
168164
def register_task( # type: ignore[misc]
169165
app: Celery,
170166
fn: (
171-
Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]]
172-
| Callable[Concatenate[AbortableTask, P], R]
167+
Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]]
168+
| Callable[Concatenate[Task, P], R]
173169
),
174170
task_name: str | None = None,
175171
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
@@ -186,7 +182,7 @@ def register_task( # type: ignore[misc]
186182
delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY})
187183
dont_autoretry_for -- exceptions that should not be retried when raised by the task
188184
"""
189-
wrapped_fn: Callable[Concatenate[AbortableTask, P], R]
185+
wrapped_fn: Callable[Concatenate[Task, P], R]
190186
if asyncio.iscoroutinefunction(fn):
191187
wrapped_fn = _async_task_wrapper(app)(fn)
192188
else:
@@ -202,7 +198,6 @@ def register_task( # type: ignore[misc]
202198
app.task(
203199
name=task_name or fn.__name__,
204200
bind=True,
205-
base=AbortableTask,
206201
time_limit=None if timeout is None else timeout.total_seconds(),
207202
pydantic=True,
208203
)(wrapped_fn)

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from uuid import uuid4
55

66
from celery import Celery # type: ignore[import-untyped]
7-
from celery.contrib.abortable import ( # type: ignore[import-untyped]
8-
AbortableAsyncResult,
9-
)
107
from common_library.async_tools import make_async
118
from models_library.progress_bar import ProgressReport
129
from servicelib.celery.models import (
@@ -69,24 +66,22 @@ async def submit_task(
6966
)
7067
return task_uuid
7168

72-
@make_async()
73-
def _abort_task(self, task_id: TaskID) -> None:
74-
AbortableAsyncResult(task_id, app=self._celery_app).abort()
75-
7669
async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None:
7770
with log_context(
7871
_logger,
7972
logging.DEBUG,
8073
msg=f"task cancellation: {task_filter=} {task_uuid=}",
8174
):
8275
task_id = build_task_id(task_filter, task_uuid)
83-
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
84-
await self._abort_task(task_id)
76+
await self._forget_task(task_id)
8577
await self._task_info_store.remove_task(task_id)
8678

79+
async def exists_task(self, task_id: TaskID) -> bool:
80+
return await self._task_info_store.exists_task(task_id)
81+
8782
@make_async()
8883
def _forget_task(self, task_id: TaskID) -> None:
89-
AbortableAsyncResult(task_id, app=self._celery_app).forget()
84+
self._celery_app.AsyncResult(task_id).forget()
9085

9186
async def get_task_result(
9287
self, task_filter: TaskFilter, task_uuid: TaskUUID
@@ -109,18 +104,11 @@ async def get_task_result(
109104
async def _get_task_progress_report(
110105
self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState
111106
) -> ProgressReport:
112-
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
107+
if task_state in (TaskState.STARTED, TaskState.RETRY):
113108
task_id = build_task_id(task_filter, task_uuid)
114109
progress = await self._task_info_store.get_task_progress(task_id)
115110
if progress is not None:
116111
return progress
117-
if task_state in (
118-
TaskState.SUCCESS,
119-
TaskState.FAILURE,
120-
):
121-
return ProgressReport(
122-
actual_value=_MAX_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
123-
)
124112

125113
# task is pending
126114
return ProgressReport(
@@ -140,7 +128,10 @@ async def get_task_status(
140128
msg=f"Getting task status: {task_filter=} {task_uuid=}",
141129
):
142130
task_id = build_task_id(task_filter, task_uuid)
143-
task_state = await self._get_task_celery_state(task_id)
131+
if not await self.exists_task(task_id):
132+
task_state = TaskState.ABORTED
133+
else:
134+
task_state = await self._get_task_celery_state(task_id)
144135
return TaskStatus(
145136
task_uuid=task_uuid,
146137
task_state=task_state,

packages/celery-library/tests/unit/test_tasks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import pytest
1414
from celery import Celery, Task # pylint: disable=no-name-in-module
15-
from celery.contrib.abortable import AbortableTask # pylint: disable=no-name-in-module
1615
from celery_library.errors import TransferrableCeleryError
1716
from celery_library.task import register_task
1817
from celery_library.task_manager import CeleryTaskManager
@@ -72,7 +71,7 @@ def failure_task(task: Task, task_id: TaskID) -> None:
7271
raise MyError(msg=msg)
7372

7473

75-
async def dreamer_task(task: AbortableTask, task_id: TaskID) -> list[int]:
74+
async def dreamer_task(task: Task, task_id: TaskID) -> list[int]:
7675
numbers = []
7776
for _ in range(30):
7877
numbers.append(randint(1, 90)) # noqa: S311

packages/service-library/src/servicelib/celery/task_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ async def cancel_task(
2121
self, task_filter: TaskFilter, task_uuid: TaskUUID
2222
) -> None: ...
2323

24+
async def exists_task(self, task_id: TaskID) -> bool: ...
25+
2426
async def get_task_result(
2527
self, task_filter: TaskFilter, task_uuid: TaskUUID
2628
) -> Any: ...

0 commit comments

Comments
 (0)