Skip to content

Commit 94bb1ae

Browse files
committed
use base model for job filter
1 parent d8e9e10 commit 94bb1ae

File tree

21 files changed

+168
-162
lines changed

21 files changed

+168
-162
lines changed

packages/celery-library/src/celery_library/backends/_redis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8282
)
8383
return None
8484

85-
async def list_tasks(self, task_context: TaskFilterBase) -> list[Task]:
85+
async def list_tasks(self, task_filter: TaskFilterBase) -> list[Task]:
8686
search_key = (
8787
_CELERY_TASK_INFO_PREFIX
88-
+ build_task_id_prefix(task_context)
88+
+ build_task_id_prefix(task_filter)
8989
+ _CELERY_TASK_ID_KEY_SEPARATOR
9090
)
9191
search_key_len = len(search_key)

packages/celery-library/src/celery_library/rpc/_async_jobs.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from celery.exceptions import CeleryError # type: ignore[import-untyped]
66
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
7+
AsyncJobFilter,
78
AsyncJobGet,
89
AsyncJobId,
9-
AsyncJobNameData,
1010
AsyncJobResult,
1111
AsyncJobStatus,
1212
)
@@ -32,13 +32,13 @@
3232

3333
@router.expose(reraise_if_error_type=(JobSchedulerError,))
3434
async def cancel(
35-
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
35+
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
3636
):
3737
assert task_manager # nosec
38-
assert job_id_data # nosec
38+
assert job_filter # nosec
3939
try:
4040
await task_manager.cancel_task(
41-
task_context=job_id_data.model_dump(),
41+
task_filter=job_filter,
4242
task_uuid=job_id,
4343
)
4444
except CeleryError as exc:
@@ -47,14 +47,14 @@ async def cancel(
4747

4848
@router.expose(reraise_if_error_type=(JobSchedulerError,))
4949
async def status(
50-
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
50+
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
5151
) -> AsyncJobStatus:
5252
assert task_manager # nosec
53-
assert job_id_data # nosec
53+
assert job_filter # nosec
5454

5555
try:
5656
task_status = await task_manager.get_task_status(
57-
task_context=job_id_data.model_dump(),
57+
task_filter=job_filter,
5858
task_uuid=job_id,
5959
)
6060
except CeleryError as exc:
@@ -76,21 +76,21 @@ async def status(
7676
)
7777
)
7878
async def result(
79-
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
79+
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
8080
) -> AsyncJobResult:
8181
assert task_manager # nosec
8282
assert job_id # nosec
83-
assert job_id_data # nosec
83+
assert job_filter # nosec
8484

8585
try:
8686
_status = await task_manager.get_task_status(
87-
task_context=job_id_data.model_dump(),
87+
task_filter=job_filter,
8888
task_uuid=job_id,
8989
)
9090
if not _status.is_done:
9191
raise JobNotDoneError(job_id=job_id)
9292
_result = await task_manager.get_task_result(
93-
task_context=job_id_data.model_dump(),
93+
task_filter=job_filter,
9494
task_uuid=job_id,
9595
)
9696
except CeleryError as exc:
@@ -123,13 +123,13 @@ async def result(
123123

124124
@router.expose(reraise_if_error_type=(JobSchedulerError,))
125125
async def list_jobs(
126-
task_manager: TaskManager, filter_: str, job_id_data: AsyncJobNameData
126+
task_manager: TaskManager, filter_: str, job_filter: AsyncJobFilter
127127
) -> list[AsyncJobGet]:
128128
_ = filter_
129129
assert task_manager # nosec
130130
try:
131131
tasks = await task_manager.list_tasks(
132-
task_context=job_id_data.model_dump(),
132+
task_filter=job_filter,
133133
)
134134
except CeleryError as exc:
135135
raise JobSchedulerError(exc=f"{exc}") from exc

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44
from uuid import uuid4
55

66
from celery import Celery # type: ignore[import-untyped]
@@ -19,6 +19,7 @@
1919
TaskStatus,
2020
TaskUUID,
2121
)
22+
from servicelib.celery.task_manager import TaskManager
2223
from servicelib.logging_utils import log_context
2324
from settings_library.celery import CelerySettings
2425

