Skip to content

Commit e5556d7

Browse files
Merge remote-tracking branch 'upstream/master' into is8159/fix-redis-client-lifecycle
2 parents 6a62f5c + d0d210d commit e5556d7

File tree

14 files changed

+118
-134
lines changed

14 files changed

+118
-134
lines changed

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
2020
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
2121
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
22-
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
22+
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000
2323
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
2424
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"
2525

@@ -53,11 +53,6 @@ async def create_task(
5353
expiry,
5454
)
5555

56-
async def exists_task(self, task_id: TaskID) -> bool:
57-
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
58-
assert isinstance(n, int) # nosec
59-
return n > 0
60-
6156
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
6257
raw_result = await handle_redis_returns_union_types(
6358
self._redis_client_sdk.redis.hget(
@@ -143,3 +138,8 @@ async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> No
143138
value=report.model_dump_json(),
144139
)
145140
)
141+
142+
async def task_exists(self, task_id: TaskID) -> bool:
143+
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
144+
assert isinstance(n, int) # nosec
145+
return n > 0

packages/celery-library/src/celery_library/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import base64
22
import pickle
33

4+
from common_library.errors_classes import OsparcErrorMixin
5+
46

57
class TransferrableCeleryError(Exception):
68
def __repr__(self) -> str:
@@ -22,3 +24,7 @@ def decode_celery_transferrable_error(error: TransferrableCeleryError) -> Except
2224
assert isinstance(error, TransferrableCeleryError) # nosec
2325
result: Exception = pickle.loads(base64.b64decode(error.args[0])) # noqa: S301
2426
return result
27+
28+
29+
class TaskNotFoundError(OsparcErrorMixin, Exception):
30+
msg_template = "Task with id '{task_id}' was not found"

packages/celery-library/src/celery_library/rpc/_async_jobs.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from models_library.api_schemas_rpc_async_jobs.exceptions import (
1414
JobAbortedError,
1515
JobError,
16+
JobMissingError,
1617
JobNotDoneError,
1718
JobSchedulerError,
1819
)
@@ -22,6 +23,7 @@
2223
from servicelib.rabbitmq import RPCRouter
2324

2425
from ..errors import (
26+
TaskNotFoundError,
2527
TransferrableCeleryError,
2628
decode_celery_transferrable_error,
2729
)
@@ -30,7 +32,7 @@
3032
router = RPCRouter()
3133

3234

33-
@router.expose(reraise_if_error_type=(JobSchedulerError,))
35+
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
3436
async def cancel(
3537
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
3638
):
@@ -42,11 +44,13 @@ async def cancel(
4244
task_filter=task_filter,
4345
task_uuid=job_id,
4446
)
47+
except TaskNotFoundError as exc:
48+
raise JobMissingError(job_id=job_id) from exc
4549
except CeleryError as exc:
4650
raise JobSchedulerError(exc=f"{exc}") from exc
4751

4852

49-
@router.expose(reraise_if_error_type=(JobSchedulerError,))
53+
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
5054
async def status(
5155
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
5256
) -> AsyncJobStatus:
@@ -59,6 +63,8 @@ async def status(
5963
task_filter=task_filter,
6064
task_uuid=job_id,
6165
)
66+
except TaskNotFoundError as exc:
67+
raise JobMissingError(job_id=job_id) from exc
6268
except CeleryError as exc:
6369
raise JobSchedulerError(exc=f"{exc}") from exc
6470

@@ -71,9 +77,10 @@ async def status(
7177

7278
@router.expose(
7379
reraise_if_error_type=(
80+
JobAbortedError,
7481
JobError,
82+
JobMissingError,
7583
JobNotDoneError,
76-
JobAbortedError,
7784
JobSchedulerError,
7885
)
7986
)
@@ -97,11 +104,11 @@ async def result(
97104
task_filter=task_filter,
98105
task_uuid=job_id,
99106
)
107+
except TaskNotFoundError as exc:
108+
raise JobMissingError(job_id=job_id) from exc
100109
except CeleryError as exc:
101110
raise JobSchedulerError(exc=f"{exc}") from exc
102111

103-
if _status.task_state == TaskState.ABORTED:
104-
raise JobAbortedError(job_id=job_id)
105112
if _status.task_state == TaskState.FAILURE:
106113
# fallback exception to report
107114
exc_type = type(_result).__name__

packages/celery-library/src/celery_library/task.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@
66
from functools import wraps
77
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload
88

9-
from celery import Celery # type: ignore[import-untyped]
10-
from celery.contrib.abortable import ( # type: ignore[import-untyped]
11-
AbortableAsyncResult,
12-
AbortableTask,
13-
)
9+
from celery import Celery, Task # type: ignore[import-untyped]
1410
from celery.exceptions import Ignore # type: ignore[import-untyped]
1511
from common_library.async_tools import cancel_wait_task
1612
from pydantic import NonNegativeInt
@@ -39,42 +35,42 @@ class TaskAbortedError(Exception): ...
3935
def _async_task_wrapper(
4036
app: Celery,
4137
) -> Callable[
42-
[Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]]],
43-
Callable[Concatenate[AbortableTask, P], R],
38+
[Callable[Concatenate[Task, P], Coroutine[Any, Any, R]]],
39+
Callable[Concatenate[Task, P], R],
4440
]:
4541
def decorator(
46-
coro: Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]],
47-
) -> Callable[Concatenate[AbortableTask, P], R]:
42+
coro: Callable[Concatenate[Task, P], Coroutine[Any, Any, R]],
43+
) -> Callable[Concatenate[Task, P], R]:
4844
@wraps(coro)
49-
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
45+
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
5046
app_server = get_app_server(app)
5147
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
5248
assert task.request.id is not None # nosec
5349

