-
Notifications
You must be signed in to change notification settings - Fork 312
Simplify shared union serializer logic #1538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
f037904
0ceac1b
a5fbc2a
742a3e9
92f5bbc
57c3c6f
7ea0d64
f3a304a
ee9dd3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,8 @@ use std::borrow::Cow; | |
| use crate::build_tools::py_schema_err; | ||
| use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; | ||
| use crate::definitions::DefinitionsBuilder; | ||
| use crate::serializers::PydanticSerializationUnexpectedValue; | ||
| use crate::tools::{truncate_safe_repr, SchemaDict}; | ||
| use crate::PydanticSerializationUnexpectedValue; | ||
|
|
||
| use super::{ | ||
| infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck, | ||
|
|
@@ -70,22 +70,25 @@ impl UnionSerializer { | |
|
|
||
| impl_py_gc_traverse!(UnionSerializer { choices }); | ||
|
|
||
| fn to_python( | ||
| value: &Bound<'_, PyAny>, | ||
| include: Option<&Bound<'_, PyAny>>, | ||
| exclude: Option<&Bound<'_, PyAny>>, | ||
| fn union_serialize<S, R>( | ||
| // if this returns `Ok(v)`, we picked a union variant to serialize, where | ||
| // `S` is intermediate state which can be passed on to the finalizer | ||
| mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>, | ||
| // if called with `Some(v)`, we have intermediate state to finish | ||
| // if `None`, we need to just go to fallback | ||
| finalizer: impl FnOnce(Option<S>) -> R, | ||
| extra: &Extra, | ||
| choices: &[CombinedSerializer], | ||
| retry_with_lax_check: bool, | ||
| ) -> PyResult<PyObject> { | ||
| ) -> PyResult<R> { | ||
| // try the serializers in left to right order with error_on fallback=true | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
| let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); | ||
|
|
||
| for comb_serializer in choices { | ||
| match comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| match selector(comb_serializer, &new_extra) { | ||
| Ok(v) => return Ok(finalizer(Some(v))), | ||
| Err(err) => errors.push(err), | ||
| } | ||
| } | ||
|
|
@@ -94,8 +97,8 @@ fn to_python( | |
| if extra.check != SerCheck::Strict && retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| for comb_serializer in choices { | ||
| if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| return Ok(v); | ||
| if let Ok(v) = selector(comb_serializer, &new_extra) { | ||
| return Ok(finalizer(Some(v))); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -113,94 +116,42 @@ fn to_python( | |
| return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); | ||
| } | ||
|
|
||
| infer_to_python(value, include, exclude, extra) | ||
| Ok(finalizer(None)) | ||
| } | ||
|
|
||
| fn json_key<'a>( | ||
| key: &'a Bound<'_, PyAny>, | ||
| fn tagged_union_serialize<S>( | ||
| discriminator_value: Option<Py<PyAny>>, | ||
| lookup: &HashMap<String, usize>, | ||
| // if this returns `Ok(v)`, we picked a union variant to serialize, where | ||
| // `S` is intermediate state which can be passed on to the finalizer | ||
| mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>, | ||
| extra: &Extra, | ||
| choices: &[CombinedSerializer], | ||
| retry_with_lax_check: bool, | ||
| ) -> PyResult<Cow<'a, str>> { | ||
| ) -> Option<S> { | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
| let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); | ||
|
|
||
| for comb_serializer in choices { | ||
| match comb_serializer.json_key(key, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| Err(err) => errors.push(err), | ||
| } | ||
| } | ||
|
|
||
| // If extra.check is SerCheck::Strict, we're in a nested union | ||
| if extra.check != SerCheck::Strict && retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| for comb_serializer in choices { | ||
| if let Ok(v) = comb_serializer.json_key(key, &new_extra) { | ||
| return Ok(v); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings | ||
| if extra.check == SerCheck::None { | ||
| for err in &errors { | ||
| extra.warnings.custom_warning(err.to_string()); | ||
| } | ||
| } | ||
| // Otherwise, if we've encountered errors, return them to the parent union, which should take | ||
| // care of the formatting for us | ||
| else if !errors.is_empty() { | ||
| let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n"); | ||
| return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); | ||
| } | ||
| infer_json_key(key, extra) | ||
| } | ||
|
|
||
| #[allow(clippy::too_many_arguments)] | ||
| fn serde_serialize<S: serde::ser::Serializer>( | ||
| value: &Bound<'_, PyAny>, | ||
| serializer: S, | ||
| include: Option<&Bound<'_, PyAny>>, | ||
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| choices: &[CombinedSerializer], | ||
| retry_with_lax_check: bool, | ||
| ) -> Result<S::Ok, S::Error> { | ||
| let py = value.py(); | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
| let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); | ||
|
|
||
| for comb_serializer in choices { | ||
| match comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), | ||
| Err(err) => errors.push(err), | ||
| } | ||
| } | ||
|
|
||
| // If extra.check is SerCheck::Strict, we're in a nested union | ||
| if extra.check != SerCheck::Strict && retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| for comb_serializer in choices { | ||
| if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| return infer_serialize(v.bind(py), serializer, None, None, extra); | ||
| if let Some(tag) = discriminator_value { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = lookup.get(&tag_str) { | ||
| let selected_serializer = &choices[serializer_index]; | ||
|
|
||
| match selector(selected_serializer, &new_extra) { | ||
| Ok(v) => return Some(v), | ||
| Err(_) => { | ||
| if retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = selector(selected_serializer, &new_extra) { | ||
| return Some(v); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
sydney-runkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings | ||
| if extra.check == SerCheck::None { | ||
| for err in &errors { | ||
| extra.warnings.custom_warning(err.to_string()); | ||
| } | ||
| } else { | ||
| // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors | ||
| // will have to be returned here | ||
| } | ||
|
|
||
| infer_serialize(value, serializer, include, exclude, extra) | ||
| None | ||
| } | ||
|
|
||
| impl TypeSerializer for UnionSerializer { | ||
|
|
@@ -211,18 +162,23 @@ impl TypeSerializer for UnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> PyResult<PyObject> { | ||
| to_python( | ||
| value, | ||
| include, | ||
| exclude, | ||
| union_serialize( | ||
| |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), | ||
| |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| )? | ||
| } | ||
|
|
||
| fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> { | ||
| json_key(key, extra, &self.choices, self.retry_with_lax_check()) | ||
| union_serialize( | ||
| |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), | ||
| |v| v.map_or_else(|| infer_json_key(key, extra), Ok), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| )? | ||
| } | ||
|
|
||
| fn serde_serialize<S: serde::ser::Serializer>( | ||
|
|
@@ -233,15 +189,22 @@ impl TypeSerializer for UnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> Result<S::Ok, S::Error> { | ||
| serde_serialize( | ||
| value, | ||
| serializer, | ||
| include, | ||
| exclude, | ||
| union_serialize( | ||
| |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), | ||
| |v| { | ||
| infer_serialize( | ||
| v.as_ref().map_or(value, |v| v.bind(value.py())), | ||
| serializer, | ||
| None, | ||
| None, | ||
| extra, | ||
| ) | ||
| }, | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| .map_err(|err| serde::ser::Error::custom(err.to_string()))? | ||
| } | ||
|
|
||
| fn get_name(&self) -> &str { | ||
|
|
@@ -309,62 +272,56 @@ impl TypeSerializer for TaggedUnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> PyResult<PyObject> { | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
|
|
||
| if let Some(tag) = self.get_discriminator_value(value, extra) { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = self.lookup.get(&tag_str) { | ||
| let serializer = &self.choices[serializer_index]; | ||
|
|
||
| match serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| Err(_) => { | ||
| if self.retry_with_lax_check() { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) { | ||
| return Ok(v); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| let to_python_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| { | ||
| comb_serializer.to_python(value, include, exclude, new_extra) | ||
| }; | ||
|
|
||
| to_python( | ||
| value, | ||
| include, | ||
| exclude, | ||
| tagged_union_serialize( | ||
| self.get_discriminator_value(value, extra), | ||
| &self.lookup, | ||
| to_python_selector, | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| .map_or_else( | ||
| || { | ||
| union_serialize( | ||
| to_python_selector, | ||
| |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| )? | ||
| }, | ||
| Ok, | ||
| ) | ||
| } | ||
|
|
||
| fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> { | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
|
|
||
| if let Some(tag) = self.get_discriminator_value(key, extra) { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = self.lookup.get(&tag_str) { | ||
| let serializer = &self.choices[serializer_index]; | ||
|
|
||
| match serializer.json_key(key, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| Err(_) => { | ||
| if self.retry_with_lax_check() { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = serializer.json_key(key, &new_extra) { | ||
| return Ok(v); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| let json_key_selector = | ||
| |comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra); | ||
|
|
||
| json_key(key, extra, &self.choices, self.retry_with_lax_check()) | ||
| tagged_union_serialize( | ||
| self.get_discriminator_value(key, extra), | ||
| &self.lookup, | ||
| json_key_selector, | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| .map_or_else( | ||
| || { | ||
| union_serialize( | ||
| json_key_selector, | ||
| |v| v.map_or_else(|| infer_json_key(key, extra), Ok), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| )? | ||
| }, | ||
| Ok, | ||
| ) | ||
| } | ||
|
|
||
| fn serde_serialize<S: serde::ser::Serializer>( | ||
|
|
@@ -375,38 +332,37 @@ impl TypeSerializer for TaggedUnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> Result<S::Ok, S::Error> { | ||
| let py = value.py(); | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
|
|
||
| if let Some(tag) = self.get_discriminator_value(value, extra) { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = self.lookup.get(&tag_str) { | ||
| let selected_serializer = &self.choices[serializer_index]; | ||
|
|
||
| match selected_serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), | ||
| Err(_) => { | ||
| if self.retry_with_lax_check() { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) { | ||
| return infer_serialize(v.bind(py), serializer, None, None, extra); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| let serde_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| { | ||
| comb_serializer.to_python(value, include, exclude, new_extra) | ||
| }; | ||
|
|
||
| if let Some(v) = tagged_union_serialize( | ||
| None, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sydney-runkle this is the source of the perf regression; accidentally switched off tagged union serialization optimization in the JSON case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😬 oops. Great find! |
||
| &self.lookup, | ||
| serde_selector, | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) { | ||
| return infer_serialize(v.bind(value.py()), serializer, None, None, extra); | ||
| } | ||
|
|
||
| serde_serialize( | ||
| value, | ||
| serializer, | ||
| include, | ||
| exclude, | ||
| union_serialize( | ||
| serde_selector, | ||
| |v| { | ||
| infer_serialize( | ||
| v.as_ref().map_or(value, |v| v.bind(value.py())), | ||
| serializer, | ||
| None, | ||
| None, | ||
| extra, | ||
| ) | ||
| }, | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| .map_err(|err| serde::ser::Error::custom(err.to_string()))? | ||
| } | ||
|
|
||
| fn get_name(&self) -> &str { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.