Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 84 additions & 177 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::borrow::Cow;
use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::serializers::PydanticSerializationUnexpectedValue;
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
Expand Down Expand Up @@ -70,22 +70,23 @@ impl UnionSerializer {

impl_py_gc_traverse!(UnionSerializer { choices });

fn to_python(
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
fn union_serialize<S>(
// if this returns `Ok(Some(v))`, we picked a union variant to serialize,
// Or `Ok(None)` if we couldn't find a suitable variant to serialize
// Finally, `Err(err)` if we encountered errors while trying to serialize
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<PyObject> {
) -> PyResult<Option<S>> {
// try the serializers in left to right order with error_on fallback=true
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
match selector(comb_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(err) => errors.push(err),
}
}
Expand All @@ -94,8 +95,8 @@ fn to_python(
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
return Ok(v);
if let Ok(v) = selector(comb_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
Expand All @@ -113,94 +114,45 @@ fn to_python(
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

infer_to_python(value, include, exclude, extra)
Ok(None)
}

fn json_key<'a>(
key: &'a Bound<'_, PyAny>,
fn tagged_union_serialize<S>(
discriminator_value: Option<Py<PyAny>>,
lookup: &HashMap<String, usize>,
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<Cow<'a, str>> {
) -> PyResult<Option<S>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.json_key(key, &new_extra) {
Ok(v) => return Ok(v),
Err(err) => errors.push(err),
}
}

// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
return Ok(v);
}
}
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
}
// Otherwise, if we've encountered errors, return them to the parent union, which should take
// care of the formatting for us
else if !errors.is_empty() {
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}
infer_json_key(key, extra)
}

#[allow(clippy::too_many_arguments)]
fn serde_serialize<S: serde::ser::Serializer>(
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> Result<S::Ok, S::Error> {
let py = value.py();
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(err) => errors.push(err),
}
}

// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
return infer_serialize(v.bind(py), serializer, None, None, extra);
if let Some(tag) = discriminator_value {
let tag_str = tag.to_string();
if let Some(&serializer_index) = lookup.get(&tag_str) {
let selected_serializer = &choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(_) => {
if retry_with_lax_check {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
}
}
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
} else {
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
// will have to be returned here
}

infer_serialize(value, serializer, include, exclude, extra)
// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, choices, retry_with_lax_check)
}

impl TypeSerializer for UnionSerializer {
Expand All @@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
to_python(
value,
include,
exclude,
union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
json_key(key, extra, &self.choices, self.retry_with_lax_check())
union_serialize(
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
serde_serialize(
value,
serializer,
include,
exclude,
match union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}
}

fn get_name(&self) -> &str {
Expand Down Expand Up @@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(value, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let serializer = &self.choices[serializer_index];

match serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) {
return Ok(v);
}
}
}
}
}
}

to_python(
value,
include,
exclude,
tagged_union_serialize(
self.get_discriminator_value(value, extra),
&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(key, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let serializer = &self.choices[serializer_index];

match serializer.json_key(key, &new_extra) {
Ok(v) => return Ok(v),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = serializer.json_key(key, &new_extra) {
return Ok(v);
}
}
}
}
}
}

json_key(key, extra, &self.choices, self.retry_with_lax_check())
tagged_union_serialize(
self.get_discriminator_value(key, extra),
&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let py = value.py();
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(value, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let selected_serializer = &self.choices[serializer_index];

match selected_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) {
return infer_serialize(v.bind(py), serializer, None, None, extra);
}
}
}
}
}
}

serde_serialize(
value,
serializer,
include,
exclude,
match tagged_union_serialize(
None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sydney-runkle this is the source of the perf regression; accidentally switched off tagged union serialization optimization in the JSON case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😬 oops. Great find!

&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}
}

fn get_name(&self) -> &str {
Expand Down
Loading