-
Notifications
You must be signed in to change notification settings - Fork 312
Simplify shared union serializer logic #1538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f037904
0ceac1b
a5fbc2a
742a3e9
92f5bbc
57c3c6f
7ea0d64
f3a304a
ee9dd3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,8 @@ use std::borrow::Cow; | |
| use crate::build_tools::py_schema_err; | ||
| use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; | ||
| use crate::definitions::DefinitionsBuilder; | ||
| use crate::serializers::PydanticSerializationUnexpectedValue; | ||
| use crate::tools::{truncate_safe_repr, SchemaDict}; | ||
| use crate::PydanticSerializationUnexpectedValue; | ||
|
|
||
| use super::{ | ||
| infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck, | ||
|
|
@@ -70,22 +70,23 @@ impl UnionSerializer { | |
|
|
||
| impl_py_gc_traverse!(UnionSerializer { choices }); | ||
|
|
||
| fn to_python( | ||
| value: &Bound<'_, PyAny>, | ||
| include: Option<&Bound<'_, PyAny>>, | ||
| exclude: Option<&Bound<'_, PyAny>>, | ||
| fn union_serialize<S>( | ||
| // if this returns `Ok(Some(v))`, we picked a union variant to serialize, | ||
| // Or `Ok(None)` if we couldn't find a suitable variant to serialize | ||
| // Finally, `Err(err)` if we encountered errors while trying to serialize | ||
| mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>, | ||
| extra: &Extra, | ||
| choices: &[CombinedSerializer], | ||
| retry_with_lax_check: bool, | ||
| ) -> PyResult<PyObject> { | ||
| ) -> PyResult<Option<S>> { | ||
| // try the serializers in left to right order with error_on fallback=true | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
| let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); | ||
|
|
||
| for comb_serializer in choices { | ||
| match comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| match selector(comb_serializer, &new_extra) { | ||
| Ok(v) => return Ok(Some(v)), | ||
| Err(err) => errors.push(err), | ||
| } | ||
| } | ||
|
|
@@ -94,8 +95,8 @@ fn to_python( | |
| if extra.check != SerCheck::Strict && retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| for comb_serializer in choices { | ||
| if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| return Ok(v); | ||
| if let Ok(v) = selector(comb_serializer, &new_extra) { | ||
| return Ok(Some(v)); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -113,94 +114,45 @@ fn to_python( | |
| return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); | ||
| } | ||
|
|
||
| infer_to_python(value, include, exclude, extra) | ||
| Ok(None) | ||
| } | ||
|
|
||
| fn json_key<'a>( | ||
| key: &'a Bound<'_, PyAny>, | ||
| fn tagged_union_serialize<S>( | ||
| discriminator_value: Option<Py<PyAny>>, | ||
| lookup: &HashMap<String, usize>, | ||
| // if this returns `Ok(v)`, we picked a union variant to serialize, where | ||
| // `S` is intermediate state which can be passed on to the finalizer | ||
| mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>, | ||
| extra: &Extra, | ||
| choices: &[CombinedSerializer], | ||
| retry_with_lax_check: bool, | ||
| ) -> PyResult<Cow<'a, str>> { | ||
| ) -> PyResult<Option<S>> { | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
| let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); | ||
|
|
||
| for comb_serializer in choices { | ||
| match comb_serializer.json_key(key, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| Err(err) => errors.push(err), | ||
| } | ||
| } | ||
|
|
||
| // If extra.check is SerCheck::Strict, we're in a nested union | ||
| if extra.check != SerCheck::Strict && retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| for comb_serializer in choices { | ||
| if let Ok(v) = comb_serializer.json_key(key, &new_extra) { | ||
| return Ok(v); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings | ||
| if extra.check == SerCheck::None { | ||
| for err in &errors { | ||
| extra.warnings.custom_warning(err.to_string()); | ||
| } | ||
| } | ||
| // Otherwise, if we've encountered errors, return them to the parent union, which should take | ||
| // care of the formatting for us | ||
| else if !errors.is_empty() { | ||
| let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n"); | ||
| return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); | ||
| } | ||
| infer_json_key(key, extra) | ||
| } | ||
|
|
||
| #[allow(clippy::too_many_arguments)] | ||
| fn serde_serialize<S: serde::ser::Serializer>( | ||
| value: &Bound<'_, PyAny>, | ||
| serializer: S, | ||
| include: Option<&Bound<'_, PyAny>>, | ||
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| choices: &[CombinedSerializer], | ||
| retry_with_lax_check: bool, | ||
| ) -> Result<S::Ok, S::Error> { | ||
| let py = value.py(); | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
| let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); | ||
|
|
||
| for comb_serializer in choices { | ||
| match comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), | ||
| Err(err) => errors.push(err), | ||
| } | ||
| } | ||
|
|
||
| // If extra.check is SerCheck::Strict, we're in a nested union | ||
| if extra.check != SerCheck::Strict && retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| for comb_serializer in choices { | ||
| if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { | ||
| return infer_serialize(v.bind(py), serializer, None, None, extra); | ||
| if let Some(tag) = discriminator_value { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = lookup.get(&tag_str) { | ||
| let selected_serializer = &choices[serializer_index]; | ||
|
|
||
| match selector(selected_serializer, &new_extra) { | ||
| Ok(v) => return Ok(Some(v)), | ||
| Err(_) => { | ||
| if retry_with_lax_check { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = selector(selected_serializer, &new_extra) { | ||
| return Ok(Some(v)); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings | ||
| if extra.check == SerCheck::None { | ||
| for err in &errors { | ||
| extra.warnings.custom_warning(err.to_string()); | ||
| } | ||
| } else { | ||
| // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors | ||
| // will have to be returned here | ||
| } | ||
|
|
||
| infer_serialize(value, serializer, include, exclude, extra) | ||
| // if we haven't returned at this point, we should fallback to the union serializer | ||
| // which preserves the historical expectation that we do our best with serialization | ||
| // even if that means we resort to inference | ||
| union_serialize(selector, extra, choices, retry_with_lax_check) | ||
| } | ||
|
|
||
| impl TypeSerializer for UnionSerializer { | ||
|
|
@@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> PyResult<PyObject> { | ||
| to_python( | ||
| value, | ||
| include, | ||
| exclude, | ||
| union_serialize( | ||
| |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| )? | ||
| .map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) | ||
| } | ||
|
|
||
| fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> { | ||
| json_key(key, extra, &self.choices, self.retry_with_lax_check()) | ||
| union_serialize( | ||
| |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| )? | ||
| .map_or_else(|| infer_json_key(key, extra), Ok) | ||
| } | ||
|
|
||
| fn serde_serialize<S: serde::ser::Serializer>( | ||
|
|
@@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> Result<S::Ok, S::Error> { | ||
| serde_serialize( | ||
| value, | ||
| serializer, | ||
| include, | ||
| exclude, | ||
| match union_serialize( | ||
| |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| ) { | ||
| Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), | ||
| Ok(None) => infer_serialize(value, serializer, include, exclude, extra), | ||
| Err(err) => Err(serde::ser::Error::custom(err.to_string())), | ||
| } | ||
| } | ||
|
|
||
| fn get_name(&self) -> &str { | ||
|
|
@@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> PyResult<PyObject> { | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
|
|
||
| if let Some(tag) = self.get_discriminator_value(value, extra) { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = self.lookup.get(&tag_str) { | ||
| let serializer = &self.choices[serializer_index]; | ||
|
|
||
| match serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| Err(_) => { | ||
| if self.retry_with_lax_check() { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) { | ||
| return Ok(v); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| to_python( | ||
| value, | ||
| include, | ||
| exclude, | ||
| tagged_union_serialize( | ||
| self.get_discriminator_value(value, extra), | ||
| &self.lookup, | ||
| |comb_serializer: &CombinedSerializer, new_extra: &Extra| { | ||
| comb_serializer.to_python(value, include, exclude, new_extra) | ||
| }, | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| )? | ||
| .map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) | ||
| } | ||
|
|
||
| fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> { | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
|
|
||
| if let Some(tag) = self.get_discriminator_value(key, extra) { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = self.lookup.get(&tag_str) { | ||
| let serializer = &self.choices[serializer_index]; | ||
|
|
||
| match serializer.json_key(key, &new_extra) { | ||
| Ok(v) => return Ok(v), | ||
| Err(_) => { | ||
| if self.retry_with_lax_check() { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = serializer.json_key(key, &new_extra) { | ||
| return Ok(v); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| json_key(key, extra, &self.choices, self.retry_with_lax_check()) | ||
| tagged_union_serialize( | ||
| self.get_discriminator_value(key, extra), | ||
| &self.lookup, | ||
| |comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra), | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| )? | ||
| .map_or_else(|| infer_json_key(key, extra), Ok) | ||
| } | ||
|
|
||
| fn serde_serialize<S: serde::ser::Serializer>( | ||
|
|
@@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer { | |
| exclude: Option<&Bound<'_, PyAny>>, | ||
| extra: &Extra, | ||
| ) -> Result<S::Ok, S::Error> { | ||
| let py = value.py(); | ||
| let mut new_extra = extra.clone(); | ||
| new_extra.check = SerCheck::Strict; | ||
|
|
||
| if let Some(tag) = self.get_discriminator_value(value, extra) { | ||
| let tag_str = tag.to_string(); | ||
| if let Some(&serializer_index) = self.lookup.get(&tag_str) { | ||
| let selected_serializer = &self.choices[serializer_index]; | ||
|
|
||
| match selected_serializer.to_python(value, include, exclude, &new_extra) { | ||
| Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), | ||
| Err(_) => { | ||
| if self.retry_with_lax_check() { | ||
| new_extra.check = SerCheck::Lax; | ||
| if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) { | ||
| return infer_serialize(v.bind(py), serializer, None, None, extra); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| serde_serialize( | ||
| value, | ||
| serializer, | ||
| include, | ||
| exclude, | ||
| match tagged_union_serialize( | ||
| None, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sydney-runkle this is the source of the perf regression; accidentally switched off tagged union serialization optimization in the JSON case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😬 oops. Great find! |
||
| &self.lookup, | ||
| |comb_serializer: &CombinedSerializer, new_extra: &Extra| { | ||
| comb_serializer.to_python(value, include, exclude, new_extra) | ||
| }, | ||
| extra, | ||
| &self.choices, | ||
| self.retry_with_lax_check(), | ||
| ) | ||
| ) { | ||
| Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), | ||
| Ok(None) => infer_serialize(value, serializer, include, exclude, extra), | ||
| Err(err) => Err(serde::ser::Error::custom(err.to_string())), | ||
| } | ||
| } | ||
|
|
||
| fn get_name(&self) -> &str { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.