54-
async def run_task(task_id: TaskID) -> R:
50+
async def _run_task(task_id: TaskID) -> R:
5551
try:
5652
async with asyncio.TaskGroup() as tg:
57-
main_task = tg.create_task(
53+
async_io_task = tg.create_task(
5854
coro(task, *args, **kwargs),
5955
)
6056

61-
async def abort_monitor():
62-
abortable_result = AbortableAsyncResult(task_id, app=app)
63-
while not main_task.done():
64-
if abortable_result.is_aborted():
57+
async def _abort_monitor():
58+
while not async_io_task.done():
59+
if not await app_server.task_manager.task_exists(
60+
task_id
61+
):
6562
await cancel_wait_task(
66-
main_task,
63+
async_io_task,
6764
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
6865
)
69-
AbortableAsyncResult(task_id, app=app).forget()
7066
raise TaskAbortedError
7167
await asyncio.sleep(
7268
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()
7369
)
7470

75-
tg.create_task(abort_monitor())
71+
tg.create_task(_abort_monitor())
7672

77-
return main_task.result()
73+
return async_io_task.result()
7874
except BaseExceptionGroup as eg:
7975
task_aborted_errors, other_errors = eg.split(TaskAbortedError)
8076

@@ -88,7 +84,7 @@ async def abort_monitor():
8884
raise other_errors.exceptions[0] from eg
8985

9086
return asyncio.run_coroutine_threadsafe(
91-
run_task(task.request.id),
87+
_run_task(task.request.id),
9288
app_server.event_loop,
9389
).result()
9490

@@ -102,14 +98,14 @@ def _error_handling(
10298
delay_between_retries: timedelta,
10399
dont_autoretry_for: tuple[type[Exception], ...],
104100
) -> Callable[
105-
[Callable[Concatenate[AbortableTask, P], R]],
106-
Callable[Concatenate[AbortableTask, P], R],
101+
[Callable[Concatenate[Task, P], R]],
102+
Callable[Concatenate[Task, P], R],
107103
]:
108104
def decorator(
109-
func: Callable[Concatenate[AbortableTask, P], R],
110-
) -> Callable[Concatenate[AbortableTask, P], R]:
105+
func: Callable[Concatenate[Task, P], R],
106+
) -> Callable[Concatenate[Task, P], R]:
111107
@wraps(func)
112-
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
108+
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
113109
try:
114110
return func(task, *args, **kwargs)
115111
except TaskAbortedError as exc:
@@ -144,7 +140,7 @@ def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
144140
@overload
145141
def register_task(
146142
app: Celery,
147-
fn: Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]],
143+
fn: Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]],
148144
task_name: str | None = None,
149145
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
150146
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
@@ -156,7 +152,7 @@ def register_task(
156152
@overload
157153
def register_task(
158154
app: Celery,
159-
fn: Callable[Concatenate[AbortableTask, P], R],
155+
fn: Callable[Concatenate[Task, P], R],
160156
task_name: str | None = None,
161157
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
162158
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
@@ -168,8 +164,8 @@ def register_task(
168164
def register_task( # type: ignore[misc]
169165
app: Celery,
170166
fn: (
171-
Callable[Concatenate[AbortableTask, TaskID, P], Coroutine[Any, Any, R]]
172-
| Callable[Concatenate[AbortableTask, P], R]
167+
Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]]
168+
| Callable[Concatenate[Task, P], R]
173169
),
174170
task_name: str | None = None,
175171
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
@@ -186,7 +182,7 @@ def register_task( # type: ignore[misc]
186182
delay_between_retries -- dealy between each attempt in case of error (default: {_DEFAULT_WAIT_BEFORE_RETRY})
187183
dont_autoretry_for -- exceptions that should not be retried when raised by the task
188184
"""
189-
wrapped_fn: Callable[Concatenate[AbortableTask, P], R]
185+
wrapped_fn: Callable[Concatenate[Task, P], R]
190186
if asyncio.iscoroutinefunction(fn):
191187
wrapped_fn = _async_task_wrapper(app)(fn)
192188
else:
@@ -202,7 +198,6 @@ def register_task( # type: ignore[misc]
202198
app.task(
203199
name=task_name or fn.__name__,
204200
bind=True,
205-
base=AbortableTask,
206201
time_limit=None if timeout is None else timeout.total_seconds(),
207202
pydantic=True,
208203
)(wrapped_fn)

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
from uuid import uuid4
55

66
from celery import Celery # type: ignore[import-untyped]
7-
from celery.contrib.abortable import ( # type: ignore[import-untyped]
8-
AbortableAsyncResult,
9-
)
107
from common_library.async_tools import make_async
118
from models_library.progress_bar import ProgressReport
129
from servicelib.celery.models import (
10+
TASK_DONE_STATES,
1311
Task,
1412
TaskFilter,
1513
TaskID,
@@ -23,6 +21,7 @@
2321
from servicelib.logging_utils import log_context
2422
from settings_library.celery import CelerySettings
2523

24+
from .errors import TaskNotFoundError
2625
from .utils import build_task_id
2726

2827
_logger = logging.getLogger(__name__)
@@ -69,24 +68,25 @@ async def submit_task(
6968
)
7069
return task_uuid
7170

72-
@make_async()
73-
def _abort_task(self, task_id: TaskID) -> None:
74-
AbortableAsyncResult(task_id, app=self._celery_app).abort()
75-
7671
async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None:
7772
with log_context(
7873
_logger,
7974
logging.DEBUG,
8075
msg=f"task cancellation: {task_filter=} {task_uuid=}",
8176
):
8277
task_id = build_task_id(task_filter, task_uuid)
83-
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
84-
await self._abort_task(task_id)
78+
if not await self.task_exists(task_id):
79+
raise TaskNotFoundError(task_id=task_id)
80+
8581
await self._task_info_store.remove_task(task_id)
82+
await self._forget_task(task_id)
83+
84+
async def task_exists(self, task_id: TaskID) -> bool:
85+
return await self._task_info_store.task_exists(task_id)
8686

8787
@make_async()
8888
def _forget_task(self, task_id: TaskID) -> None:
89-
AbortableAsyncResult(task_id, app=self._celery_app).forget()
89+
self._celery_app.AsyncResult(task_id).forget()
9090

9191
async def get_task_result(
9292
self, task_filter: TaskFilter, task_uuid: TaskUUID
@@ -97,27 +97,27 @@ async def get_task_result(
9797
msg=f"Get task result: {task_filter=} {task_uuid=}",
9898
):
9999
task_id = build_task_id(task_filter, task_uuid)
100+
if not await self.task_exists(task_id):
101+
raise TaskNotFoundError(task_id=task_id)
102+
100103
async_result = self._celery_app.AsyncResult(task_id)
101104
result = async_result.result
102105
if async_result.ready():
103106
task_metadata = await self._task_info_store.get_task_metadata(task_id)
104107
if task_metadata is not None and task_metadata.ephemeral:
105-
await self._forget_task(task_id)
106108
await self._task_info_store.remove_task(task_id)
109+
await self._forget_task(task_id)
107110
return result
108111

109112
async def _get_task_progress_report(
110-
self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState
113+
self, task_id: TaskID, task_state: TaskState
111114
) -> ProgressReport:
112-
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
113-
task_id = build_task_id(task_filter, task_uuid)
115+
if task_state in (TaskState.STARTED, TaskState.RETRY):
114116
progress = await self._task_info_store.get_task_progress(task_id)
115117
if progress is not None:
116118
return progress
117-
if task_state in (
118-
TaskState.SUCCESS,
119-
TaskState.FAILURE,
120-
):
119+
120+
if task_state in TASK_DONE_STATES:
121121
return ProgressReport(
122122
actual_value=_MAX_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
123123
)
@@ -140,12 +140,15 @@ async def get_task_status(
140140
msg=f"Getting task status: {task_filter=} {task_uuid=}",
141141
):
142142
task_id = build_task_id(task_filter, task_uuid)
143+
if not await self.task_exists(task_id):
144+
raise TaskNotFoundError(task_id=task_id)
145+
143146
task_state = await self._get_task_celery_state(task_id)
144147
return TaskStatus(
145148
task_uuid=task_uuid,
146149
task_state=task_state,
147150
progress_report=await self._get_task_progress_report(
148-
task_filter, task_uuid, task_state
151+
task_id, task_state
149152
),
150153
)
151154

0 commit comments

Comments
 (0)