Skip to content

Commit e23cc2e

Browse files
refactor: move task metadata
1 parent 2bda0d8 commit e23cc2e

File tree

11 files changed

+121
-128
lines changed

11 files changed

+121
-128
lines changed

packages/celery-library/src/celery_library/backends/_redis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8282
)
8383
return None
8484

85-
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
85+
async def list_tasks(self, context: TaskContext) -> list[Task]:
8686
search_key = (
8787
_CELERY_TASK_INFO_PREFIX
88-
+ build_task_id_prefix(task_context)
88+
+ build_task_id_prefix(context)
8989
+ _CELERY_TASK_ID_KEY_SEPARATOR
9090
)
9191
search_key_len = len(search_key)

packages/celery-library/src/celery_library/rpc/_async_jobs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def cancel(
3838
assert job_id_data # nosec
3939
try:
4040
await task_manager.cancel_task(
41-
task_context=job_id_data.model_dump(),
41+
context=job_id_data.model_dump(),
4242
task_uuid=job_id,
4343
)
4444
except CeleryError as exc:
@@ -54,7 +54,7 @@ async def status(
5454

5555
try:
5656
task_status = await task_manager.get_task_status(
57-
task_context=job_id_data.model_dump(),
57+
context=job_id_data.model_dump(),
5858
task_uuid=job_id,
5959
)
6060
except CeleryError as exc:
@@ -84,13 +84,13 @@ async def result(
8484

8585
try:
8686
_status = await task_manager.get_task_status(
87-
task_context=job_id_data.model_dump(),
87+
context=job_id_data.model_dump(),
8888
task_uuid=job_id,
8989
)
9090
if not _status.is_done:
9191
raise JobNotDoneError(job_id=job_id)
9292
_result = await task_manager.get_task_result(
93-
task_context=job_id_data.model_dump(),
93+
context=job_id_data.model_dump(),
9494
task_uuid=job_id,
9595
)
9696
except CeleryError as exc:
@@ -129,7 +129,7 @@ async def list_jobs(
129129
assert task_manager # nosec
130130
try:
131131
tasks = await task_manager.list_tasks(
132-
task_context=job_id_data.model_dump(),
132+
context=job_id_data.model_dump(),
133133
)
134134
except CeleryError as exc:
135135
raise JobSchedulerError(exc=f"{exc}") from exc

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from common_library.async_tools import make_async
1111
from models_library.progress_bar import ProgressReport
1212
from servicelib.celery.models import (
13+
TASK_QUEUE_DEFAULT,
1314
Task,
1415
TaskContext,
1516
TaskID,
1617
TaskInfoStore,
1718
TaskMetadata,
19+
TaskName,
20+
TaskQueue,
1821
TaskState,
1922
TaskStatus,
2023
TaskUUID,
@@ -39,63 +42,69 @@ class CeleryTaskManager:
3942

4043
async def send_task(
4144
self,
42-
task_metadata: TaskMetadata,
45+
name: TaskName,
46+
context: TaskContext,
4347
*,
44-
task_context: TaskContext,
45-
**task_params,
48+
is_ephemeral: bool = False,
49+
queue: TaskQueue = TASK_QUEUE_DEFAULT,
50+
**params,
4651
) -> TaskUUID:
4752
with log_context(
4853
_logger,
4954
logging.DEBUG,
50-
msg=f"Send {task_metadata.name=}: {task_context=} {task_params=}",
55+
msg=f"Send {name=}: {context=} {params=}",
5156
):
5257
task_uuid = uuid4()
53-
task_id = build_task_id(task_context, task_uuid)
58+
task_id = build_task_id(context, task_uuid)
5459
self._celery_app.send_task(
55-
task_metadata.name,
60+
name,
5661
task_id=task_id,
57-
kwargs={"task_id": task_id} | task_params,
58-
queue=task_metadata.queue,
62+
kwargs={"task_id": task_id} | params,
63+
queue=queue,
5964
)
6065

6166
expiry = (
6267
self._celery_settings.CELERY_EPHEMERAL_RESULT_EXPIRES
63-
if task_metadata.ephemeral
68+
if is_ephemeral
6469
else self._celery_settings.CELERY_RESULT_EXPIRES
6570
)
6671
await self._task_info_store.create_task(
67-
task_id, task_metadata, expiry=expiry
72+
task_id,
73+
TaskMetadata(
74+
name=name,
75+
ephemeral=is_ephemeral,
76+
queue=queue,
77+
),
78+
expiry=expiry,
6879
)
6980
return task_uuid
7081

7182
@make_async()
7283
def _abort_task(self, task_id: TaskID) -> None:
7384
AbortableAsyncResult(task_id, app=self._celery_app).abort()
7485

75-
async def cancel_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
86+
async def cancel_task(self, context: TaskContext, task_uuid: TaskUUID) -> None:
7687
with log_context(
7788
_logger,
7889
logging.DEBUG,
79-
msg=f"task cancellation: {task_context=} {task_uuid=}",
90+
msg=f"task cancellation: {context=} {task_uuid=}",
8091
):
81-
task_id = build_task_id(task_context, task_uuid)
82-
if not (await self.get_task_status(task_context, task_uuid)).is_done:
92+
task_id = build_task_id(context, task_uuid)
93+
if not (await self.get_task_status(context, task_uuid)).is_done:
8394
await self._abort_task(task_id)
8495
await self._task_info_store.remove_task(task_id)
8596

8697
@make_async()
8798
def _forget_task(self, task_id: TaskID) -> None:
8899
AbortableAsyncResult(task_id, app=self._celery_app).forget()
89100

90-
async def get_task_result(
91-
self, task_context: TaskContext, task_uuid: TaskUUID
92-
) -> Any:
101+
async def get_task_result(self, context: TaskContext, task_uuid: TaskUUID) -> Any:
93102
with log_context(
94103
_logger,
95104
logging.DEBUG,
96-
msg=f"Get task result: {task_context=} {task_uuid=}",
105+
msg=f"Get task result: {context=} {task_uuid=}",
97106
):
98-
task_id = build_task_id(task_context, task_uuid)
107+
task_id = build_task_id(context, task_uuid)
99108
async_result = self._celery_app.AsyncResult(task_id)
100109
result = async_result.result
101110
if async_result.ready():
@@ -105,15 +114,15 @@ async def get_task_result(
105114
await self._task_info_store.remove_task(task_id)
106115
return result
107116

108-
async def _get_task_progress_report(
109-
self, task_context: TaskContext, task_uuid: TaskUUID, task_state: TaskState
117+
async def _get_progress_report(
118+
self, context: TaskContext, task_uuid: TaskUUID, state: TaskState
110119
) -> ProgressReport:
111-
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
112-
task_id = build_task_id(task_context, task_uuid)
120+
if state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
121+
task_id = build_task_id(context, task_uuid)
113122
progress = await self._task_info_store.get_task_progress(task_id)
114123
if progress is not None:
115124
return progress
116-
if task_state in (
125+
if state in (
117126
TaskState.SUCCESS,
118127
TaskState.FAILURE,
119128
):
@@ -131,30 +140,30 @@ def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
131140
return TaskState(self._celery_app.AsyncResult(task_id).state)
132141

133142
async def get_task_status(
134-
self, task_context: TaskContext, task_uuid: TaskUUID
143+
self, context: TaskContext, task_uuid: TaskUUID
135144
) -> TaskStatus:
136145
with log_context(
137146
_logger,
138147
logging.DEBUG,
139-
msg=f"Getting task status: {task_context=} {task_uuid=}",
148+
msg=f"Getting task status: {context=} {task_uuid=}",
140149
):
141-
task_id = build_task_id(task_context, task_uuid)
150+
task_id = build_task_id(context, task_uuid)
142151
task_state = await self._get_task_celery_state(task_id)
143152
return TaskStatus(
144153
task_uuid=task_uuid,
145154
task_state=task_state,
146-
progress_report=await self._get_task_progress_report(
147-
task_context, task_uuid, task_state
155+
progress_report=await self._get_progress_report(
156+
context, task_uuid, task_state
148157
),
149158
)
150159

151-
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
160+
async def list_tasks(self, context: TaskContext) -> list[Task]:
152161
with log_context(
153162
_logger,
154163
logging.DEBUG,
155-
msg=f"Listing tasks: {task_context=}",
164+
msg=f"Listing tasks: {context=}",
156165
):
157-
return await self._task_info_store.list_tasks(task_context)
166+
return await self._task_info_store.list_tasks(context)
158167

159168
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
160169
await self._task_info_store.set_task_progress(

packages/celery-library/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
120120

121121

122122
@pytest.fixture
123-
async def celery_task_manager(
123+
async def task_manager(
124124
celery_app: Celery,
125125
celery_settings: CelerySettings,
126126
with_celery_worker: TestWorkController,

packages/celery-library/tests/unit/test_async_jobs.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from models_library.rabbitmq_basic_types import RPCNamespace
2828
from models_library.users import UserID
2929
from pydantic import TypeAdapter
30-
from servicelib.celery.models import TaskID, TaskMetadata
30+
from servicelib.celery.models import TaskID
3131
from servicelib.celery.task_manager import TaskManager
3232
from servicelib.rabbitmq import RabbitMQRPCClient, RPCRouter
3333
from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs
@@ -83,7 +83,9 @@ async def rpc_sync_job(
8383
) -> AsyncJobGet:
8484
task_name = sync_job.__name__
8585
task_uuid = await task_manager.send_task(
86-
TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs
86+
name=task_name,
87+
context=job_id_data.model_dump(),
88+
**kwargs,
8789
)
8890

8991
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
@@ -95,7 +97,9 @@ async def rpc_async_job(
9597
) -> AsyncJobGet:
9698
task_name = async_job.__name__
9799
task_uuid = await task_manager.send_task(
98-
TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs
100+
name=task_name,
101+
context=job_id_data.model_dump(),
102+
**kwargs,
99103
)
100104

101105
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
@@ -139,13 +143,13 @@ async def async_job(task: Task, task_id: TaskID, action: Action, payload: Any) -
139143

140144
@pytest.fixture
141145
async def register_rpc_routes(
142-
async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, celery_task_manager: TaskManager
146+
async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, task_manager: TaskManager
143147
) -> None:
144148
await async_jobs_rabbitmq_rpc_client.register_router(
145-
_async_jobs.router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager
149+
_async_jobs.router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=task_manager
146150
)
147151
await async_jobs_rabbitmq_rpc_client.register_router(
148-
router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager
152+
router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=task_manager
149153
)
150154

151155

0 commit comments

Comments
 (0)