Skip to content

Commit f037904

Browse files
committed
wip: simplify unions
1 parent cd270e4 commit f037904

File tree

1 file changed

+55
-90
lines changed
  • src/serializers/type_serializers

1 file changed

+55
-90
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 55 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -70,50 +70,60 @@ impl UnionSerializer {
7070

7171
impl_py_gc_traverse!(UnionSerializer { choices });
7272

73-
fn to_python(
74-
value: &Bound<'_, PyAny>,
75-
include: Option<&Bound<'_, PyAny>>,
76-
exclude: Option<&Bound<'_, PyAny>>,
73+
fn union_serialize<S, R>(
74+
// if this returns `Ok(v)`, we picked a union variant to serialize, where
75+
// `S` is intermediate state which can be passed on to the finalizer
76+
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
77+
// if called with `Some(v)`, we have intermediate state to finish
78+
// if `None`, we need to just go to fallback
79+
finalizer: impl FnOnce(Option<S>) -> R,
7780
extra: &Extra,
7881
choices: &[CombinedSerializer],
7982
retry_with_lax_check: bool,
80-
) -> PyResult<PyObject> {
83+
) -> R {
8184
// try the serializers in left to right order with error_on fallback=true
8285
let mut new_extra = extra.clone();
8386
new_extra.check = SerCheck::Strict;
8487
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
8588

8689
for comb_serializer in choices {
87-
match comb_serializer.to_python(value, include, exclude, &new_extra) {
88-
Ok(v) => return Ok(v),
90+
match selector(comb_serializer, &new_extra) {
91+
Ok(v) => return finalizer(Some(v)),
8992
Err(err) => errors.push(err),
9093
}
9194
}
9295

93-
// If extra.check is SerCheck::Strict, we're in a nested union
94-
if extra.check != SerCheck::Strict && retry_with_lax_check {
96+
if retry_with_lax_check {
9597
new_extra.check = SerCheck::Lax;
9698
for comb_serializer in choices {
97-
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
98-
return Ok(v);
99+
if let Ok(v) = selector(comb_serializer, &new_extra) {
100+
return finalizer(Some(v));
99101
}
100102
}
101103
}
102104

103-
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
104-
if extra.check == SerCheck::None {
105-
for err in &errors {
106-
extra.warnings.custom_warning(err.to_string());
107-
}
108-
}
109-
// Otherwise, if we've encountered errors, return them to the parent union, which should take
110-
// care of the formatting for us
111-
else if !errors.is_empty() {
112-
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
113-
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
105+
for err in &errors {
106+
extra.warnings.custom_warning(err.to_string());
114107
}
115108

116-
infer_to_python(value, include, exclude, extra)
109+
finalizer(None)
110+
}
111+
112+
fn to_python(
113+
value: &Bound<'_, PyAny>,
114+
include: Option<&Bound<'_, PyAny>>,
115+
exclude: Option<&Bound<'_, PyAny>>,
116+
extra: &Extra,
117+
choices: &[CombinedSerializer],
118+
retry_with_lax_check: bool,
119+
) -> PyResult<PyObject> {
120+
union_serialize(
121+
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
122+
|v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok),
123+
extra,
124+
choices,
125+
retry_with_lax_check,
126+
)
117127
}
118128

119129
fn json_key<'a>(
@@ -122,40 +132,13 @@ fn json_key<'a>(
122132
choices: &[CombinedSerializer],
123133
retry_with_lax_check: bool,
124134
) -> PyResult<Cow<'a, str>> {
125-
let mut new_extra = extra.clone();
126-
new_extra.check = SerCheck::Strict;
127-
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
128-
129-
for comb_serializer in choices {
130-
match comb_serializer.json_key(key, &new_extra) {
131-
Ok(v) => return Ok(v),
132-
Err(err) => errors.push(err),
133-
}
134-
}
135-
136-
// If extra.check is SerCheck::Strict, we're in a nested union
137-
if extra.check != SerCheck::Strict && retry_with_lax_check {
138-
new_extra.check = SerCheck::Lax;
139-
for comb_serializer in choices {
140-
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
141-
return Ok(v);
142-
}
143-
}
144-
}
145-
146-
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
147-
if extra.check == SerCheck::None {
148-
for err in &errors {
149-
extra.warnings.custom_warning(err.to_string());
150-
}
151-
}
152-
// Otherwise, if we've encountered errors, return them to the parent union, which should take
153-
// care of the formatting for us
154-
else if !errors.is_empty() {
155-
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
156-
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
157-
}
158-
infer_json_key(key, extra)
135+
union_serialize(
136+
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
137+
|v| v.map_or_else(|| infer_json_key(key, extra), Ok),
138+
extra,
139+
choices,
140+
retry_with_lax_check,
141+
)
159142
}
160143

161144
#[allow(clippy::too_many_arguments)]
@@ -168,39 +151,21 @@ fn serde_serialize<S: serde::ser::Serializer>(
168151
choices: &[CombinedSerializer],
169152
retry_with_lax_check: bool,
170153
) -> Result<S::Ok, S::Error> {
171-
let py = value.py();
172-
let mut new_extra = extra.clone();
173-
new_extra.check = SerCheck::Strict;
174-
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
175-
176-
for comb_serializer in choices {
177-
match comb_serializer.to_python(value, include, exclude, &new_extra) {
178-
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
179-
Err(err) => errors.push(err),
180-
}
181-
}
182-
183-
// If extra.check is SerCheck::Strict, we're in a nested union
184-
if extra.check != SerCheck::Strict && retry_with_lax_check {
185-
new_extra.check = SerCheck::Lax;
186-
for comb_serializer in choices {
187-
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
188-
return infer_serialize(v.bind(py), serializer, None, None, extra);
189-
}
190-
}
191-
}
192-
193-
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
194-
if extra.check == SerCheck::None {
195-
for err in &errors {
196-
extra.warnings.custom_warning(err.to_string());
197-
}
198-
} else {
199-
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
200-
// will have to be returned here
201-
}
202-
203-
infer_serialize(value, serializer, include, exclude, extra)
154+
union_serialize(
155+
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
156+
|v| {
157+
infer_serialize(
158+
v.as_ref().map_or(value, |v| v.bind(value.py())),
159+
serializer,
160+
None,
161+
None,
162+
extra,
163+
)
164+
},
165+
extra,
166+
choices,
167+
retry_with_lax_check,
168+
)
204169
}
205170

206171
impl TypeSerializer for UnionSerializer {

0 commit comments

Comments
 (0)