55# license information.
66# --------------------------------------------------------------------------
77import base64
8+ from functools import partial
89from json import JSONEncoder
9- from typing import Dict , List , Optional , Union , cast , Any
10+ from typing import Dict , List , Optional , Union , cast , Any , Type , Callable , Tuple
1011from datetime import datetime , date , time , timedelta
1112from datetime import timezone
1213
1314
14- __all__ = ["NULL" , "AzureJSONEncoder" , "is_generated_model" , "as_attribute_dict" , "attribute_list" ]
15+ __all__ = [
16+ "NULL" ,
17+ "AzureJSONEncoder" ,
18+ "is_generated_model" ,
19+ "as_attribute_dict" ,
20+ "attribute_list" ,
21+ "TypeHandlerRegistry" ,
22+ ]
1523TZ_UTC = timezone .utc
1624
1725
@@ -29,6 +37,164 @@ def __bool__(self) -> bool:
2937"""
3038
3139
40+ class TypeHandlerRegistry :
41+ """A registry for custom serializers and deserializers for specific types or conditions."""
42+
43+ def __init__ (self ) -> None :
44+ self ._serializer_types : Dict [Type , Callable ] = {}
45+ self ._deserializer_types : Dict [Type , Callable ] = {}
46+ self ._serializer_predicates : List [Tuple [Callable [[Any ], bool ], Callable ]] = []
47+ self ._deserializer_predicates : List [Tuple [Callable [[Any ], bool ], Callable ]] = []
48+
49+ self ._serializer_cache : Dict [Type , Optional [Callable ]] = {}
50+ self ._deserializer_cache : Dict [Type , Optional [Callable ]] = {}
51+
52+ def register_serializer (
53+ self , condition : Union [Type , Callable [[Any ], bool ]]
54+ ) -> Callable [[Callable [[Any ], Dict [str , Any ]]], Callable [[Any ], Dict [str , Any ]]]:
55+ """Decorator to register a serializer.
56+
57+ The handler function is expected to take a single argument, the object to serialize,
58+ and return a dictionary representation of that object.
59+
60+ Examples:
61+
62+ .. code-block:: python
63+
64+ @registry.register_serializer(CustomModel)
65+ def serialize_single_type(value: CustomModel) -> dict:
66+ return value.to_dict()
67+
68+ @registry.register_serializer(lambda x: isinstance(x, BaseModel))
69+ def serialize_with_condition(value: BaseModel) -> dict:
70+ return value.to_dict()
71+
72+ # Called manually for a specific type
73+ def custom_serializer(value: CustomModel) -> Dict[str, Any]:
74+ return {"custom": value.custom}
75+
76+ registry.register_serializer(CustomModel)(custom_serializer)
77+
78+ :param condition: A type or a callable predicate function that takes an object and returns a bool.
79+ :type condition: Union[Type, Callable[[Any], bool]]
80+ :return: A decorator that registers the handler function.
81+ :rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]
82+ :raises TypeError: If the condition is neither a type nor a callable.
83+ """
84+
85+ def decorator (handler_func : Callable [[Any ], Dict [str , Any ]]) -> Callable [[Any ], Dict [str , Any ]]:
86+ if isinstance (condition , type ):
87+ self ._serializer_types [condition ] = handler_func
88+ elif callable (condition ):
89+ self ._serializer_predicates .append ((condition , handler_func ))
90+ else :
91+ raise TypeError ("Condition must be a type or a callable predicate function." )
92+
93+ self ._serializer_cache .clear ()
94+ return handler_func
95+
96+ return decorator
97+
98+ def register_deserializer (
99+ self , condition : Union [Type , Callable [[Any ], bool ]]
100+ ) -> Callable [[Callable [[Type , Dict [str , Any ]], Any ]], Callable [[Type , Dict [str , Any ]], Any ]]:
101+ """Decorator to register a deserializer.
102+
103+ The handler function is expected to take two arguments: the target type and the data dictionary,
104+ and return an instance of the target type.
105+
106+ Examples:
107+
108+ .. code-block:: python
109+
110+ @registry.register_deserializer(CustomModel)
111+ def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel:
112+ return cls(**data)
113+
114+ @registry.register_deserializer(lambda t: issubclass(t, BaseModel))
115+ def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel:
116+ return cls(**data)
117+
118+ # Called manually for a specific type
119+ def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel:
120+ return cls(custom=data["custom"])
121+
122+ registry.register_deserializer(CustomModel)(custom_deserializer)
123+
124+ :param condition: A type or a callable predicate function that takes an object and returns a bool.
125+ :type condition: Union[Type, Callable[[Any], bool]]
126+ :return: A decorator that registers the handler function.
127+ :rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]
128+ :raises TypeError: If the condition is neither a type nor a callable.
129+ """
130+
131+ def decorator (handler_func : Callable [[Type , Dict [str , Any ]], Any ]) -> Callable [[Type , Dict [str , Any ]], Any ]:
132+ if isinstance (condition , type ):
133+ self ._deserializer_types [condition ] = handler_func
134+ elif callable (condition ):
135+ self ._deserializer_predicates .append ((condition , handler_func ))
136+ else :
137+ raise TypeError ("Condition must be a type or a callable predicate function." )
138+
139+ self ._deserializer_cache .clear ()
140+ return handler_func
141+
142+ return decorator
143+
144+ def get_serializer (self , obj : Any ) -> Optional [Callable ]:
145+ """Gets the appropriate serializer for an object.
146+
147+ It first checks the type dictionary for a direct type match.
148+ If no match is found, it iterates through the predicate list to find a match.
149+
150+ Results of the lookup are cached for performance based on the object's type.
151+
152+ :param obj: The object to serialize.
153+ :type obj: any
154+ :return: The serializer function if found, otherwise None.
155+ :rtype: Optional[Callable]
156+ """
157+ obj_type = type (obj )
158+ if obj_type in self ._serializer_cache :
159+ return self ._serializer_cache [obj_type ]
160+
161+ handler = self ._serializer_types .get (type (obj ))
162+ if not handler :
163+ for predicate , pred_handler in self ._serializer_predicates :
164+ if predicate (obj ):
165+ handler = pred_handler
166+ break
167+
168+ self ._serializer_cache [obj_type ] = handler
169+ return handler
170+
171+ def get_deserializer (self , cls : Type ) -> Optional [Callable ]:
172+ """Gets the appropriate deserializer for a class.
173+
174+ It first checks the type dictionary for a direct type match.
175+ If no match is found, it iterates through the predicate list to find a match.
176+
177+ Results of the lookup are cached for performance based on the class.
178+
179+ :param cls: The class to deserialize.
180+ :type cls: type
181+ :return: The deserializer function wrapped with the class if found, otherwise None.
182+ :rtype: Optional[Callable]
183+ """
184+ if cls in self ._deserializer_cache :
185+ return self ._deserializer_cache [cls ]
186+
187+ handler = self ._deserializer_types .get (cls )
188+ if not handler :
189+ for predicate , pred_handler in self ._deserializer_predicates :
190+ if predicate (cls ):
191+ handler = pred_handler
192+ break
193+
194+ self ._deserializer_cache [cls ] = partial (handler , cls ) if handler else None
195+ return self ._deserializer_cache [cls ]
196+
197+
32198def _timedelta_as_isostr (td : timedelta ) -> str :
33199 """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
34200
0 commit comments