Skip to content

Commit ee9dd3d

Browse files
committed
continue refactor
1 parent f3a304a commit ee9dd3d

File tree

1 file changed

+38
-87
lines changed
  • src/serializers/type_serializers

1 file changed

+38
-87
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 38 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,23 @@ impl UnionSerializer {
7070

7171
impl_py_gc_traverse!(UnionSerializer { choices });
7272

73-
fn union_serialize<S, R>(
74-
// if this returns `Ok(v)`, we picked a union variant to serialize, where
75-
// `S` is intermediate state which can be passed on to the finalizer
73+
fn union_serialize<S>(
74+
// if this returns `Ok(Some(v))`, we picked a union variant to serialize,
75+
// Or `Ok(None)` if we couldn't find a suitable variant to serialize
76+
// Finally, `Err(err)` if we encountered errors while trying to serialize
7677
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
77-
// if called with `Some(v)`, we have intermediate state to finish
78-
// if `None`, we need to just go to fallback
79-
finalizer: impl FnOnce(Option<S>) -> R,
8078
extra: &Extra,
8179
choices: &[CombinedSerializer],
8280
retry_with_lax_check: bool,
83-
) -> PyResult<R> {
81+
) -> PyResult<Option<S>> {
8482
// try the serializers in left to right order with error_on fallback=true
8583
let mut new_extra = extra.clone();
8684
new_extra.check = SerCheck::Strict;
8785
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
8886

8987
for comb_serializer in choices {
9088
match selector(comb_serializer, &new_extra) {
91-
Ok(v) => return Ok(finalizer(Some(v))),
89+
Ok(v) => return Ok(Some(v)),
9290
Err(err) => errors.push(err),
9391
}
9492
}
@@ -98,7 +96,7 @@ fn union_serialize<S, R>(
9896
new_extra.check = SerCheck::Lax;
9997
for comb_serializer in choices {
10098
if let Ok(v) = selector(comb_serializer, &new_extra) {
101-
return Ok(finalizer(Some(v)));
99+
return Ok(Some(v));
102100
}
103101
}
104102
}
@@ -116,7 +114,7 @@ fn union_serialize<S, R>(
116114
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
117115
}
118116

119-
Ok(finalizer(None))
117+
Ok(None)
120118
}
121119

122120
fn tagged_union_serialize<S>(
@@ -128,7 +126,7 @@ fn tagged_union_serialize<S>(
128126
extra: &Extra,
129127
choices: &[CombinedSerializer],
130128
retry_with_lax_check: bool,
131-
) -> Option<S> {
129+
) -> PyResult<Option<S>> {
132130
let mut new_extra = extra.clone();
133131
new_extra.check = SerCheck::Strict;
134132

@@ -138,20 +136,23 @@ fn tagged_union_serialize<S>(
138136
let selected_serializer = &choices[serializer_index];
139137

140138
match selector(selected_serializer, &new_extra) {
141-
Ok(v) => return Some(v),
139+
Ok(v) => return Ok(Some(v)),
142140
Err(_) => {
143141
if retry_with_lax_check {
144142
new_extra.check = SerCheck::Lax;
145143
if let Ok(v) = selector(selected_serializer, &new_extra) {
146-
return Some(v);
144+
return Ok(Some(v));
147145
}
148146
}
149147
}
150148
}
151149
}
152150
}
153151

154-
None
152+
// if we haven't returned at this point, we should fallback to the union serializer
153+
// which preserves the historical expectation that we do our best with serialization
154+
// even if that means we resort to inference
155+
union_serialize(selector, extra, choices, retry_with_lax_check)
155156
}
156157

157158
impl TypeSerializer for UnionSerializer {
@@ -164,21 +165,21 @@ impl TypeSerializer for UnionSerializer {
164165
) -> PyResult<PyObject> {
165166
union_serialize(
166167
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
167-
|v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok),
168168
extra,
169169
&self.choices,
170170
self.retry_with_lax_check(),
171171
)?
172+
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
172173
}
173174

174175
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
175176
union_serialize(
176177
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
177-
|v| v.map_or_else(|| infer_json_key(key, extra), Ok),
178178
extra,
179179
&self.choices,
180180
self.retry_with_lax_check(),
181181
)?
182+
.map_or_else(|| infer_json_key(key, extra), Ok)
182183
}
183184

184185
fn serde_serialize<S: serde::ser::Serializer>(
@@ -189,22 +190,16 @@ impl TypeSerializer for UnionSerializer {
189190
exclude: Option<&Bound<'_, PyAny>>,
190191
extra: &Extra,
191192
) -> Result<S::Ok, S::Error> {
192-
union_serialize(
193+
match union_serialize(
193194
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
194-
|v| {
195-
infer_serialize(
196-
v.as_ref().map_or(value, |v| v.bind(value.py())),
197-
serializer,
198-
None,
199-
None,
200-
extra,
201-
)
202-
},
203195
extra,
204196
&self.choices,
205197
self.retry_with_lax_check(),
206-
)
207-
.map_err(|err| serde::ser::Error::custom(err.to_string()))?
198+
) {
199+
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
200+
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
201+
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
202+
}
208203
}
209204

210205
fn get_name(&self) -> &str {
@@ -272,56 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
272267
exclude: Option<&Bound<'_, PyAny>>,
273268
extra: &Extra,
274269
) -> PyResult<PyObject> {
275-
let to_python_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| {
276-
comb_serializer.to_python(value, include, exclude, new_extra)
277-
};
278-
279270
tagged_union_serialize(
280271
self.get_discriminator_value(value, extra),
281272
&self.lookup,
282-
to_python_selector,
273+
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
274+
comb_serializer.to_python(value, include, exclude, new_extra)
275+
},
283276
extra,
284277
&self.choices,
285278
self.retry_with_lax_check(),
286-
)
287-
.map_or_else(
288-
|| {
289-
union_serialize(
290-
to_python_selector,
291-
|v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok),
292-
extra,
293-
&self.choices,
294-
self.retry_with_lax_check(),
295-
)?
296-
},
297-
Ok,
298-
)
279+
)?
280+
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
299281
}
300282