@@ -41,16 +42,16 @@ async def submit_task(
4142
self,
4243
task_metadata: TaskMetadata,
4344
*,
44-
task_context: TaskFilterBase,
45+
task_filter: TaskFilterBase,
4546
**task_params,
4647
) -> TaskUUID:
4748
with log_context(
4849
_logger,
4950
logging.DEBUG,
50-
msg=f"Submit {task_metadata.name=}: {task_context=} {task_params=}",
51+
msg=f"Submit {task_metadata.name=}: {task_filter=} {task_params=}",
5152
):
5253
task_uuid = uuid4()
53-
task_id = build_task_id(task_context, task_uuid)
54+
task_id = build_task_id(task_filter, task_uuid)
5455
self._celery_app.send_task(
5556
task_metadata.name,
5657
task_id=task_id,
@@ -73,15 +74,15 @@ def _abort_task(self, task_id: TaskID) -> None:
7374
AbortableAsyncResult(task_id, app=self._celery_app).abort()
7475

7576
async def cancel_task(
76-
self, task_context: TaskFilterBase, task_uuid: TaskUUID
77+
self, task_filter: TaskFilterBase, task_uuid: TaskUUID
7778
) -> None:
7879
with log_context(
7980
_logger,
8081
logging.DEBUG,
81-
msg=f"task cancellation: {task_context=} {task_uuid=}",
82+
msg=f"task cancellation: {task_filter=} {task_uuid=}",
8283
):
83-
task_id = build_task_id(task_context, task_uuid)
84-
if not (await self.get_task_status(task_context, task_uuid)).is_done:
84+
task_id = build_task_id(task_filter, task_uuid)
85+
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
8586
await self._abort_task(task_id)
8687
await self._task_info_store.remove_task(task_id)
8788

@@ -90,14 +91,14 @@ def _forget_task(self, task_id: TaskID) -> None:
9091
AbortableAsyncResult(task_id, app=self._celery_app).forget()
9192

9293
async def get_task_result(
93-
self, task_context: TaskFilterBase, task_uuid: TaskUUID
94+
self, task_filter: TaskFilterBase, task_uuid: TaskUUID
9495
) -> Any:
9596
with log_context(
9697
_logger,
9798
logging.DEBUG,
98-
msg=f"Get task result: {task_context=} {task_uuid=}",
99+
msg=f"Get task result: {task_filter=} {task_uuid=}",
99100
):
100-
task_id = build_task_id(task_context, task_uuid)
101+
task_id = build_task_id(task_filter, task_uuid)
101102
async_result = self._celery_app.AsyncResult(task_id)
102103
result = async_result.result
103104
if async_result.ready():
@@ -108,10 +109,10 @@ async def get_task_result(
108109
return result
109110

110111
async def _get_task_progress_report(
111-
self, task_context: TaskFilterBase, task_uuid: TaskUUID, task_state: TaskState
112+
self, task_filter: TaskFilterBase, task_uuid: TaskUUID, task_state: TaskState
112113
) -> ProgressReport:
113114
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
114-
task_id = build_task_id(task_context, task_uuid)
115+
task_id = build_task_id(task_filter, task_uuid)
115116
progress = await self._task_info_store.get_task_progress(task_id)
116117
if progress is not None:
117118
return progress
@@ -133,33 +134,37 @@ def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
133134
return TaskState(self._celery_app.AsyncResult(task_id).state)
134135

135136
async def get_task_status(
136-
self, task_context: TaskFilterBase, task_uuid: TaskUUID
137+
self, task_filter: TaskFilterBase, task_uuid: TaskUUID
137138
) -> TaskStatus:
138139
with log_context(
139140
_logger,
140141
logging.DEBUG,
141-
msg=f"Getting task status: {task_context=} {task_uuid=}",
142+
msg=f"Getting task status: {task_filter=} {task_uuid=}",
142143
):
143-
task_id = build_task_id(task_context, task_uuid)
144+
task_id = build_task_id(task_filter, task_uuid)
144145
task_state = await self._get_task_celery_state(task_id)
145146
return TaskStatus(
146147
task_uuid=task_uuid,
147148
task_state=task_state,
148149
progress_report=await self._get_task_progress_report(
149-
task_context, task_uuid, task_state
150+
task_filter, task_uuid, task_state
150151
),
151152
)
152153

153-
async def list_tasks(self, task_context: TaskFilterBase) -> list[Task]:
154+
async def list_tasks(self, task_filter: TaskFilterBase) -> list[Task]:
154155
with log_context(
155156
_logger,
156157
logging.DEBUG,
157-
msg=f"Listing tasks: {task_context=}",
158+
msg=f"Listing tasks: {task_filter=}",
158159
):
159-
return await self._task_info_store.list_tasks(task_context)
160+
return await self._task_info_store.list_tasks(task_filter)
160161

161162
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
162163
await self._task_info_store.set_task_progress(
163164
task_id=task_id,
164165
report=report,
165166
)
167+
168+
169+
if TYPE_CHECKING:
170+
_: type[TaskManager] = CeleryTaskManager

packages/celery-library/tests/unit/test_async_jobs.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from common_library.errors_classes import OsparcErrorMixin
1717
from faker import Faker
1818
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
19+
AsyncJobFilter,
1920
AsyncJobGet,
20-
AsyncJobNameData,
2121
)
2222
from models_library.api_schemas_rpc_async_jobs.exceptions import (
2323
JobAbortedError,
@@ -79,23 +79,23 @@ def product_name(faker: Faker) -> ProductName:
7979

8080
@router.expose()
8181
async def rpc_sync_job(
82-
task_manager: TaskManager, *, job_id_data: AsyncJobNameData, **kwargs: Any
82+
task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any
8383
) -> AsyncJobGet:
8484
task_name = sync_job.__name__
8585
task_uuid = await task_manager.submit_task(
86-
TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs
86+
TaskMetadata(name=task_name), task_filter=job_filter, **kwargs
8787
)
8888

8989
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
9090

9191

9292
@router.expose()
9393
async def rpc_async_job(
94-
task_manager: TaskManager, *, job_id_data: AsyncJobNameData, **kwargs: Any
94+
task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any
9595
) -> AsyncJobGet:
9696
task_name = async_job.__name__
9797
task_uuid = await task_manager.submit_task(
98-
TaskMetadata(name=task_name), task_context=job_id_data.model_dump(), **kwargs
98+
TaskMetadata(name=task_name), task_filter=job_filter, **kwargs
9999
)
100100

101101
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
@@ -156,16 +156,18 @@ async def _start_task_via_rpc(
156156
user_id: UserID,
157157
product_name: ProductName,
158158
**kwargs: Any,
159-
) -> tuple[AsyncJobGet, AsyncJobNameData]:
160-
job_id_data = AsyncJobNameData(user_id=user_id, product_name=product_name)
159+
) -> tuple[AsyncJobGet, AsyncJobFilter]:
160+
job_filter = AsyncJobFilter(
161+
user_id=user_id, product_name=product_name, client_name="pytest_client"
162+
)
161163
async_job_get = await async_jobs.submit(
162164
rabbitmq_rpc_client=client,
163165
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
164166
method_name=rpc_task_name,
165-
job_id_data=job_id_data,
167+
job_filter=job_filter,
166168
**kwargs,
167169
)
168-
return async_job_get, job_id_data
170+
return async_job_get, job_filter
169171

170172

171173
@pytest.fixture
@@ -193,7 +195,7 @@ async def _wait_for_job(
193195
rpc_client: RabbitMQRPCClient,
194196
*,
195197
async_job_get: AsyncJobGet,
196-
job_id_data: AsyncJobNameData,
198+
job_filter: AsyncJobFilter,
197199
stop_after: timedelta = timedelta(seconds=5),
198200
) -> None:
199201

@@ -208,7 +210,7 @@ async def _wait_for_job(
208210
rpc_client,
209211
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
210212
job_id=async_job_get.job_id,
211-
job_id_data=job_id_data,
213+
job_filter=job_filter,
212214
)
213215
assert (
214216
result.done is True
@@ -255,14 +257,14 @@ async def test_async_jobs_workflow(
255257
async_jobs_rabbitmq_rpc_client,
256258
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
257259
filter_="", # currently not used
258-
job_id_data=job_id_data,
260+
job_filter=job_id_data,
259261
)
260262
assert len(jobs) > 0
261263

262264
await _wait_for_job(
263265
async_jobs_rabbitmq_rpc_client,
264266
async_job_get=async_job_get,
265-
job_id_data=job_id_data,
267+
job_filter=job_id_data,
266268
)
267269

268270
async_job_result = await async_jobs.result(
@@ -301,20 +303,20 @@ async def test_async_jobs_cancel(
301303
async_jobs_rabbitmq_rpc_client,
302304
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
303305
job_id=async_job_get.job_id,
304-
job_id_data=job_id_data,
306+
job_filter=job_id_data,
305307
)
306308

307309
await _wait_for_job(
308310
async_jobs_rabbitmq_rpc_client,
309311
async_job_get=async_job_get,
310-
job_id_data=job_id_data,
312+
job_filter=job_id_data,
311313
)
312314

313315
jobs = await async_jobs.list_jobs(
314316
async_jobs_rabbitmq_rpc_client,
315317
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
316318
filter_="", # currently not used
317-
job_id_data=job_id_data,
319+
job_filter=job_id_data,
318320
)
319321
assert async_job_get.job_id not in [job.job_id for job in jobs]
320322

@@ -365,7 +367,7 @@ async def test_async_jobs_raises(
365367
await _wait_for_job(
366368
async_jobs_rabbitmq_rpc_client,
367369
async_job_get=async_job_get,
368-
job_id_data=job_id_data,
370+
job_filter=job_id_data,
369371
stop_after=timedelta(minutes=1),
370372
)
371373

packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from uuid import UUID
33

44
from pydantic import BaseModel, StringConstraints
5+
from servicelib.celery.models import TaskFilterBase
56

67
from ..products import ProductName
78
from ..progress_bar import ProgressReport
@@ -33,8 +34,12 @@ class AsyncJobAbort(BaseModel):
3334
job_id: AsyncJobId
3435

3536

36-
class AsyncJobNameData(BaseModel):
37+
class AsyncJobFilter(TaskFilterBase):
3738
"""Data for controlling access to an async job"""
3839

3940
product_name: ProductName
4041
user_id: UserID
42+
client_name: Annotated[
43+
str,
44+
StringConstraints(min_length=1, pattern=r"^[^\s]+$"),
45+
]

0 commit comments

Comments
 (0)