@@ -70,50 +70,60 @@ impl UnionSerializer {
7070
7171impl_py_gc_traverse ! ( UnionSerializer  {  choices } ) ; 
7272
73- fn  to_python ( 
74-     value :  & Bound < ' _ ,  PyAny > , 
75-     include :  Option < & Bound < ' _ ,  PyAny > > , 
76-     exclude :  Option < & Bound < ' _ ,  PyAny > > , 
73+ fn  union_serialize < S ,  R > ( 
74+     // if this returns `Ok(v)`, we picked a union variant to serialize, where 
75+     // `S` is intermediate state which can be passed on to the finalizer 
76+     mut  selector :  impl  FnMut ( & CombinedSerializer ,  & Extra )  -> PyResult < S > , 
77+     // if called with `Some(v)`, we have intermediate state to finish 
78+     // if `None`, we need to just go to fallback 
79+     finalizer :  impl  FnOnce ( Option < S > )  -> R , 
7780    extra :  & Extra , 
7881    choices :  & [ CombinedSerializer ] , 
7982    retry_with_lax_check :  bool , 
80- )  -> PyResult < PyObject >  { 
83+ )  -> R  { 
8184    // try the serializers in left to right order with error_on fallback=true 
8285    let  mut  new_extra = extra. clone ( ) ; 
8386    new_extra. check  = SerCheck :: Strict ; 
8487    let  mut  errors:  SmallVec < [ PyErr ;  SMALL_UNION_THRESHOLD ] >  = SmallVec :: new ( ) ; 
8588
8689    for  comb_serializer in  choices { 
87-         match  comb_serializer . to_python ( value ,  include ,  exclude ,  & new_extra)  { 
88-             Ok ( v)  => return  Ok ( v ) , 
90+         match  selector ( comb_serializer ,  & new_extra)  { 
91+             Ok ( v)  => return  finalizer ( Some ( v ) ) , 
8992            Err ( err)  => errors. push ( err) , 
9093        } 
9194    } 
9295
93-     // If extra.check is SerCheck::Strict, we're in a nested union 
94-     if  extra. check  != SerCheck :: Strict  && retry_with_lax_check { 
96+     if  retry_with_lax_check { 
9597        new_extra. check  = SerCheck :: Lax ; 
9698        for  comb_serializer in  choices { 
97-             if  let  Ok ( v)  = comb_serializer . to_python ( value ,  include ,  exclude ,  & new_extra)  { 
98-                 return  Ok ( v ) ; 
99+             if  let  Ok ( v)  = selector ( comb_serializer ,  & new_extra)  { 
100+                 return  finalizer ( Some ( v ) ) ; 
99101            } 
100102        } 
101103    } 
102104
103-     // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings 
104-     if  extra. check  == SerCheck :: None  { 
105-         for  err in  & errors { 
106-             extra. warnings . custom_warning ( err. to_string ( ) ) ; 
107-         } 
108-     } 
109-     // Otherwise, if we've encountered errors, return them to the parent union, which should take 
110-     // care of the formatting for us 
111-     else  if  !errors. is_empty ( )  { 
112-         let  message = errors. iter ( ) . map ( ToString :: to_string) . collect :: < Vec < _ > > ( ) . join ( "\n " ) ; 
113-         return  Err ( PydanticSerializationUnexpectedValue :: new_err ( Some ( message) ) ) ; 
105+     for  err in  & errors { 
106+         extra. warnings . custom_warning ( err. to_string ( ) ) ; 
114107    } 
115108
116-     infer_to_python ( value,  include,  exclude,  extra) 
109+     finalizer ( None ) 
110+ } 
111+ 
112+ fn  to_python ( 
113+     value :  & Bound < ' _ ,  PyAny > , 
114+     include :  Option < & Bound < ' _ ,  PyAny > > , 
115+     exclude :  Option < & Bound < ' _ ,  PyAny > > , 
116+     extra :  & Extra , 
117+     choices :  & [ CombinedSerializer ] , 
118+     retry_with_lax_check :  bool , 
119+ )  -> PyResult < PyObject >  { 
120+     union_serialize ( 
121+         |comb_serializer,  new_extra| comb_serializer. to_python ( value,  include,  exclude,  new_extra) , 
122+         |v| v. map_or_else ( || infer_to_python ( value,  include,  exclude,  extra) ,  Ok ) , 
123+         extra, 
124+         choices, 
125+         retry_with_lax_check, 
126+     ) 
117127} 
118128
119129fn  json_key < ' a > ( 
@@ -122,40 +132,13 @@ fn json_key<'a>(
122132    choices :  & [ CombinedSerializer ] , 
123133    retry_with_lax_check :  bool , 
124134)  -> PyResult < Cow < ' a ,  str > >  { 
125-     let  mut  new_extra = extra. clone ( ) ; 
126-     new_extra. check  = SerCheck :: Strict ; 
127-     let  mut  errors:  SmallVec < [ PyErr ;  SMALL_UNION_THRESHOLD ] >  = SmallVec :: new ( ) ; 
128- 
129-     for  comb_serializer in  choices { 
130-         match  comb_serializer. json_key ( key,  & new_extra)  { 
131-             Ok ( v)  => return  Ok ( v) , 
132-             Err ( err)  => errors. push ( err) , 
133-         } 
134-     } 
135- 
136-     // If extra.check is SerCheck::Strict, we're in a nested union 
137-     if  extra. check  != SerCheck :: Strict  && retry_with_lax_check { 
138-         new_extra. check  = SerCheck :: Lax ; 
139-         for  comb_serializer in  choices { 
140-             if  let  Ok ( v)  = comb_serializer. json_key ( key,  & new_extra)  { 
141-                 return  Ok ( v) ; 
142-             } 
143-         } 
144-     } 
145- 
146-     // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings 
147-     if  extra. check  == SerCheck :: None  { 
148-         for  err in  & errors { 
149-             extra. warnings . custom_warning ( err. to_string ( ) ) ; 
150-         } 
151-     } 
152-     // Otherwise, if we've encountered errors, return them to the parent union, which should take 
153-     // care of the formatting for us 
154-     else  if  !errors. is_empty ( )  { 
155-         let  message = errors. iter ( ) . map ( ToString :: to_string) . collect :: < Vec < _ > > ( ) . join ( "\n " ) ; 
156-         return  Err ( PydanticSerializationUnexpectedValue :: new_err ( Some ( message) ) ) ; 
157-     } 
158-     infer_json_key ( key,  extra) 
135+     union_serialize ( 
136+         |comb_serializer,  new_extra| comb_serializer. json_key ( key,  new_extra) , 
137+         |v| v. map_or_else ( || infer_json_key ( key,  extra) ,  Ok ) , 
138+         extra, 
139+         choices, 
140+         retry_with_lax_check, 
141+     ) 
159142} 
160143
161144#[ allow( clippy:: too_many_arguments) ]  
@@ -168,39 +151,21 @@ fn serde_serialize<S: serde::ser::Serializer>(
168151    choices :  & [ CombinedSerializer ] , 
169152    retry_with_lax_check :  bool , 
170153)  -> Result < S :: Ok ,  S :: Error >  { 
171-     let  py = value. py ( ) ; 
172-     let  mut  new_extra = extra. clone ( ) ; 
173-     new_extra. check  = SerCheck :: Strict ; 
174-     let  mut  errors:  SmallVec < [ PyErr ;  SMALL_UNION_THRESHOLD ] >  = SmallVec :: new ( ) ; 
175- 
176-     for  comb_serializer in  choices { 
177-         match  comb_serializer. to_python ( value,  include,  exclude,  & new_extra)  { 
178-             Ok ( v)  => return  infer_serialize ( v. bind ( py) ,  serializer,  None ,  None ,  extra) , 
179-             Err ( err)  => errors. push ( err) , 
180-         } 
181-     } 
182- 
183-     // If extra.check is SerCheck::Strict, we're in a nested union 
184-     if  extra. check  != SerCheck :: Strict  && retry_with_lax_check { 
185-         new_extra. check  = SerCheck :: Lax ; 
186-         for  comb_serializer in  choices { 
187-             if  let  Ok ( v)  = comb_serializer. to_python ( value,  include,  exclude,  & new_extra)  { 
188-                 return  infer_serialize ( v. bind ( py) ,  serializer,  None ,  None ,  extra) ; 
189-             } 
190-         } 
191-     } 
192- 
193-     // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings 
194-     if  extra. check  == SerCheck :: None  { 
195-         for  err in  & errors { 
196-             extra. warnings . custom_warning ( err. to_string ( ) ) ; 
197-         } 
198-     }  else  { 
199-         // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors 
200-         // will have to be returned here 
201-     } 
202- 
203-     infer_serialize ( value,  serializer,  include,  exclude,  extra) 
154+     union_serialize ( 
155+         |comb_serializer,  new_extra| comb_serializer. to_python ( value,  include,  exclude,  new_extra) , 
156+         |v| { 
157+             infer_serialize ( 
158+                 v. as_ref ( ) . map_or ( value,  |v| v. bind ( value. py ( ) ) ) , 
159+                 serializer, 
160+                 None , 
161+                 None , 
162+                 extra, 
163+             ) 
164+         } , 
165+         extra, 
166+         choices, 
167+         retry_with_lax_check, 
168+     ) 
204169} 
205170
206171impl  TypeSerializer  for  UnionSerializer  { 
0 commit comments