Skip to content

Commit 3f44899

Browse files
committed
cosmetic change
1 parent c25f45d commit 3f44899

File tree

4 files changed

+68
-62
lines changed

4 files changed

+68
-62
lines changed

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8585
return None
8686

8787
async def list_tasks(self, task_filter: OwnerMetadata) -> list[Task]:
88-
search_key = _CELERY_TASK_INFO_PREFIX + task_filter.create_task_id(
88+
search_key = _CELERY_TASK_INFO_PREFIX + task_filter.model_dump_task_id(
8989
task_uuid="*"
9090
)
9191

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def submit_task(
5050
msg=f"Submit {execution_metadata.name=}: {owner_metadata=} {task_params=}",
5151
):
5252
task_uuid = uuid4()
53-
task_id = owner_metadata.create_task_id(task_uuid=task_uuid)
53+
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
5454

5555
expiry = (
5656
self._celery_settings.CELERY_EPHEMERAL_RESULT_EXPIRES
@@ -93,7 +93,7 @@ async def cancel_task(
9393
logging.DEBUG,
9494
msg=f"task cancellation: {owner_metadata=} {task_uuid=}",
9595
):
96-
task_id = owner_metadata.create_task_id(task_uuid=task_uuid)
96+
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
9797
if not await self.task_exists(task_id):
9898
raise TaskNotFoundError(task_id=task_id)
9999

@@ -115,7 +115,7 @@ async def get_task_result(
115115
logging.DEBUG,
116116
msg=f"Get task result: {owner_metadata=} {task_uuid=}",
117117
):
118-
task_id = owner_metadata.create_task_id(task_uuid=task_uuid)
118+
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
119119
if not await self.task_exists(task_id):
120120
raise TaskNotFoundError(task_id=task_id)
121121

@@ -158,7 +158,7 @@ async def get_task_status(
158158
logging.DEBUG,
159159
msg=f"Getting task status: {owner_metadata=} {task_uuid=}",
160160
):
161-
task_id = owner_metadata.create_task_id(task_uuid=task_uuid)
161+
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
162162
if not await self.task_exists(task_id):
163163
raise TaskNotFoundError(task_id=task_id)
164164

packages/service-library/src/servicelib/celery/models.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import datetime
22
from enum import StrEnum
3-
from typing import Annotated, Any, Final, Literal, Protocol, Self, TypeAlias, TypeVar
3+
from types import NoneType
4+
from typing import Annotated, Final, Literal, Protocol, Self, TypeAlias, TypeVar
45
from uuid import UUID
56

7+
import orjson
8+
from common_library.json_serialization import json_dumps, json_loads
69
from models_library.progress_bar import ProgressReport
710
from pydantic import BaseModel, ConfigDict, StringConstraints, model_validator
811
from pydantic.config import JsonDict
@@ -17,7 +20,17 @@
1720
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
1821
_FORBIDDEN_KEYS = ("*", _TASK_ID_KEY_DELIMITATOR, "=")
1922
_FORBIDDEN_VALUES = (_TASK_ID_KEY_DELIMITATOR, "=")
20-
_VALID_VALUE_TYPES = (int, float, bool, str)
23+
AllowedTypes = (
24+
int
25+
| float
26+
| bool
27+
| str
28+
| NoneType
29+
| list[str]
30+
| list[int]
31+
| list[float]
32+
| list[bool]
33+
)
2134

2235
Wildcard: TypeAlias = Literal["*"]
2336

@@ -55,50 +68,45 @@ def _check_valid_filters(self) -> Self:
5568
# forbidden values
5669
if any(x in f"{value}" for x in _FORBIDDEN_VALUES):
5770
raise ValueError(f"Invalid filter value for key '{key}': '{value}'")
58-
if not any(isinstance(value, type_) for type_ in _VALID_VALUE_TYPES):
59-
# restrict value types to ensure smooth serialization/deserialization
60-
raise ValueError(f"Invalid filter value for key '{key}': '{value}'")
61-
return self
6271

63-
def _build_task_id_prefix(self) -> str:
64-
filter_dict = self.model_dump()
65-
return _TASK_ID_KEY_DELIMITATOR.join(
66-
[f"{key}={filter_dict[key]}" for key in sorted(filter_dict)]
67-
)
72+
class _ValidationModel(BaseModel):
73+
filters: dict[str, AllowedTypes]
74+
75+
_ValidationModel.model_validate({"filters": self.model_dump()})
76+
return self
6877

69-
def create_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID:
78+
def model_dump_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID:
79+
data = self.model_dump(mode="json")
80+
data.update({"task_uuid": f"{task_uuid}"})
7081
return _TASK_ID_KEY_DELIMITATOR.join(
71-
[
72-
self._build_task_id_prefix(),
73-
f"task_uuid={task_uuid}",
74-
]
82+
[f"{k}={json_dumps(v)}" for k, v in sorted(data.items())]
7583
)
7684

7785
@classmethod
78-
def validate_from_task_id(cls, task_id: TaskID) -> Self:
79-
filter_dict = cls._recreate_data(task_id)
80-
return cls.model_validate(filter_dict)
86+
def model_validate_task_id(cls, task_id: TaskID) -> Self:
87+
data = cls._deserialize_task_id(task_id)
88+
data.pop("task_uuid", None)
89+
return cls.model_validate(data)
8190

8291
@classmethod
83-
def _recreate_data(cls, task_id: TaskID) -> dict[str, Any]:
84-
"""Recreates the filter data from a task_id string
85-
WARNING: does not validate types. For that use `recreate_model` instead
86-
"""
92+
def _deserialize_task_id(cls, task_id: TaskID) -> dict[str, AllowedTypes]:
93+
key_value_pairs = [
94+
item.split("=") for item in task_id.split(_TASK_ID_KEY_DELIMITATOR)
95+
]
8796
try:
88-
parts = task_id.split(_TASK_ID_KEY_DELIMITATOR)
89-
return {
90-
key: value
91-
for part in parts[:-1]
92-
if (key := part.split("=")[0]) and (value := part.split("=")[1])
93-
}
94-
except (IndexError, ValueError) as err:
97+
return {key: json_loads(value) for key, value in key_value_pairs}
98+
except orjson.JSONDecodeError as err:
9599
raise ValueError(f"Invalid task_id format: {task_id}") from err
96100

97101
@classmethod
98102
def get_task_uuid(cls, task_id: TaskID) -> TaskUUID:
103+
data = cls._deserialize_task_id(task_id)
99104
try:
100-
return UUID(task_id.split(_TASK_ID_KEY_DELIMITATOR)[-1].split("=")[1])
101-
except (IndexError, ValueError) as err:
105+
uuid_string = data["task_uuid"]
106+
if not isinstance(uuid_string, str):
107+
raise ValueError(f"Invalid task_id format: {task_id}")
108+
return TaskUUID(uuid_string)
109+
except ValueError as err:
102110
raise ValueError(f"Invalid task_id format: {task_id}") from err
103111

104112

packages/service-library/tests/test_celery.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import typing
1+
from types import NoneType
22
from typing import Annotated
33

44
# pylint: disable=redefined-outer-name
@@ -8,7 +8,6 @@
88
from faker import Faker
99
from pydantic import StringConstraints
1010
from servicelib.celery.models import (
11-
_VALID_VALUE_TYPES,
1211
OwnerMetadata,
1312
TaskUUID,
1413
Wildcard,
@@ -23,7 +22,6 @@ class _TestOwnerMetadata(OwnerMetadata):
2322
bool_: bool
2423
none_: None
2524
uuid_: str
26-
list_: list[str]
2725

2826

2927
@pytest.fixture
@@ -34,7 +32,6 @@ def owner_metadata() -> dict[str, str | int | bool | None | list[str]]:
3432
"bool_": _faker.boolean(),
3533
"none_": None,
3634
"uuid_": _faker.uuid4(),
37-
"list_": [_faker.word() for _ in range(3)],
3835
"owner": _faker.word().lower(),
3936
}
4037
_TestOwnerMetadata.model_validate(data) # ensure it's valid
@@ -67,37 +64,38 @@ async def test_task_filter_task_uuid(
6764
):
6865
task_filter = _TestOwnerMetadata.model_validate(owner_metadata)
6966
task_uuid = TaskUUID(_faker.uuid4())
70-
task_id = task_filter.create_task_id(task_uuid)
67+
task_id = task_filter.model_dump_task_id(task_uuid)
7168
assert OwnerMetadata.get_task_uuid(task_id=task_id) == task_uuid
7269

7370

74-
async def test_create_task_filter_from_task_id():
71+
async def test_owner_metadata_task_id_dump_and_validate():
7572

7673
class MyModel(OwnerMetadata):
7774
int_: int
7875
bool_: bool
7976
str_: str
8077
float_: float
81-
82-
# Check that all elements in _VALID_VALUE_TYPES are represented in MyModel's field types
83-
mymodel_types = set()
84-
for field in MyModel.model_fields.values():
85-
field_type = field.annotation
86-
origin = typing.get_origin(field_type)
87-
if origin is typing.Union:
88-
types_to_check = typing.get_args(field_type)
89-
else:
90-
types_to_check = [field_type]
91-
for t in types_to_check:
92-
if t is not Wildcard:
93-
mymodel_types.add(t)
94-
for valid_type in _VALID_VALUE_TYPES:
95-
assert valid_type in mymodel_types, f"{valid_type} not represented in MyModel"
96-
97-
mymodel = MyModel(int_=1, bool_=True, str_="test", float_=1.0, owner="myowner")
78+
none_: NoneType
79+
list_s: list[str]
80+
list_i: list[int]
81+
list_f: list[float]
82+
list_b: list[bool]
83+
84+
mymodel = MyModel(
85+
int_=1,
86+
none_=None,
87+
bool_=True,
88+
str_="test",
89+
float_=1.0,
90+
owner="myowner",
91+
list_b=[True, False],
92+
list_f=[1.0, 2.0],
93+
list_i=[1, 2],
94+
list_s=["a", "b"],
95+
)
9896
task_uuid = TaskUUID(_faker.uuid4())
99-
task_id = mymodel.create_task_id(task_uuid)
100-
mymodel_recreated = MyModel.validate_from_task_id(task_id=task_id)
97+
task_id = mymodel.model_dump_task_id(task_uuid)
98+
mymodel_recreated = MyModel.model_validate_task_id(task_id=task_id)
10199
assert mymodel_recreated == mymodel
102100

103101

0 commit comments

Comments
 (0)