Skip to content

Commit 3e05de4

Browse files
🎨 introduce task filter class in celery (#8076)
1 parent 415442c commit 3e05de4

File tree

25 files changed

+319
-236
lines changed

25 files changed

+319
-236
lines changed

packages/celery-library/src/celery_library/backends/_redis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import ValidationError
88
from servicelib.celery.models import (
99
Task,
10-
TaskContext,
10+
TaskFilter,
1111
TaskID,
1212
TaskMetadata,
1313
TaskUUID,
@@ -82,10 +82,10 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8282
)
8383
return None
8484

85-
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
85+
async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
8686
search_key = (
8787
_CELERY_TASK_INFO_PREFIX
88-
+ build_task_id_prefix(task_context)
88+
+ build_task_id_prefix(task_filter)
8989
+ _CELERY_TASK_ID_KEY_SEPARATOR
9090
)
9191
search_key_len = len(search_key)

packages/celery-library/src/celery_library/rpc/_async_jobs.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from celery.exceptions import CeleryError # type: ignore[import-untyped]
66
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
7+
AsyncJobFilter,
78
AsyncJobGet,
89
AsyncJobId,
9-
AsyncJobNameData,
1010
AsyncJobResult,
1111
AsyncJobStatus,
1212
)
@@ -16,7 +16,7 @@
1616
JobNotDoneError,
1717
JobSchedulerError,
1818
)
19-
from servicelib.celery.models import TaskState
19+
from servicelib.celery.models import TaskFilter, TaskState
2020
from servicelib.celery.task_manager import TaskManager
2121
from servicelib.logging_utils import log_catch
2222
from servicelib.rabbitmq import RPCRouter
@@ -32,13 +32,14 @@
3232

3333
@router.expose(reraise_if_error_type=(JobSchedulerError,))
3434
async def cancel(
35-
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
35+
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
3636
):
3737
assert task_manager # nosec
38-
assert job_id_data # nosec
38+
assert job_filter # nosec
39+
task_filter = TaskFilter.model_validate(job_filter.model_dump())
3940
try:
4041
await task_manager.cancel_task(
41-
task_context=job_id_data.model_dump(),
42+
task_filter=task_filter,
4243
task_uuid=job_id,
4344
)
4445
except CeleryError as exc:
@@ -47,14 +48,15 @@ async def cancel(
4748

4849
@router.expose(reraise_if_error_type=(JobSchedulerError,))
4950
async def status(
50-
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
51+
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
5152
) -> AsyncJobStatus:
5253
assert task_manager # nosec
53-
assert job_id_data # nosec
54+
assert job_filter # nosec
5455

56+
task_filter = TaskFilter.model_validate(job_filter.model_dump())
5557
try:
5658
task_status = await task_manager.get_task_status(
57-
task_context=job_id_data.model_dump(),
59+
task_filter=task_filter,
5860
task_uuid=job_id,
5961
)
6062
except CeleryError as exc:
@@ -76,21 +78,23 @@ async def status(
7678
)
7779
)
7880
async def result(
79-
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
81+
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
8082
) -> AsyncJobResult:
8183
assert task_manager # nosec
8284
assert job_id # nosec
83-
assert job_id_data # nosec
85+
assert job_filter # nosec
86+
87+
task_filter = TaskFilter.model_validate(job_filter.model_dump())
8488

