Skip to content

Commit 7e94b08

Browse files
committed
improve validation from task_id
1 parent 24ef2be commit 7e94b08

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

packages/service-library/src/servicelib/celery/models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
1818
_FORBIDDEN_KEYS = ("*", _TASK_ID_KEY_DELIMITATOR, "=")
1919
_FORBIDDEN_VALUES = (_TASK_ID_KEY_DELIMITATOR, "=")
20+
_VALID_VALUE_TYPES = (int, float, bool, str)
2021

2122
Wildcard: TypeAlias = Literal["*"]
2223

@@ -54,6 +55,9 @@ def _check_valid_filters(self) -> Self:
5455
# forbidden values
5556
if any(x in f"{value}" for x in _FORBIDDEN_VALUES):
5657
raise ValueError(f"Invalid filter value for key '{key}': '{value}'")
58+
if not any(isinstance(value, type_) for type_ in _VALID_VALUE_TYPES):
59+
# restrict value types to ensure smooth serialization/deserialization
60+
raise ValueError(f"Invalid filter value for key '{key}': '{value}'")
5761
return self
5862

5963
def _build_task_id_prefix(self) -> str:
@@ -71,9 +75,9 @@ def create_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID:
7175
)
7276

7377
@classmethod
74-
def recreate_as_model(cls, task_id: TaskID, schema: type[ModelType]) -> ModelType:
78+
def validate_from_task_id(cls, task_id: TaskID) -> Self:
7579
filter_dict = cls._recreate_data(task_id)
76-
return schema.model_validate(filter_dict)
80+
return cls.model_validate(filter_dict)
7781

7882
@classmethod
7983
def _recreate_data(cls, task_id: TaskID) -> dict[str, Any]:

packages/service-library/tests/test_celery.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import typing
12
from typing import Annotated
23

34
# pylint: disable=redefined-outer-name
@@ -6,7 +7,12 @@
67
import pytest
78
from faker import Faker
89
from pydantic import StringConstraints
9-
from servicelib.celery.models import OwnerMetadata, TaskUUID, Wildcard
10+
from servicelib.celery.models import (
11+
_VALID_VALUE_TYPES,
12+
OwnerMetadata,
13+
TaskUUID,
14+
Wildcard,
15+
)
1016

1117
_faker = Faker()
1218

@@ -68,17 +74,31 @@ async def test_task_filter_task_uuid(
6874
async def test_create_task_filter_from_task_id():
6975

7076
class MyModel(OwnerMetadata):
71-
_int: int
72-
_bool: bool
73-
_str: str
74-
_list: list[str]
75-
76-
mymodel = MyModel(
77-
_int=1, _bool=True, _str="test", _list=["a", "b"], owner="myowner"
78-
)
77+
int_: int
78+
bool_: bool
79+
str_: str
80+
float_: float
81+
82+
# Check that all elements in _VALID_VALUE_TYPES are represented in MyModel's field types
83+
mymodel_types = set()
84+
for field in MyModel.model_fields.values():
85+
field_type = field.annotation
86+
origin = typing.get_origin(field_type)
87+
if origin is typing.Union:
88+
types_to_check = typing.get_args(field_type)
89+
else:
90+
types_to_check = [field_type]
91+
for t in types_to_check:
92+
if t is not Wildcard:
93+
mymodel_types.add(t)
94+
for valid_type in _VALID_VALUE_TYPES:
95+
assert valid_type in mymodel_types, f"{valid_type} not represented in MyModel"
96+
97+
mymodel = MyModel(int_=1, bool_=True, str_="test", float_=1.0, owner="myowner")
7998
task_uuid = TaskUUID(_faker.uuid4())
8099
task_id = mymodel.create_task_id(task_uuid)
81-
assert OwnerMetadata.recreate_as_model(task_id=task_id, schema=MyModel) == mymodel
100+
mymodel_recreated = MyModel.validate_from_task_id(task_id=task_id)
101+
assert mymodel_recreated == mymodel
82102

83103

84104
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)