Skip to content

Commit bbaf2a5

Browse files
committed
add test of field_sorting_key
1 parent 35402ee commit bbaf2a5

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

packages/service-library/src/servicelib/celery/models.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,40 @@
11
import datetime
2+
from collections.abc import Callable
23
from enum import StrEnum
3-
from typing import Annotated, Protocol, TypeAlias
4+
from typing import Annotated, Final, Protocol, TypeAlias
45
from uuid import UUID
56

67
from models_library.progress_bar import ProgressReport
7-
from pydantic import BaseModel, ConfigDict, StringConstraints
8+
from pydantic import BaseModel, ConfigDict, Field, StringConstraints
89
from pydantic.config import JsonDict
910

1011
TaskID: TypeAlias = str
1112
TaskName: TypeAlias = Annotated[
1213
str, StringConstraints(strip_whitespace=True, min_length=1)
1314
]
1415
TaskUUID: TypeAlias = UUID
16+
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
1517

1618

1719
class TaskFilter(BaseModel):
1820
model_config = ConfigDict(extra="allow")
21+
field_sorting_key: Annotated[Callable[[str], int] | None, Field(exclude=True)] = (
22+
None
23+
)
24+
25+
def _build_task_id_prefix(self) -> str:
26+
filter_dict = self.model_dump()
27+
return _TASK_ID_KEY_DELIMITATOR.join(
28+
[
29+
f"{key}={filter_dict[key]}"
30+
for key in sorted(filter_dict, key=self.field_sorting_key)
31+
]
32+
)
33+
34+
def task_id(self, task_uuid: TaskUUID) -> TaskID:
35+
return _TASK_ID_KEY_DELIMITATOR.join(
36+
[self._build_task_id_prefix(), f"task_uuid={task_uuid}"]
37+
)
1938

2039

2140
class TaskState(StrEnum):

packages/service-library/tests/test_celery.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Literal
2+
3+
import pytest
14
from faker import Faker
25
from servicelib.celery.models import TaskFilter
36

@@ -15,3 +18,23 @@ async def test_task_filter_serialization():
1518
}
1619
task_filter = TaskFilter.model_validate(_dict)
1720
assert task_filter.model_dump() == _dict
21+
22+
23+
@pytest.mark.parametrize("key_direction", ["plus", "minus"])
24+
async def test_task_filter_sorting_key_not_serialized(
25+
key_direction: Literal["plus", "minus"],
26+
):
27+
28+
keys = ["a", "aa"]
29+
key = lambda s: len(s) if key_direction == "plus" else -len(s)
30+
31+
task_filter = TaskFilter(
32+
a=_faker.random_int(),
33+
aa=_faker.word(),
34+
field_sorting_key=key,
35+
)
36+
expected_key = ":".join(
37+
[f"{k}={getattr(task_filter, k)}" for k in sorted(keys, key=key)]
38+
)
39+
assert task_filter._build_task_id_prefix() == expected_key
40+
assert "field_sorting_key" not in task_filter.model_dump()

0 commit comments

Comments
 (0)