Skip to content

Commit 8ba0717

Browse files
authored
[Core] Add type handling registry (Azure#43051)
This registry can be used by serialization and deserialization functions to enable custom type (de)serialization. Signed-off-by: Paul Van Eck <[email protected]>
1 parent c93f5e7 commit 8ba0717

File tree

4 files changed

+864
-5
lines changed

4 files changed

+864
-5
lines changed

sdk/core/azure-core/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Added `TypeHandlerRegistry` to `azure.core.serialization` to allow developers to register custom serializers and deserializers for specific types or conditions. #43051
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/core/azure-core/azure/core/serialization.py

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,21 @@
55
# license information.
66
# --------------------------------------------------------------------------
77
import base64
8+
from functools import partial
89
from json import JSONEncoder
9-
from typing import Dict, List, Optional, Union, cast, Any
10+
from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple
1011
from datetime import datetime, date, time, timedelta
1112
from datetime import timezone
1213

1314

14-
__all__ = ["NULL", "AzureJSONEncoder", "is_generated_model", "as_attribute_dict", "attribute_list"]
15+
__all__ = [
16+
"NULL",
17+
"AzureJSONEncoder",
18+
"is_generated_model",
19+
"as_attribute_dict",
20+
"attribute_list",
21+
"TypeHandlerRegistry",
22+
]
1523
TZ_UTC = timezone.utc
1624

1725

@@ -29,6 +37,164 @@ def __bool__(self) -> bool:
2937
"""
3038

3139

40+
class TypeHandlerRegistry:
41+
"""A registry for custom serializers and deserializers for specific types or conditions."""
42+
43+
def __init__(self) -> None:
44+
self._serializer_types: Dict[Type, Callable] = {}
45+
self._deserializer_types: Dict[Type, Callable] = {}
46+
self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
47+
self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
48+
49+
self._serializer_cache: Dict[Type, Optional[Callable]] = {}
50+
self._deserializer_cache: Dict[Type, Optional[Callable]] = {}
51+
52+
def register_serializer(
53+
self, condition: Union[Type, Callable[[Any], bool]]
54+
) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]:
55+
"""Decorator to register a serializer.
56+
57+
The handler function is expected to take a single argument, the object to serialize,
58+
and return a dictionary representation of that object.
59+
60+
Examples:
61+
62+
.. code-block:: python
63+
64+
@registry.register_serializer(CustomModel)
65+
def serialize_single_type(value: CustomModel) -> dict:
66+
return value.to_dict()
67+
68+
@registry.register_serializer(lambda x: isinstance(x, BaseModel))
69+
def serialize_with_condition(value: BaseModel) -> dict:
70+
return value.to_dict()
71+
72+
# Called manually for a specific type
73+
def custom_serializer(value: CustomModel) -> Dict[str, Any]:
74+
return {"custom": value.custom}
75+
76+
registry.register_serializer(CustomModel)(custom_serializer)
77+
78+
:param condition: A type or a callable predicate function that takes an object and returns a bool.
79+
:type condition: Union[Type, Callable[[Any], bool]]
80+
:return: A decorator that registers the handler function.
81+
:rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]
82+
:raises TypeError: If the condition is neither a type nor a callable.
83+
"""
84+
85+
def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]:
86+
if isinstance(condition, type):
87+
self._serializer_types[condition] = handler_func
88+
elif callable(condition):
89+
self._serializer_predicates.append((condition, handler_func))
90+
else:
91+
raise TypeError("Condition must be a type or a callable predicate function.")
92+
93+
self._serializer_cache.clear()
94+
return handler_func
95+
96+
return decorator
97+
98+
def register_deserializer(
99+
self, condition: Union[Type, Callable[[Any], bool]]
100+
) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]:
101+
"""Decorator to register a deserializer.
102+
103+
The handler function is expected to take two arguments: the target type and the data dictionary,
104+
and return an instance of the target type.
105+
106+
Examples:
107+
108+
.. code-block:: python
109+
110+
@registry.register_deserializer(CustomModel)
111+
def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel:
112+
return cls(**data)
113+
114+
@registry.register_deserializer(lambda t: issubclass(t, BaseModel))
115+
def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel:
116+
return cls(**data)
117+
118+
# Called manually for a specific type
119+
def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel:
120+
return cls(custom=data["custom"])
121+
122+
registry.register_deserializer(CustomModel)(custom_deserializer)
123+
124+
:param condition: A type or a callable predicate function that takes an object and returns a bool.
125+
:type condition: Union[Type, Callable[[Any], bool]]
126+
:return: A decorator that registers the handler function.
127+
:rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]
128+
:raises TypeError: If the condition is neither a type nor a callable.
129+
"""
130+
131+
def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]:
132+
if isinstance(condition, type):
133+
self._deserializer_types[condition] = handler_func
134+
elif callable(condition):
135+
self._deserializer_predicates.append((condition, handler_func))
136+
else:
137+
raise TypeError("Condition must be a type or a callable predicate function.")
138+
139+
self._deserializer_cache.clear()
140+
return handler_func
141+
142+
return decorator
143+
144+
def get_serializer(self, obj: Any) -> Optional[Callable]:
145+
"""Gets the appropriate serializer for an object.
146+
147+
It first checks the type dictionary for a direct type match.
148+
If no match is found, it iterates through the predicate list to find a match.
149+
150+
Results of the lookup are cached for performance based on the object's type.
151+
152+
:param obj: The object to serialize.
153+
:type obj: any
154+
:return: The serializer function if found, otherwise None.
155+
:rtype: Optional[Callable]
156+
"""
157+
obj_type = type(obj)
158+
if obj_type in self._serializer_cache:
159+
return self._serializer_cache[obj_type]
160+
161+
handler = self._serializer_types.get(type(obj))
162+
if not handler:
163+
for predicate, pred_handler in self._serializer_predicates:
164+
if predicate(obj):
165+
handler = pred_handler
166+
break
167+
168+
self._serializer_cache[obj_type] = handler
169+
return handler
170+
171+
def get_deserializer(self, cls: Type) -> Optional[Callable]:
172+
"""Gets the appropriate deserializer for a class.
173+
174+
It first checks the type dictionary for a direct type match.
175+
If no match is found, it iterates through the predicate list to find a match.
176+
177+
Results of the lookup are cached for performance based on the class.
178+
179+
:param cls: The class to deserialize.
180+
:type cls: type
181+
:return: The deserializer function wrapped with the class if found, otherwise None.
182+
:rtype: Optional[Callable]
183+
"""
184+
if cls in self._deserializer_cache:
185+
return self._deserializer_cache[cls]
186+
187+
handler = self._deserializer_types.get(cls)
188+
if not handler:
189+
for predicate, pred_handler in self._deserializer_predicates:
190+
if predicate(cls):
191+
handler = pred_handler
192+
break
193+
194+
self._deserializer_cache[cls] = partial(handler, cls) if handler else None
195+
return self._deserializer_cache[cls]
196+
197+
32198
def _timedelta_as_isostr(td: timedelta) -> str:
33199
"""Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
34200