8589
try:
8690
_status = await task_manager.get_task_status(
87-
task_context=job_id_data.model_dump(),
91+
task_filter=task_filter,
8892
task_uuid=job_id,
8993
)
9094
if not _status.is_done:
9195
raise JobNotDoneError(job_id=job_id)
9296
_result = await task_manager.get_task_result(
93-
task_context=job_id_data.model_dump(),
97+
task_filter=task_filter,
9498
task_uuid=job_id,
9599
)
96100
except CeleryError as exc:
@@ -123,13 +127,14 @@ async def result(
123127

124128
@router.expose(reraise_if_error_type=(JobSchedulerError,))
125129
async def list_jobs(
126-
task_manager: TaskManager, filter_: str, job_id_data: AsyncJobNameData
130+
task_manager: TaskManager, filter_: str, job_filter: AsyncJobFilter
127131
) -> list[AsyncJobGet]:
128132
_ = filter_
129133
assert task_manager # nosec
134+
task_filter = TaskFilter.model_validate(job_filter.model_dump())
130135
try:
131136
tasks = await task_manager.list_tasks(
132-
task_context=job_id_data.model_dump(),
137+
task_filter=task_filter,
133138
)
134139
except CeleryError as exc:
135140
raise JobSchedulerError(exc=f"{exc}") from exc

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44
from uuid import uuid4
55

66
from celery import Celery # type: ignore[import-untyped]
@@ -11,14 +11,15 @@
1111
from models_library.progress_bar import ProgressReport
1212
from servicelib.celery.models import (
1313
Task,
14-
TaskContext,
14+
TaskFilter,
1515
TaskID,
1616
TaskInfoStore,
1717
TaskMetadata,
1818
TaskState,
1919
TaskStatus,
2020
TaskUUID,
2121
)
22+
from servicelib.celery.task_manager import TaskManager
2223
from servicelib.logging_utils import log_context
2324
from settings_library.celery import CelerySettings
2425

@@ -41,16 +42,16 @@ async def submit_task(
4142
self,
4243
task_metadata: TaskMetadata,
4344
*,
44-
task_context: TaskContext,
45+
task_filter: TaskFilter,
4546
**task_params,
4647
) -> TaskUUID:
4748
with log_context(
4849
_logger,
4950
logging.DEBUG,
50-
msg=f"Submit {task_metadata.name=}: {task_context=} {task_params=}",
51+
msg=f"Submit {task_metadata.name=}: {task_filter=} {task_params=}",
5152
):
5253
task_uuid = uuid4()
53-
task_id = build_task_id(task_context, task_uuid)
54+
task_id = build_task_id(task_filter, task_uuid)
5455
self._celery_app.send_task(
5556
task_metadata.name,
5657
task_id=task_id,
@@ -72,14 +73,14 @@ async def submit_task(
7273
def _abort_task(self, task_id: TaskID) -> None:
7374
AbortableAsyncResult(task_id, app=self._celery_app).abort()
7475

75-
async def cancel_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
76+
async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None:
7677
with log_context(
7778
_logger,
7879
logging.DEBUG,
79-
msg=f"task cancellation: {task_context=} {task_uuid=}",
80+
msg=f"task cancellation: {task_filter=} {task_uuid=}",
8081
):
81-
task_id = build_task_id(task_context, task_uuid)
82-
if not (await self.get_task_status(task_context, task_uuid)).is_done:
82+
task_id = build_task_id(task_filter, task_uuid)
83+
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
8384
await self._abort_task(task_id)
8485
await self._task_info_store.remove_task(task_id)
8586

@@ -88,14 +89,14 @@ def _forget_task(self, task_id: TaskID) -> None:
8889
AbortableAsyncResult(task_id, app=self._celery_app).forget()
8990

9091
async def get_task_result(
91-
self, task_context: TaskContext, task_uuid: TaskUUID
92+
self, task_filter: TaskFilter, task_uuid: TaskUUID
9293
) -> Any:
9394
with log_context(
9495
_logger,
9596
logging.DEBUG,
96-
msg=f"Get task result: {task_context=} {task_uuid=}",
97+
msg=f"Get task result: {task_filter=} {task_uuid=}",
9798
):
98-
task_id = build_task_id(task_context, task_uuid)
99+
task_id = build_task_id(task_filter, task_uuid)
99100
async_result = self._celery_app.AsyncResult(task_id)
100101
result = async_result.result
101102
if async_result.ready():
@@ -106,10 +107,10 @@ async def get_task_result(
106107
return result
107108

108109
async def _get_task_progress_report(
109-
self, task_context: TaskContext, task_uuid: TaskUUID, task_state: TaskState
110+
self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState
110111
) -> ProgressReport:
111112
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
112-
task_id = build_task_id(task_context, task_uuid)
113+
task_id = build_task_id(task_filter, task_uuid)
113114
progress = await self._task_info_store.get_task_progress(task_id)
114115
if progress is not None:
115116
return progress
@@ -131,33 +132,37 @@ def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
131132
return TaskState(self._celery_app.AsyncResult(task_id).state)
132133

133134
async def get_task_status(
134-
self, task_context: TaskContext, task_uuid: TaskUUID
135+
self, task_filter: TaskFilter, task_uuid: TaskUUID
135136
) -> TaskStatus:
136137
with log_context(
137138
_logger,
138139
logging.DEBUG,
139-
msg=f"Getting task status: {task_context=} {task_uuid=}",
140+
msg=f"Getting task status: {task_filter=} {task_uuid=}",
140141
):
141-
task_id = build_task_id(task_context, task_uuid)
142+
task_id = build_task_id(task_filter, task_uuid)
142143
task_state = await self._get_task_celery_state(task_id)
143144
return TaskStatus(
144145
task_uuid=task_uuid,
145146
task_state=task_state,
146147
progress_report=await self._get_task_progress_report(
147-
task_context, task_uuid, task_state
148+
task_filter, task_uuid, task_state
148149
),
149150
)
150151

151-
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
152+
async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
152153
with log_context(
153154
_logger,
154155
logging.DEBUG,
155-
msg=f"Listing tasks: {task_context=}",
156+
msg=f"Listing tasks: {task_filter=}",
156157
):
157-
return await self._task_info_store.list_tasks(task_context)
158+
return await self._task_info_store.list_tasks(task_filter)
158159

159160
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
160161
await self._task_info_store.set_task_progress(
161162
task_id=task_id,
162163
report=report,
163164
)
165+
166+
167+
if TYPE_CHECKING:
168+
_: type[TaskManager] = CeleryTaskManager

packages/celery-library/src/celery_library/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,23 @@
22

33
from celery import Celery # type: ignore[import-untyped]
44
from servicelib.celery.app_server import BaseAppServer
5-
from servicelib.celery.models import TaskContext, TaskID, TaskUUID
5+
from servicelib.celery.models import TaskFilter, TaskID, TaskUUID
66

77
_APP_SERVER_KEY = "app_server"
88

99
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
1010

1111

12-
def build_task_id_prefix(task_context: TaskContext) -> str:
12+
def build_task_id_prefix(task_filter: TaskFilter) -> str:
13+
filter_dict = task_filter.model_dump()
1314
return _TASK_ID_KEY_DELIMITATOR.join(
14-
[f"{task_context[key]}" for key in sorted(task_context)]
15+
[f"{filter_dict[key]}" for key in sorted(filter_dict)]
1516
)
1617

1718

18-
def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
19+
def build_task_id(task_filter: TaskFilter, task_uuid: TaskUUID) -> TaskID:
1920
return _TASK_ID_KEY_DELIMITATOR.join(
20-
[build_task_id_prefix(task_context), f"{task_uuid}"]
21+
[build_task_id_prefix(task_filter), f"{task_uuid}"]
2122
)
2223

2324

0 commit comments

Comments
 (0)