Skip to content

Commit d8e9e10

Browse files
committed
introduce TaslkFilterBase model
1 parent d0874a0 commit d8e9e10

File tree

6 files changed

+36
-30
lines changed

6 files changed

+36
-30
lines changed

packages/celery-library/src/celery_library/backends/_redis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import ValidationError
88
from servicelib.celery.models import (
99
Task,
10-
TaskContext,
10+
TaskFilterBase,
1111
TaskID,
1212
TaskMetadata,
1313
TaskUUID,
@@ -82,7 +82,7 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8282
)
8383
return None
8484

85-
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
85+
async def list_tasks(self, task_context: TaskFilterBase) -> list[Task]:
8686
search_key = (
8787
_CELERY_TASK_INFO_PREFIX
8888
+ build_task_id_prefix(task_context)

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from models_library.progress_bar import ProgressReport
1212
from servicelib.celery.models import (
1313
Task,
14-
TaskContext,
14+
TaskFilterBase,
1515
TaskID,
1616
TaskInfoStore,
1717
TaskMetadata,
@@ -41,7 +41,7 @@ async def submit_task(
4141
self,
4242
task_metadata: TaskMetadata,
4343
*,
44-
task_context: TaskContext,
44+
task_context: TaskFilterBase,
4545
**task_params,
4646
) -> TaskUUID:
4747
with log_context(
@@ -72,7 +72,9 @@ async def submit_task(
7272
def _abort_task(self, task_id: TaskID) -> None:
7373
AbortableAsyncResult(task_id, app=self._celery_app).abort()
7474

75-
async def cancel_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None:
75+
async def cancel_task(
76+
self, task_context: TaskFilterBase, task_uuid: TaskUUID
77+
) -> None:
7678
with log_context(
7779
_logger,
7880
logging.DEBUG,
@@ -88,7 +90,7 @@ def _forget_task(self, task_id: TaskID) -> None:
8890
AbortableAsyncResult(task_id, app=self._celery_app).forget()
8991

9092
async def get_task_result(
91-
self, task_context: TaskContext, task_uuid: TaskUUID
93+
self, task_context: TaskFilterBase, task_uuid: TaskUUID
9294
) -> Any:
9395
with log_context(
9496
_logger,
@@ -106,7 +108,7 @@ async def get_task_result(
106108
return result
107109

108110
async def _get_task_progress_report(
109-
self, task_context: TaskContext, task_uuid: TaskUUID, task_state: TaskState
111+
self, task_context: TaskFilterBase, task_uuid: TaskUUID, task_state: TaskState
110112
) -> ProgressReport:
111113
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
112114
task_id = build_task_id(task_context, task_uuid)
@@ -131,7 +133,7 @@ def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
131133
return TaskState(self._celery_app.AsyncResult(task_id).state)
132134

133135
async def get_task_status(
134-
self, task_context: TaskContext, task_uuid: TaskUUID
136+
self, task_context: TaskFilterBase, task_uuid: TaskUUID
135137
) -> TaskStatus:
136138
with log_context(
137139
_logger,
@@ -148,7 +150,7 @@ async def get_task_status(
148150
),
149151
)
150152

151-
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
153+
async def list_tasks(self, task_context: TaskFilterBase) -> list[Task]:
152154
with log_context(
153155
_logger,
154156
logging.DEBUG,

packages/celery-library/src/celery_library/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,19 @@
22

33
from celery import Celery # type: ignore[import-untyped]
44
from servicelib.celery.app_server import BaseAppServer
5-
from servicelib.celery.models import TaskContext, TaskID, TaskUUID
5+
from servicelib.celery.models import TaskFilterBase, TaskID, TaskUUID
66

77
_APP_SERVER_KEY = "app_server"
88

99
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
1010

1111

12-
def build_task_id_prefix(task_context: TaskContext) -> str:
13-
return _TASK_ID_KEY_DELIMITATOR.join(
14-
[f"{task_context[key]}" for key in sorted(task_context)]
15-
)
12+
def build_task_id_prefix(task_context: TaskFilterBase) -> str:
13+
_dict = task_context.model_dump()
14+
return _TASK_ID_KEY_DELIMITATOR.join([f"{_dict[key]}" for key in sorted(_dict)])
1615

1716

18-
def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
17+
def build_task_id(task_context: TaskFilterBase, task_uuid: TaskUUID) -> TaskID:
1918
return _TASK_ID_KEY_DELIMITATOR.join(
2019
[build_task_id_prefix(task_context), f"{task_uuid}"]
2120
)

packages/celery-library/tests/unit/test_tasks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from common_library.errors_classes import OsparcErrorMixin
2121
from models_library.progress_bar import ProgressReport
2222
from servicelib.celery.models import (
23-
TaskContext,
23+
TaskFilterBase,
2424
TaskID,
2525
TaskMetadata,
2626
TaskState,
@@ -93,7 +93,7 @@ def _(celery_app: Celery) -> None:
9393
async def test_submitting_task_calling_async_function_results_with_success_state(
9494
celery_task_manager: CeleryTaskManager,
9595
):
96-
task_context = TaskContext(user_id=42)
96+
task_context = TaskFilterBase(user_id=42)
9797

9898
task_uuid = await celery_task_manager.submit_task(
9999
TaskMetadata(
@@ -123,7 +123,7 @@ async def test_submitting_task_calling_async_function_results_with_success_state
123123
async def test_submitting_task_with_failure_results_with_error(
124124
celery_task_manager: CeleryTaskManager,
125125
):
126-
task_context = TaskContext(user_id=42)
126+
task_context = TaskFilterBase(user_id=42)
127127

128128
task_uuid = await celery_task_manager.submit_task(
129129
TaskMetadata(
@@ -151,7 +151,7 @@ async def test_submitting_task_with_failure_results_with_error(
151151
async def test_cancelling_a_running_task_aborts_and_deletes(
152152
celery_task_manager: CeleryTaskManager,
153153
):
154-
task_context = TaskContext(user_id=42)
154+
task_context = TaskFilterBase(user_id=42)
155155

156156
task_uuid = await celery_task_manager.submit_task(
157157
TaskMetadata(
@@ -185,7 +185,7 @@ async def test_cancelling_a_running_task_aborts_and_deletes(
185185
async def test_listing_task_uuids_contains_submitted_task(
186186
celery_task_manager: CeleryTaskManager,
187187
):
188-
task_context = TaskContext(user_id=42)
188+
task_context = TaskFilterBase(user_id=42)
189189

190190
task_uuid = await celery_task_manager.submit_task(
191191
TaskMetadata(

packages/service-library/src/servicelib/celery/models.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
import datetime
22
from enum import StrEnum
3-
from typing import Annotated, Any, Protocol, TypeAlias
3+
from typing import Annotated, Protocol, TypeAlias
44
from uuid import UUID
55

66
from models_library.progress_bar import ProgressReport
7-
from pydantic import BaseModel, StringConstraints
7+
from pydantic import BaseModel, ConfigDict, StringConstraints
88

9-
TaskContext: TypeAlias = dict[str, Any]
109
TaskID: TypeAlias = str
1110
TaskName: TypeAlias = Annotated[
1211
str, StringConstraints(strip_whitespace=True, min_length=1)
1312
]
1413
TaskUUID: TypeAlias = UUID
1514

1615

16+
class TaskFilterBase(BaseModel):
17+
__root__: dict[str, str]
18+
19+
model_config = ConfigDict(extra="forbid")
20+
21+
1722
class TaskState(StrEnum):
1823
PENDING = "PENDING"
1924
STARTED = "STARTED"
@@ -56,7 +61,7 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ...
5661

5762
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ...
5863

59-
async def list_tasks(self, task_context: TaskContext) -> list[Task]: ...
64+
async def list_tasks(self, task_context: TaskFilterBase) -> list[Task]: ...
6065

6166
async def remove_task(self, task_id: TaskID) -> None: ...
6267

packages/service-library/src/servicelib/celery/task_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ..celery.models import (
66
Task,
7-
TaskContext,
7+
TaskFilterBase,
88
TaskID,
99
TaskMetadata,
1010
TaskStatus,
@@ -14,22 +14,22 @@
1414

1515
class TaskManager(Protocol):
1616
async def submit_task(
17-
self, task_metadata: TaskMetadata, *, task_context: TaskContext, **task_param
17+
self, task_metadata: TaskMetadata, *, task_context: TaskFilterBase, **task_param
1818
) -> TaskUUID: ...
1919

2020
async def cancel_task(
21-
self, task_context: TaskContext, task_uuid: TaskUUID
21+
self, task_context: TaskFilterBase, task_uuid: TaskUUID
2222
) -> None: ...
2323

2424
async def get_task_result(
25-
self, task_context: TaskContext, task_uuid: TaskUUID
25+
self, task_context: TaskFilterBase, task_uuid: TaskUUID
2626
) -> Any: ...
2727

2828
async def get_task_status(
29-
self, task_context: TaskContext, task_uuid: TaskUUID
29+
self, task_context: TaskFilterBase, task_uuid: TaskUUID
3030
) -> TaskStatus: ...
3131

32-
async def list_tasks(self, task_context: TaskContext) -> list[Task]: ...
32+
async def list_tasks(self, task_context: TaskFilterBase) -> list[Task]: ...
3333

3434
async def set_task_progress(
3535
self, task_id: TaskID, report: ProgressReport

0 commit comments

Comments
 (0)