sdk/core/azure-core/tests/specs_sdk/modeltypes/modeltypes/_utils/model_base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import base64
1818
import re
1919
import typing
20+
import traceback
2021
import enum
2122
import email.utils
2223
from datetime import datetime, date, time, timedelta, timezone
@@ -28,7 +29,7 @@
2829
from azure.core.exceptions import DeserializationError
2930
from azure.core import CaseInsensitiveEnumMeta
3031
from azure.core.pipeline import PipelineResponse
31-
from azure.core.serialization import _Null
32+
from azure.core.serialization import _Null, TypeHandlerRegistry
3233

3334
_LOGGER = logging.getLogger(__name__)
3435

@@ -38,6 +39,9 @@
3839
_T = typing.TypeVar("_T")
3940

4041

42+
TYPE_HANDLER_REGISTRY = TypeHandlerRegistry()
43+
44+
4145
def _timedelta_as_isostr(td: timedelta) -> str:
4246
"""Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
4347
@@ -161,6 +165,10 @@ def default(self, o): # pylint: disable=too-many-return-statements
161165
except AttributeError:
162166
# This will be raised when it hits value.total_seconds in the method above
163167
pass
168+
169+
custom_serializer = TYPE_HANDLER_REGISTRY.get_serializer(o)
170+
if custom_serializer:
171+
return custom_serializer(o)
164172
return super(SdkJSONEncoder, self).default(o)
165173

166174

@@ -481,6 +489,7 @@ def _is_model(obj: typing.Any) -> bool:
481489

482490

483491
def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements
492+
484493
if isinstance(o, list):
485494
return [_serialize(x, format) for x in o]
486495
if isinstance(o, dict):
@@ -510,6 +519,12 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
510519
except AttributeError:
511520
# This will be raised when it hits value.total_seconds in the method above
512521
pass
522+
523+
# Check if there's a custom serializer for the type
524+
custom_serializer = TYPE_HANDLER_REGISTRY.get_serializer(o)
525+
if custom_serializer:
526+
return custom_serializer(o)
527+
513528
return o
514529

515530

@@ -886,6 +901,10 @@ def _deserialize_default(
886901
if get_deserializer(annotation, rf):
887902
return functools.partial(_deserialize_default, get_deserializer(annotation, rf))
888903

904+
deserializer = TYPE_HANDLER_REGISTRY.get_deserializer(annotation)
905+
if deserializer:
906+
return deserializer
907+
889908
return functools.partial(_deserialize_default, annotation)
890909

891910

0 commit comments

Comments
 (0)