@@ -4,6 +4,7 @@ use pyo3::prelude::*;
44use pyo3:: types:: { PyDict , PyList , PyTuple } ;
55use smallvec:: SmallVec ;
66use std:: borrow:: Cow ;
7+ use std:: sync:: Arc ;
78
89use crate :: build_tools:: py_schema_err;
910use crate :: common:: union:: { Discriminator , SMALL_UNION_THRESHOLD } ;
@@ -119,6 +120,41 @@ fn union_serialize<S, R>(
119120 Ok ( finalizer ( None ) )
120121}
121122
123+ fn tagged_union_serialize < S > (
124+ discriminator_value : Option < Py < PyAny > > ,
125+ lookup : & HashMap < String , usize > ,
126+ // if this returns `Ok(v)`, we picked a union variant to serialize, where
127+ // `S` is intermediate state which can be passed on to the finalizer
128+ mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
129+ extra : & Extra ,
130+ choices : & Vec < CombinedSerializer > ,
131+ retry_with_lax_check : bool ,
132+ ) -> Option < S > {
133+ let mut new_extra = extra. clone ( ) ;
134+ new_extra. check = SerCheck :: Strict ;
135+
136+ if let Some ( tag) = discriminator_value {
137+ let tag_str = tag. to_string ( ) ;
138+ if let Some ( & serializer_index) = lookup. get ( & tag_str) {
139+ let selected_serializer = & choices[ serializer_index] ;
140+
141+ match selector ( & selected_serializer, & new_extra) {
142+ Ok ( v) => return Some ( v) ,
143+ Err ( _) => {
144+ if retry_with_lax_check {
145+ new_extra. check = SerCheck :: Lax ;
146+ if let Ok ( v) = selector ( & selected_serializer, & new_extra) {
147+ return Some ( v) ;
148+ }
149+ }
150+ }
151+ }
152+ }
153+ }
154+
155+ None
156+ }
157+
122158impl TypeSerializer for UnionSerializer {
123159 fn to_python (
124160 & self ,
@@ -237,67 +273,56 @@ impl TypeSerializer for TaggedUnionSerializer {
237273 exclude : Option < & Bound < ' _ , PyAny > > ,
238274 extra : & Extra ,
239275 ) -> PyResult < PyObject > {
240- let mut new_extra = extra. clone ( ) ;
241- new_extra. check = SerCheck :: Strict ;
242-
243- if let Some ( tag) = self . get_discriminator_value ( value, extra) {
244- let tag_str = tag. to_string ( ) ;
245- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
246- let serializer = & self . choices [ serializer_index] ;
247-
248- match serializer. to_python ( value, include, exclude, & new_extra) {
249- Ok ( v) => return Ok ( v) ,
250- Err ( _) => {
251- if self . retry_with_lax_check ( ) {
252- new_extra. check = SerCheck :: Lax ;
253- if let Ok ( v) = serializer. to_python ( value, include, exclude, & new_extra) {
254- return Ok ( v) ;
255- }
256- }
257- }
258- }
259- }
260- }
276+ let to_python_selector = |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
277+ comb_serializer. to_python ( value, include, exclude, new_extra)
278+ } ;
261279
262- union_serialize (
263- |comb_serializer, new_extra| comb_serializer. to_python ( value, include, exclude, new_extra) ,
264- |v| v. map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok ) ,
280+ tagged_union_serialize (
281+ self . get_discriminator_value ( value, extra) ,
282+ & self . lookup ,
283+ to_python_selector,
265284 extra,
266285 & self . choices ,
267286 self . retry_with_lax_check ( ) ,
268- ) ?
287+ )
288+ . map_or_else (
289+ || {
290+ union_serialize (
291+ to_python_selector,
292+ |v| v. map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok ) ,
293+ extra,
294+ & self . choices ,
295+ self . retry_with_lax_check ( ) ,
296+ ) ?
297+ } ,
298+ Ok ,
299+ )
269300 }
270301
271302 fn json_key < ' a > ( & self , key : & ' a Bound < ' _ , PyAny > , extra : & Extra ) -> PyResult < Cow < ' a , str > > {
272- let mut new_extra = extra. clone ( ) ;
273- new_extra. check = SerCheck :: Strict ;
274-
275- if let Some ( tag) = self . get_discriminator_value ( key, extra) {
276- let tag_str = tag. to_string ( ) ;
277- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
278- let serializer = & self . choices [ serializer_index] ;
279-
280- match serializer. json_key ( key, & new_extra) {
281- Ok ( v) => return Ok ( v) ,
282- Err ( _) => {
283- if self . retry_with_lax_check ( ) {
284- new_extra. check = SerCheck :: Lax ;
285- if let Ok ( v) = serializer. json_key ( key, & new_extra) {
286- return Ok ( v) ;
287- }
288- }
289- }
290- }
291- }
292- }
303+ let json_key_selector =
304+ |comb_serializer : & CombinedSerializer , new_extra : & Extra | comb_serializer. json_key ( key, new_extra) ;
293305
294- union_serialize (
295- |comb_serializer, new_extra| comb_serializer. json_key ( key, new_extra) ,
296- |v| v. map_or_else ( || infer_json_key ( key, extra) , Ok ) ,
306+ tagged_union_serialize (
307+ self . get_discriminator_value ( key, extra) ,
308+ & self . lookup ,
309+ json_key_selector,
297310 extra,
298311 & self . choices ,
299312 self . retry_with_lax_check ( ) ,
300- ) ?
313+ )
314+ . map_or_else (
315+ || {
316+ union_serialize (
317+ json_key_selector,
318+ |v| v. map_or_else ( || infer_json_key ( key, extra) , Ok ) ,
319+ extra,
320+ & self . choices ,
321+ self . retry_with_lax_check ( ) ,
322+ ) ?
323+ } ,
324+ Ok ,
325+ )
301326 }
302327
303328 fn serde_serialize < S : serde:: ser:: Serializer > (
@@ -308,45 +333,39 @@ impl TypeSerializer for TaggedUnionSerializer {
308333 exclude : Option < & Bound < ' _ , PyAny > > ,
309334 extra : & Extra ,
310335 ) -> Result < S :: Ok , S :: Error > {
311- let py = value. py ( ) ;
312- let mut new_extra = extra. clone ( ) ;
313- new_extra. check = SerCheck :: Strict ;
314-
315- if let Some ( tag) = self . get_discriminator_value ( value, extra) {
316- let tag_str = tag. to_string ( ) ;
317- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
318- let selected_serializer = & self . choices [ serializer_index] ;
319-
320- match selected_serializer. to_python ( value, include, exclude, & new_extra) {
321- Ok ( v) => return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ,
322- Err ( _) => {
323- if self . retry_with_lax_check ( ) {
324- new_extra. check = SerCheck :: Lax ;
325- if let Ok ( v) = selected_serializer. to_python ( value, include, exclude, & new_extra) {
326- return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ;
327- }
328- }
329- }
330- }
331- }
332- }
336+ let serde_selector = |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
337+ comb_serializer. to_python ( value, include, exclude, new_extra)
338+ } ;
333339
334- union_serialize (
335- |comb_serializer, new_extra| comb_serializer. to_python ( value, include, exclude, new_extra) ,
336- |v| {
337- infer_serialize (
338- v. as_ref ( ) . map_or ( value, |v| v. bind ( value. py ( ) ) ) ,
339- serializer,
340- None ,
341- None ,
342- extra,
343- )
344- } ,
340+ tagged_union_serialize (
341+ None ,
342+ & self . lookup ,
343+ serde_selector,
345344 extra,
346345 & self . choices ,
347346 self . retry_with_lax_check ( ) ,
348347 )
349- . map_err ( |err| serde:: ser:: Error :: custom ( err. to_string ( ) ) ) ?
348+ . map_or_else (
349+ || {
350+ union_serialize (
351+ serde_selector,
352+ |v| {
353+ infer_serialize (
354+ v. as_ref ( ) . map_or ( value, |v| v. bind ( value. py ( ) ) ) ,
355+ serializer,
356+ None ,
357+ None ,
358+ extra,
359+ )
360+ } ,
361+ extra,
362+ & self . choices ,
363+ self . retry_with_lax_check ( ) ,
364+ )
365+ . map_err ( |err| serde:: ser:: Error :: custom ( err. to_string ( ) ) ) ?
366+ } ,
367+ |v| infer_serialize ( v. bind ( value. py ( ) ) , serializer, None , None , extra) ,
368+ )
350369 }
351370
352371 fn get_name ( & self ) -> & str {
0 commit comments