@@ -117,44 +117,6 @@ fn union_serialize<S>(
117117 Ok ( None )
118118}
119119
120- fn tagged_union_serialize < S > (
121- discriminator_value : Option < Py < PyAny > > ,
122- lookup : & HashMap < String , usize > ,
123- // if this returns `Ok(v)`, we picked a union variant to serialize, where
124- // `S` is intermediate state which can be passed on to the finalizer
125- mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
126- extra : & Extra ,
127- choices : & [ CombinedSerializer ] ,
128- retry_with_lax_check : bool ,
129- ) -> PyResult < Option < S > > {
130- let mut new_extra = extra. clone ( ) ;
131- new_extra. check = SerCheck :: Strict ;
132-
133- if let Some ( tag) = discriminator_value {
134- let tag_str = tag. to_string ( ) ;
135- if let Some ( & serializer_index) = lookup. get ( & tag_str) {
136- let selected_serializer = & choices[ serializer_index] ;
137-
138- match selector ( selected_serializer, & new_extra) {
139- Ok ( v) => return Ok ( Some ( v) ) ,
140- Err ( _) => {
141- if retry_with_lax_check {
142- new_extra. check = SerCheck :: Lax ;
143- if let Ok ( v) = selector ( selected_serializer, & new_extra) {
144- return Ok ( Some ( v) ) ;
145- }
146- }
147- }
148- }
149- }
150- }
151-
152- // if we haven't returned at this point, we should fallback to the union serializer
153- // which preserves the historical expectation that we do our best with serialization
154- // even if that means we resort to inference
155- union_serialize ( selector, extra, choices, retry_with_lax_check)
156- }
157-
158120impl TypeSerializer for UnionSerializer {
159121 fn to_python (
160122 & self ,
@@ -267,27 +229,21 @@ impl TypeSerializer for TaggedUnionSerializer {
267229 exclude : Option < & Bound < ' _ , PyAny > > ,
268230 extra : & Extra ,
269231 ) -> PyResult < PyObject > {
270- tagged_union_serialize (
271- self . get_discriminator_value ( value, extra) ,
272- & self . lookup ,
232+ self . tagged_union_serialize (
233+ value,
273234 |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
274235 comb_serializer. to_python ( value, include, exclude, new_extra)
275236 } ,
276237 extra,
277- & self . choices ,
278- self . retry_with_lax_check ( ) ,
279238 ) ?
280239 . map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok )
281240 }
282241
283242 fn json_key < ' a > ( & self , key : & ' a Bound < ' _ , PyAny > , extra : & Extra ) -> PyResult < Cow < ' a , str > > {
284- tagged_union_serialize (
285- self . get_discriminator_value ( key, extra) ,
286- & self . lookup ,
243+ self . tagged_union_serialize (
244+ key,
287245 |comb_serializer : & CombinedSerializer , new_extra : & Extra | comb_serializer. json_key ( key, new_extra) ,
288246 extra,
289- & self . choices ,
290- self . retry_with_lax_check ( ) ,
291247 ) ?
292248 . map_or_else ( || infer_json_key ( key, extra) , Ok )
293249 }
@@ -300,15 +256,12 @@ impl TypeSerializer for TaggedUnionSerializer {
300256 exclude : Option < & Bound < ' _ , PyAny > > ,
301257 extra : & Extra ,
302258 ) -> Result < S :: Ok , S :: Error > {
303- match tagged_union_serialize (
304- None ,
305- & self . lookup ,
259+ match self . tagged_union_serialize (
260+ value,
306261 |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
307262 comb_serializer. to_python ( value, include, exclude, new_extra)
308263 } ,
309264 extra,
310- & self . choices ,
311- self . retry_with_lax_check ( ) ,
312265 ) {
313266 Ok ( Some ( v) ) => return infer_serialize ( v. bind ( value. py ( ) ) , serializer, None , None , extra) ,
314267 Ok ( None ) => infer_serialize ( value, serializer, include, exclude, extra) ,
@@ -326,36 +279,66 @@ impl TypeSerializer for TaggedUnionSerializer {
326279}
327280
328281impl TaggedUnionSerializer {
329- fn get_discriminator_value ( & self , value : & Bound < ' _ , PyAny > , extra : & Extra ) -> Option < Py < PyAny > > {
282+ fn get_discriminator_value < ' py > ( & self , value : & Bound < ' py , PyAny > ) -> Option < Bound < ' py , PyAny > > {
330283 let py = value. py ( ) ;
331- let discriminator_value = match & self . discriminator {
284+ match & self . discriminator {
332285 Discriminator :: LookupKey ( lookup_key) => {
333286 // we're pretty lax here, we allow either dict[key] or object.key, as we very well could
334287 // be doing a discriminator lookup on a typed dict, and there's no good way to check that
335288 // at this point. we could be more strict and only do this in lax mode...
336- let getattr_result = match value. is_instance_of :: < PyDict > ( ) {
337- true => {
338- let value_dict = value. downcast :: < PyDict > ( ) . unwrap ( ) ;
339- lookup_key. py_get_dict_item ( value_dict) . ok ( )
340- }
341- false => lookup_key. simple_py_get_attr ( value) . ok ( ) ,
342- } ;
343- getattr_result. and_then ( |opt| opt. map ( |( _, bound) | bound. to_object ( py) ) )
289+ if let Ok ( value_dict) = value. downcast :: < PyDict > ( ) {
290+ lookup_key. py_get_dict_item ( value_dict) . ok ( ) . flatten ( )
291+ } else {
292+ lookup_key. simple_py_get_attr ( value) . ok ( ) . flatten ( )
293+ }
294+ . map ( |( _, tag) | tag)
344295 }
345- Discriminator :: Function ( func) => func. call1 ( py, ( value, ) ) . ok ( ) ,
346- } ;
347- if discriminator_value. is_none ( ) {
348- let value_str = truncate_safe_repr ( value, None ) ;
296+ Discriminator :: Function ( func) => func. bind ( py) . call1 ( ( value, ) ) . ok ( ) ,
297+ }
298+ }
349299
350- // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
351- if extra. check == SerCheck :: None {
352- extra. warnings . custom_warning (
353- format ! (
354- "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
355- )
356- ) ;
300+ fn tagged_union_serialize < S > (
301+ & self ,
302+ value : & Bound < ' _ , PyAny > ,
303+ // if this returns `Ok(v)`, we picked a union variant to serialize, where
304+ // `S` is intermediate state which can be passed on to the finalizer
305+ mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
306+ extra : & Extra ,
307+ ) -> PyResult < Option < S > > {
308+ if let Some ( tag) = self . get_discriminator_value ( value) {
309+ let mut new_extra = extra. clone ( ) ;
310+ new_extra. check = SerCheck :: Strict ;
311+
312+ let tag_str = tag. to_string ( ) ;
313+ if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
314+ let selected_serializer = & self . choices [ serializer_index] ;
315+
316+ match selector ( selected_serializer, & new_extra) {
317+ Ok ( v) => return Ok ( Some ( v) ) ,
318+ Err ( _) => {
319+ if self . retry_with_lax_check ( ) {
320+ new_extra. check = SerCheck :: Lax ;
321+ if let Ok ( v) = selector ( selected_serializer, & new_extra) {
322+ return Ok ( Some ( v) ) ;
323+ }
324+ }
325+ }
326+ }
357327 }
328+ } else if extra. check == SerCheck :: None {
329+ // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise
330+ // this warning
331+ let value_str = truncate_safe_repr ( value, None ) ;
332+ extra. warnings . custom_warning (
333+ format ! (
334+ "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
335+ )
336+ ) ;
358337 }
359- discriminator_value
338+
339+ // if we haven't returned at this point, we should fallback to the union serializer
340+ // which preserves the historical expectation that we do our best with serialization
341+ // even if that means we resort to inference
342+ union_serialize ( selector, extra, & self . choices , self . retry_with_lax_check ( ) )
360343 }
361344}
0 commit comments