301283
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
302-
let json_key_selector =
303-
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra);
304-
305284
tagged_union_serialize(
306285
self.get_discriminator_value(key, extra),
307286
&self.lookup,
308-
json_key_selector,
287+
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
309288
extra,
310289
&self.choices,
311290
self.retry_with_lax_check(),
312-
)
313-
.map_or_else(
314-
|| {
315-
union_serialize(
316-
json_key_selector,
317-
|v| v.map_or_else(|| infer_json_key(key, extra), Ok),
318-
extra,
319-
&self.choices,
320-
self.retry_with_lax_check(),
321-
)?
322-
},
323-
Ok,
324-
)
291+
)?
292+
.map_or_else(|| infer_json_key(key, extra), Ok)
325293
}
326294

327295
fn serde_serialize<S: serde::ser::Serializer>(
@@ -332,37 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
332300
exclude: Option<&Bound<'_, PyAny>>,
333301
extra: &Extra,
334302
) -> Result<S::Ok, S::Error> {
335-
let serde_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| {
336-
comb_serializer.to_python(value, include, exclude, new_extra)
337-
};
338-
339-
if let Some(v) = tagged_union_serialize(
303+
match tagged_union_serialize(
340304
None,
341305
&self.lookup,
342-
serde_selector,
306+
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
307+
comb_serializer.to_python(value, include, exclude, new_extra)
308+
},
343309
extra,
344310
&self.choices,
345311
self.retry_with_lax_check(),
346312
) {
347-
return infer_serialize(v.bind(value.py()), serializer, None, None, extra);
313+
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
314+
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
315+
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
348316
}
349-
350-
union_serialize(
351-
serde_selector,
352-
|v| {
353-
infer_serialize(
354-
v.as_ref().map_or(value, |v| v.bind(value.py())),
355-
serializer,
356-
None,
357-
None,
358-
extra,
359-
)
360-
},
361-
extra,
362-
&self.choices,
363-
self.retry_with_lax_check(),
364-
)
365-
.map_err(|err| serde::ser::Error::custom(err.to_string()))?
366317
}
367318

368319
fn get_name(&self) -> &str {

0 commit comments

Comments
 (0)