|
1 | 1 | import datetime |
2 | 2 | from enum import StrEnum |
3 | | -from typing import Annotated, Any, Final, Literal, Protocol, Self, TypeAlias, TypeVar |
| 3 | +from types import NoneType |
| 4 | +from typing import Annotated, Final, Literal, Protocol, Self, TypeAlias, TypeVar |
4 | 5 | from uuid import UUID |
5 | 6 |
|
| 7 | +import orjson |
| 8 | +from common_library.json_serialization import json_dumps, json_loads |
6 | 9 | from models_library.progress_bar import ProgressReport |
7 | 10 | from pydantic import BaseModel, ConfigDict, StringConstraints, model_validator |
8 | 11 | from pydantic.config import JsonDict |
|
17 | 20 | _TASK_ID_KEY_DELIMITATOR: Final[str] = ":" |
18 | 21 | _FORBIDDEN_KEYS = ("*", _TASK_ID_KEY_DELIMITATOR, "=") |
19 | 22 | _FORBIDDEN_VALUES = (_TASK_ID_KEY_DELIMITATOR, "=") |
20 | | -_VALID_VALUE_TYPES = (int, float, bool, str) |
| 23 | +AllowedTypes = ( |
| 24 | + int |
| 25 | + | float |
| 26 | + | bool |
| 27 | + | str |
| 28 | + | NoneType |
| 29 | + | list[str] |
| 30 | + | list[int] |
| 31 | + | list[float] |
| 32 | + | list[bool] |
| 33 | +) |
21 | 34 |
|
22 | 35 | Wildcard: TypeAlias = Literal["*"] |
23 | 36 |
|
@@ -55,50 +68,45 @@ def _check_valid_filters(self) -> Self: |
55 | 68 | # forbidden values |
56 | 69 | if any(x in f"{value}" for x in _FORBIDDEN_VALUES): |
57 | 70 | raise ValueError(f"Invalid filter value for key '{key}': '{value}'") |
58 | | - if not any(isinstance(value, type_) for type_ in _VALID_VALUE_TYPES): |
59 | | - # restrict value types to ensure smooth serialization/deserialization |
60 | | - raise ValueError(f"Invalid filter value for key '{key}': '{value}'") |
61 | | - return self |
62 | 71 |
|
63 | | - def _build_task_id_prefix(self) -> str: |
64 | | - filter_dict = self.model_dump() |
65 | | - return _TASK_ID_KEY_DELIMITATOR.join( |
66 | | - [f"{key}={filter_dict[key]}" for key in sorted(filter_dict)] |
67 | | - ) |
| 72 | + class _ValidationModel(BaseModel): |
| 73 | + filters: dict[str, AllowedTypes] |
| 74 | + |
| 75 | + _ValidationModel.model_validate({"filters": self.model_dump()}) |
| 76 | + return self |
68 | 77 |
|
69 | | - def create_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID: |
| 78 | + def model_dump_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID: |
| 79 | + data = self.model_dump(mode="json") |
| 80 | + data.update({"task_uuid": f"{task_uuid}"}) |
70 | 81 | return _TASK_ID_KEY_DELIMITATOR.join( |
71 | | - [ |
72 | | - self._build_task_id_prefix(), |
73 | | - f"task_uuid={task_uuid}", |
74 | | - ] |
| 82 | + [f"{k}={json_dumps(v)}" for k, v in sorted(data.items())] |
75 | 83 | ) |
76 | 84 |
|
77 | 85 | @classmethod |
78 | | - def validate_from_task_id(cls, task_id: TaskID) -> Self: |
79 | | - filter_dict = cls._recreate_data(task_id) |
80 | | - return cls.model_validate(filter_dict) |
| 86 | + def model_validate_task_id(cls, task_id: TaskID) -> Self: |
| 87 | + data = cls._deserialize_task_id(task_id) |
| 88 | + data.pop("task_uuid", None) |
| 89 | + return cls.model_validate(data) |
81 | 90 |
|
82 | 91 | @classmethod |
83 | | - def _recreate_data(cls, task_id: TaskID) -> dict[str, Any]: |
84 | | - """Recreates the filter data from a task_id string |
85 | | - WARNING: does not validate types. For that use `recreate_model` instead |
86 | | - """ |
| 92 | + def _deserialize_task_id(cls, task_id: TaskID) -> dict[str, AllowedTypes]: |
| 93 | + key_value_pairs = [ |
| 94 | + item.split("=") for item in task_id.split(_TASK_ID_KEY_DELIMITATOR) |
| 95 | + ] |
87 | 96 | try: |
88 | | - parts = task_id.split(_TASK_ID_KEY_DELIMITATOR) |
89 | | - return { |
90 | | - key: value |
91 | | - for part in parts[:-1] |
92 | | - if (key := part.split("=")[0]) and (value := part.split("=")[1]) |
93 | | - } |
94 | | - except (IndexError, ValueError) as err: |
| 97 | + return {key: json_loads(value) for key, value in key_value_pairs} |
| 98 | + except orjson.JSONDecodeError as err: |
95 | 99 | raise ValueError(f"Invalid task_id format: {task_id}") from err |
96 | 100 |
|
97 | 101 | @classmethod |
98 | 102 | def get_task_uuid(cls, task_id: TaskID) -> TaskUUID: |
| 103 | + data = cls._deserialize_task_id(task_id) |
99 | 104 | try: |
100 | | - return UUID(task_id.split(_TASK_ID_KEY_DELIMITATOR)[-1].split("=")[1]) |
101 | | - except (IndexError, ValueError) as err: |
| 105 | + uuid_string = data["task_uuid"] |
| 106 | + if not isinstance(uuid_string, str): |
| 107 | + raise ValueError(f"Invalid task_id format: {task_id}") |
| 108 | + return TaskUUID(uuid_string) |
| 109 | + except ValueError as err: |
102 | 110 | raise ValueError(f"Invalid task_id format: {task_id}") from err |
103 | 111 |
|
104 | 112 |
|
|
0 commit comments