diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 16e89eb68..aa07ef7b7 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -18,7 +18,7 @@ use crate::build_tools::py_schema_error_type; use crate::errors::LocItem; use crate::get_pydantic_version; use crate::input::InputType; -use crate::serializers::{DuckTypingSerMode, Extra, SerMode, SerializationState}; +use crate::serializers::{Extra, SerMode, SerializationState}; use crate::tools::{safe_repr, write_truncated_to_limited_bytes, SchemaDict}; use super::line_error::ValLineError; @@ -341,17 +341,7 @@ impl ValidationError { include_input: bool, ) -> PyResult> { let state = SerializationState::new("iso8601", "utf8", "constants")?; - let extra = state.extra( - py, - &SerMode::Json, - None, - false, - false, - true, - None, - DuckTypingSerMode::SchemaBased, - None, - ); + let extra = state.extra(py, &SerMode::Json, None, false, false, true, None, false, None); let serializer = ValidationErrorSerializer { py, line_errors: &self.line_errors, diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 6e6786d73..7a574093c 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -2,13 +2,12 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use serde::ser::SerializeMap; -use serde::Serialize; use crate::build_tools::py_schema_error_type; use crate::definitions::DefinitionsBuilder; use crate::py_gc::PyGcTraverse; use crate::serializers::filter::SchemaFilter; -use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer}; +use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSerializer}; use crate::tools::SchemaDict; use super::errors::py_err_se_err; @@ -48,18 +47,31 @@ impl ComputedFields { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult<()> { - if extra.round_trip { - // Do not serialize computed fields - return Ok(()); - } - for computed_field in &self.0 { - let field_extra = Extra { - field_name: Some(computed_field.property_name.as_str()), - ..*extra - }; - computed_field.to_python(model, output_dict, filter, include, exclude, &field_extra)?; - } - Ok(()) + self.serialize_fields( + model, + filter, + include, + exclude, + extra, + |e| e, + |ComputedFieldToSerialize { + computed_field, + value, + include, + exclude, + field_extra, + }| { + let key = match field_extra.serialize_by_alias_or(computed_field.serialize_by_alias) { + true => computed_field.alias_py.bind(model.py()), + false => computed_field.property_name_py.bind(model.py()), + }; + let value = + computed_field + .serializer + .to_python(&value, include.as_ref(), exclude.as_ref(), &field_extra)?; + output_dict.set_item(key, value) + }, + ) } pub fn serde_serialize( @@ -71,6 +83,49 @@ impl ComputedFields { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result<(), S::Error> { + self.serialize_fields( + model, + filter, + include, + exclude, + extra, + py_err_se_err, + |ComputedFieldToSerialize { + computed_field, + value, + include, + exclude, + field_extra, + }| { + let key = match field_extra.serialize_by_alias_or(computed_field.serialize_by_alias) { + true => &computed_field.alias, + false => &computed_field.property_name, + }; + let s = PydanticSerializer::new( + &value, + &computed_field.serializer, + include.as_ref(), + exclude.as_ref(), + &field_extra, + ); + map.serialize_entry(key, &s) + }, + ) + } + + /// Iterate each field for serialization, filtering on + /// `include` and `exclude` if provided. + #[allow(clippy::too_many_arguments)] + fn serialize_fields<'a, 'py, E>( + &'a self, + model: &'a Bound<'py, PyAny>, + filter: &'a SchemaFilter, + include: Option<&'a Bound<'py, PyAny>>, + exclude: Option<&'a Bound<'py, PyAny>>, + extra: &'a Extra, + convert_error: impl FnOnce(PyErr) -> E, + mut serialize: impl FnMut(ComputedFieldToSerialize<'a, 'py>) -> Result<(), E>, + ) -> Result<(), E> { if extra.round_trip { // Do not serialize computed fields return Ok(()); @@ -78,37 +133,46 @@ impl ComputedFields { for computed_field in &self.0 { let property_name_py = computed_field.property_name_py.bind(model.py()); + let (next_include, next_exclude) = match filter.key_filter(property_name_py, include, exclude) { + Ok(Some((next_include, next_exclude))) => (next_include, next_exclude), + Ok(None) => continue, + Err(e) => return Err(convert_error(e)), + }; - if let Some((next_include, next_exclude)) = filter - .key_filter(property_name_py, include, exclude) - .map_err(py_err_se_err)? - { - let value = model.getattr(property_name_py).map_err(py_err_se_err)?; - if extra.exclude_none && value.is_none() { - continue; + let value = match model.getattr(property_name_py) { + Ok(field_value) => field_value, + Err(e) => { + return Err(convert_error(e)); } - let field_extra = Extra { - field_name: Some(computed_field.property_name.as_str()), - ..*extra - }; - let cfs = ComputedFieldSerializer { - model, - computed_field, - include: next_include.as_ref(), - exclude: next_exclude.as_ref(), - extra: &field_extra, - }; - let key = match extra.serialize_by_alias_or(computed_field.serialize_by_alias) { - true => computed_field.alias.as_str(), - false => computed_field.property_name.as_str(), - }; - map.serialize_entry(key, &cfs)?; + }; + if extra.exclude_none && value.is_none() { + continue; } + + let field_extra = Extra { + field_name: Some(&computed_field.property_name), + ..*extra + }; + serialize(ComputedFieldToSerialize { + computed_field, + value, + include: next_include, + exclude: next_exclude, + field_extra, + })?; } Ok(()) } } +struct ComputedFieldToSerialize<'a, 'py> { + computed_field: &'a ComputedField, + value: Bound<'py, PyAny>, + include: Option>, + exclude: Option>, + field_extra: Extra<'a>, +} + #[derive(Debug)] struct ComputedField { property_name: String, @@ -143,44 +207,6 @@ impl ComputedField { serialize_by_alias: config.get_as(intern!(py, "serialize_by_alias"))?, }) } - - fn to_python( - &self, - model: &Bound<'_, PyAny>, - output_dict: &Bound<'_, PyDict>, - filter: &SchemaFilter, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, - ) -> PyResult<()> { - let py = model.py(); - let property_name_py = self.property_name_py.bind(py); - - if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? { - let next_value = model.getattr(property_name_py)?; - - let value = self - .serializer - .to_python(&next_value, next_include.as_ref(), next_exclude.as_ref(), extra)?; - if extra.exclude_none && value.is_none(py) { - return Ok(()); - } - let key = match extra.serialize_by_alias_or(self.serialize_by_alias) { - true => self.alias_py.bind(py), - false => property_name_py, - }; - output_dict.set_item(key, value)?; - } - Ok(()) - } -} - -pub(crate) struct ComputedFieldSerializer<'py> { - model: &'py Bound<'py, PyAny>, - computed_field: &'py ComputedField, - include: Option<&'py Bound<'py, PyAny>>, - exclude: Option<&'py Bound<'py, PyAny>>, - extra: &'py Extra<'py>, } impl_py_gc_traverse!(ComputedField { serializer }); @@ -190,21 +216,3 @@ impl PyGcTraverse for ComputedFields { self.0.py_gc_traverse(visit) } } - -impl_py_gc_traverse!(ComputedFieldSerializer<'_> { computed_field }); - -impl Serialize for ComputedFieldSerializer<'_> { - fn serialize(&self, serializer: S) -> Result { - let py = self.model.py(); - let property_name_py = self.computed_field.property_name_py.bind(py); - let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?; - let s = PydanticSerializer::new( - &next_value, - &self.computed_field.serializer, - self.include, - self.exclude, - self.extra, - ); - s.serialize(serializer) - } -} diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 82c432342..e919beb8c 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -27,45 +27,6 @@ pub(crate) struct SerializationState { config: SerializationConfig, } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum DuckTypingSerMode { - // Don't check the type of the value, use the type of the schema - SchemaBased, - // Check the type of the value, use the type of the value - NeedsInference, - // We already checked the type of the value - // we don't want to infer again, but if we recurse down - // we do want to flip this back to NeedsInference for the - // fields / keys / items of any inner serializers - Inferred, -} - -impl DuckTypingSerMode { - pub fn from_bool(serialize_as_any: bool) -> Self { - if serialize_as_any { - DuckTypingSerMode::NeedsInference - } else { - DuckTypingSerMode::SchemaBased - } - } - - pub fn to_bool(self) -> bool { - match self { - DuckTypingSerMode::SchemaBased => false, - DuckTypingSerMode::NeedsInference => true, - DuckTypingSerMode::Inferred => true, - } - } - - pub fn next_mode(self) -> Self { - match self { - DuckTypingSerMode::SchemaBased => DuckTypingSerMode::SchemaBased, - DuckTypingSerMode::NeedsInference => DuckTypingSerMode::Inferred, - DuckTypingSerMode::Inferred => DuckTypingSerMode::NeedsInference, - } - } -} - impl SerializationState { pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult { let warnings = CollectWarnings::new(WarningsMode::None); @@ -88,7 +49,7 @@ impl SerializationState { round_trip: bool, serialize_unknown: bool, fallback: Option<&'py Bound<'_, PyAny>>, - duck_typing_ser_mode: DuckTypingSerMode, + serialize_as_any: bool, context: Option<&'py Bound<'_, PyAny>>, ) -> Extra<'py> { Extra::new( @@ -104,7 +65,7 @@ impl SerializationState { &self.rec_guard, serialize_unknown, fallback, - duck_typing_ser_mode, + serialize_as_any, context, ) } @@ -137,7 +98,7 @@ pub(crate) struct Extra<'a> { pub field_name: Option<&'a str>, pub serialize_unknown: bool, pub fallback: Option<&'a Bound<'a, PyAny>>, - pub duck_typing_ser_mode: DuckTypingSerMode, + pub serialize_as_any: bool, pub context: Option<&'a Bound<'a, PyAny>>, } @@ -156,7 +117,7 @@ impl<'a> Extra<'a> { rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a Bound<'a, PyAny>>, - duck_typing_ser_mode: DuckTypingSerMode, + serialize_as_any: bool, context: Option<&'a Bound<'a, PyAny>>, ) -> Self { Self { @@ -175,7 +136,7 @@ impl<'a> Extra<'a> { field_name: None, serialize_unknown, fallback, - duck_typing_ser_mode, + serialize_as_any, context, } } @@ -243,7 +204,7 @@ pub(crate) struct ExtraOwned { field_name: Option, serialize_unknown: bool, pub fallback: Option, - duck_typing_ser_mode: DuckTypingSerMode, + serialize_as_any: bool, pub context: Option, } @@ -264,7 +225,7 @@ impl ExtraOwned { field_name: extra.field_name.map(ToString::to_string), serialize_unknown: extra.serialize_unknown, fallback: extra.fallback.map(|model| model.clone().into()), - duck_typing_ser_mode: extra.duck_typing_ser_mode, + serialize_as_any: extra.serialize_as_any, context: extra.context.map(|model| model.clone().into()), } } @@ -286,7 +247,7 @@ impl ExtraOwned { field_name: self.field_name.as_deref(), serialize_unknown: self.serialize_unknown, fallback: self.fallback.as_ref().map(|m| m.bind(py)), - duck_typing_ser_mode: self.duck_typing_ser_mode, + serialize_as_any: self.serialize_as_any, context: self.context.as_ref().map(|m| m.bind(py)), } } diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index d4bc8bb67..a5c5bc6b3 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -8,7 +8,6 @@ use serde::ser::SerializeMap; use smallvec::SmallVec; use crate::serializers::extra::SerCheck; -use crate::serializers::DuckTypingSerMode; use crate::PydanticSerializationUnexpectedValue; use super::computed_fields::ComputedFields; @@ -191,7 +190,7 @@ impl GeneralFieldsSerializer { Some(serializer) => { serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)? } - None => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?, + _ => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?, }; output_dict.set_item(key, value)?; } else if field_extra.check == SerCheck::Strict { @@ -332,20 +331,6 @@ impl TypeSerializer for GeneralFieldsSerializer { // If there is no model, we (a TypedDict) are the model let model = extra.model.map_or_else(|| Some(value), Some); - // If there is no model, use duck typing ser logic for TypedDict - // If there is a model, skip this step, as BaseModel and dataclass duck typing - // is handled in their respective serializers - if extra.model.is_none() { - let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode(); - let td_extra = Extra { - model, - duck_typing_ser_mode, - ..*extra - }; - if td_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred { - return infer_to_python(value, include, exclude, &td_extra); - } - } let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) { main_extra_dict } else { @@ -367,7 +352,7 @@ impl TypeSerializer for GeneralFieldsSerializer { Some(serializer) => { serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), extra)? } - None => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), extra)?, + _ => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), extra)?, }; output_dict.set_item(key, value)?; } @@ -401,20 +386,6 @@ impl TypeSerializer for GeneralFieldsSerializer { // If there is no model, we (a TypedDict) are the model let model = extra.model.map_or_else(|| Some(value), Some); - // If there is no model, use duck typing ser logic for TypedDict - // If there is a model, skip this step, as BaseModel and dataclass duck typing - // is handled in their respective serializers - if extra.model.is_none() { - let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode(); - let td_extra = Extra { - model, - duck_typing_ser_mode, - ..*extra - }; - if td_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred { - return infer_serialize(value, serializer, include, exclude, &td_extra); - } - } let expected_len = match self.mode { FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(), _ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(), diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 70910aeb1..eb067b81d 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -21,7 +21,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError}; use super::extra::{Extra, SerMode}; use super::filter::{AnyFilter, SchemaFilter}; use super::ob_type::ObType; -use super::shared::{any_dataclass_iter, PydanticSerializer, TypeSerializer}; +use super::shared::any_dataclass_iter; use super::SchemaSerializer; pub(crate) fn infer_to_python( @@ -106,10 +106,14 @@ pub(crate) fn infer_to_python_known( extra.rec_guard, extra.serialize_unknown, extra.fallback, - extra.duck_typing_ser_mode, + extra.serialize_as_any, extra.context, ); - serializer.serializer.to_python(value, include, exclude, &extra) + // Avoid falling immediately back into inference because we need to use the serializer + // to drive the next step of serialization + serializer + .serializer + .to_python_no_infer(value, include, exclude, &extra) }; let value = match extra.mode { @@ -489,6 +493,7 @@ pub(crate) fn infer_serialize_known( .getattr(intern!(py, "__pydantic_serializer__")) .map_err(py_err_se_err)?; let extracted_serializer: PyRef = py_serializer.extract().map_err(py_err_se_err)?; + let extra = extracted_serializer.build_extra( py, extra.mode, @@ -501,12 +506,14 @@ pub(crate) fn infer_serialize_known( extra.rec_guard, extra.serialize_unknown, extra.fallback, - extra.duck_typing_ser_mode, + extra.serialize_as_any, extra.context, ); - let pydantic_serializer = - PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); - pydantic_serializer.serialize(serializer) + // Avoid falling immediately back into inference because we need to use the serializer + // to drive the next step of serialization + extracted_serializer + .serializer + .serde_serialize_no_infer(value, serializer, include, exclude, &extra) } ObType::Dataclass => { let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?; diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 5516f81e1..0da5f188f 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -4,6 +4,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyTuple, PyType}; use pyo3::{PyTraverseError, PyVisit}; +use type_serializers::any::AnySerializer; use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; @@ -12,9 +13,9 @@ pub(crate) use config::BytesMode; use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; use extra::{CollectWarnings, SerRecursionState, WarningsMode}; -pub(crate) use extra::{DuckTypingSerMode, Extra, SerMode, SerializationState}; +pub(crate) use extra::{Extra, SerMode, SerializationState}; +use shared::to_json_bytes; pub use shared::CombinedSerializer; -use shared::{to_json_bytes, TypeSerializer}; mod computed_fields; mod config; @@ -63,7 +64,7 @@ impl SchemaSerializer { rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a Bound<'a, PyAny>>, - duck_typing_ser_mode: DuckTypingSerMode, + serialize_as_any: bool, context: Option<&'a Bound<'a, PyAny>>, ) -> Extra<'b> { Extra::new( @@ -79,7 +80,7 @@ impl SchemaSerializer { rec_guard, serialize_unknown, fallback, - duck_typing_ser_mode, + serialize_as_any, context, ) } @@ -133,7 +134,6 @@ impl SchemaSerializer { }; let warnings = CollectWarnings::new(warnings_mode); let rec_guard = SerRecursionState::default(); - let duck_typing_ser_mode = DuckTypingSerMode::from_bool(serialize_as_any); let extra = self.build_extra( py, &mode, @@ -146,7 +146,7 @@ impl SchemaSerializer { &rec_guard, false, fallback, - duck_typing_ser_mode, + serialize_as_any, context, ); let v = self.serializer.to_python(value, include, exclude, &extra)?; @@ -181,7 +181,6 @@ impl SchemaSerializer { }; let warnings = CollectWarnings::new(warnings_mode); let rec_guard = SerRecursionState::default(); - let duck_typing_ser_mode = DuckTypingSerMode::from_bool(serialize_as_any); let extra = self.build_extra( py, &SerMode::Json, @@ -194,7 +193,7 @@ impl SchemaSerializer { &rec_guard, false, fallback, - duck_typing_ser_mode, + serialize_as_any, context, ); let bytes = to_json_bytes( @@ -261,7 +260,6 @@ pub fn to_json( context: Option<&Bound<'_, PyAny>>, ) -> PyResult { let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; - let duck_typing_ser_mode = DuckTypingSerMode::from_bool(serialize_as_any); let extra = state.extra( py, &SerMode::Json, @@ -270,11 +268,10 @@ pub fn to_json( round_trip, serialize_unknown, fallback, - duck_typing_ser_mode, + serialize_as_any, context, ); - let serializer = type_serializers::any::AnySerializer.into(); - let bytes = to_json_bytes(value, &serializer, include, exclude, &extra, indent, 1024)?; + let bytes = to_json_bytes(value, AnySerializer::get(), include, exclude, &extra, indent, 1024)?; state.final_check(py)?; let py_bytes = PyBytes::new(py, &bytes); Ok(py_bytes.into()) @@ -302,7 +299,6 @@ pub fn to_jsonable_python( context: Option<&Bound<'_, PyAny>>, ) -> PyResult { let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; - let duck_typing_ser_mode = DuckTypingSerMode::from_bool(serialize_as_any); let extra = state.extra( py, &SerMode::Json, @@ -311,7 +307,7 @@ pub fn to_jsonable_python( round_trip, serialize_unknown, fallback, - duck_typing_ser_mode, + serialize_as_any, context, ); let v = infer::infer_to_python(value, include, exclude, &extra)?; diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs index 3f1cf1c68..323dbe2e9 100644 --- a/src/serializers/prebuilt.rs +++ b/src/serializers/prebuilt.rs @@ -39,11 +39,11 @@ impl TypeSerializer for PrebuiltSerializer { self.schema_serializer .get() .serializer - .to_python(value, include, exclude, extra) + .to_python_no_infer(value, include, exclude, extra) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - self.schema_serializer.get().serializer.json_key(key, extra) + self.schema_serializer.get().serializer.json_key_no_infer(key, extra) } fn serde_serialize( @@ -57,7 +57,7 @@ impl TypeSerializer for PrebuiltSerializer { self.schema_serializer .get() .serializer - .serde_serialize(value, serializer, include, exclude, extra) + .serde_serialize_no_infer(value, serializer, include, exclude, extra) } fn get_name(&self) -> &str { diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index e28ae9cee..1255e1128 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -20,7 +20,7 @@ use crate::tools::{py_err, SchemaDict}; use super::errors::se_err_py_err; use super::extra::Extra; -use super::infer::infer_json_key; +use super::infer::{infer_json_key, infer_serialize, infer_to_python}; use super::ob_type::{IsType, ObType}; pub(crate) trait BuildSerializer: Sized { @@ -221,6 +221,74 @@ impl CombinedSerializer { Self::find_serializer(type_, schema, config, definitions) } + + /// Main recursive way to call serializers, supports possible recursive type inference by + /// switching to type inference mode eagerly. + pub fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + if extra.serialize_as_any { + infer_to_python(value, include, exclude, extra) + } else { + self.to_python_no_infer(value, include, exclude, extra) + } + } + + /// Variant of the above which does not fall back to inference mode immediately + #[inline] + pub fn to_python_no_infer( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + TypeSerializer::to_python(self, value, include, exclude, extra) + } + + pub fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + if extra.serialize_as_any { + infer_json_key(key, extra) + } else { + self.json_key_no_infer(key, extra) + } + } + + #[inline] + pub fn json_key_no_infer<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + TypeSerializer::json_key(self, key, extra) + } + + pub fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + if extra.serialize_as_any { + infer_serialize(value, serializer, include, exclude, extra) + } else { + self.serde_serialize_no_infer(value, serializer, include, exclude, extra) + } + } + + #[inline] + pub fn serde_serialize_no_infer( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + TypeSerializer::serde_serialize(self, value, serializer, include, exclude, extra) + } } impl BuildSerializer for CombinedSerializer { diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index a080a9b4f..996a80d72 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -8,7 +8,6 @@ use serde::ser::SerializeMap; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; use crate::definitions::DefinitionsBuilder; -use crate::serializers::DuckTypingSerMode; use crate::tools::SchemaDict; use super::{ @@ -141,19 +140,11 @@ impl TypeSerializer for DataclassSerializer { extra: &Extra, ) -> PyResult { let model = Some(value); - let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode(); - let dc_extra = Extra { - model, - duck_typing_ser_mode, - ..*extra - }; - if dc_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred { - return infer_to_python(value, include, exclude, &dc_extra); - } + let dc_extra = Extra { model, ..*extra }; if self.allow_value(value, &dc_extra)? { let py = value.py(); if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { - let output_dict = fields_serializer.main_to_python( + let output_dict: Bound = fields_serializer.main_to_python( py, known_dataclass_iter(&self.fields, value), include, @@ -190,16 +181,8 @@ impl TypeSerializer for DataclassSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode(); let model = Some(value); - let dc_extra = Extra { - model, - duck_typing_ser_mode, - ..*extra - }; - if dc_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred { - return infer_serialize(value, serializer, include, exclude, &dc_extra); - } + let dc_extra = Extra { model, ..*extra }; if self.allow_value(value, &dc_extra).map_err(py_err_se_err)? { if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { let expected_len = self.fields.len() + fields_serializer.computed_field_count(); diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 38fd7cea5..4f517b6f2 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -94,12 +94,12 @@ impl TypeSerializer for DefinitionRefSerializer { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); let mut guard = extra.recursion_guard(value, self.definition.id())?; - comb_serializer.to_python(value, include, exclude, guard.state()) + comb_serializer.to_python_no_infer(value, include, exclude, guard.state()) }) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - self.definition.read(|s| s.unwrap().json_key(key, extra)) + self.definition.read(|s| s.unwrap().json_key_no_infer(key, extra)) } fn serde_serialize( @@ -115,7 +115,7 @@ impl TypeSerializer for DefinitionRefSerializer { let mut guard = extra .recursion_guard(value, self.definition.id()) .map_err(py_err_se_err)?; - comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state()) + comb_serializer.serde_serialize_no_infer(value, serializer, include, exclude, guard.state()) }) } diff --git a/src/serializers/type_serializers/function.rs b/src/serializers/type_serializers/function.rs index 3c66583e2..9a351694c 100644 --- a/src/serializers/type_serializers/function.rs +++ b/src/serializers/type_serializers/function.rs @@ -497,6 +497,10 @@ impl SerializationCallable { value: &Bound<'_, PyAny>, index_key: Option<&Bound<'_, PyAny>>, ) -> PyResult> { + // NB wrap serializers have strong coupling to their inner type, + // so use to_python_no_infer so that type inference can't apply + // at this layer + let include = self.include.as_ref().map(|o| o.bind(py)); let exclude = self.exclude.as_ref().map(|o| o.bind(py)); let extra = self.extra_owned.to_extra(py); @@ -508,16 +512,16 @@ impl SerializationCallable { self.filter.key_filter(index_key, include, exclude)? }; if let Some((next_include, next_exclude)) = filter { - let v = self - .serializer - .to_python(value, next_include.as_ref(), next_exclude.as_ref(), &extra)?; + let v = + self.serializer + .to_python_no_infer(value, next_include.as_ref(), next_exclude.as_ref(), &extra)?; extra.warnings.final_check(py)?; Ok(Some(v)) } else { Err(PydanticOmit::new_err()) } } else { - let v = self.serializer.to_python(value, include, exclude, &extra)?; + let v = self.serializer.to_python_no_infer(value, include, exclude, &extra)?; extra.warnings.final_check(py)?; Ok(Some(v)) } @@ -581,7 +585,7 @@ impl SerializationInfo { exclude_none: extra.exclude_none, round_trip: extra.round_trip, field_name: Some(field_name.to_string()), - serialize_as_any: extra.duck_typing_ser_mode.to_bool(), + serialize_as_any: extra.serialize_as_any, }), _ => Err(PyRuntimeError::new_err( "Model field context expected for field serialization info but no model field was found", @@ -599,7 +603,7 @@ impl SerializationInfo { exclude_none: extra.exclude_none, round_trip: extra.round_trip, field_name: None, - serialize_as_any: extra.duck_typing_ser_mode.to_bool(), + serialize_as_any: extra.serialize_as_any, }) } } diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 4bae243fc..21f7ce024 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -16,7 +16,6 @@ use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; use crate::definitions::DefinitionsBuilder; use crate::serializers::errors::PydanticSerializationUnexpectedValue; -use crate::serializers::extra::DuckTypingSerMode; use crate::tools::SchemaDict; const ROOT_FIELD: &str = "root"; @@ -171,16 +170,8 @@ impl TypeSerializer for ModelSerializer { extra: &Extra, ) -> PyResult { let model = Some(value); - let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode(); - let model_extra = Extra { - model, - duck_typing_ser_mode, - ..*extra - }; - if model_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred { - return infer_to_python(value, include, exclude, &model_extra); - } + let model_extra = Extra { model, ..*extra }; if self.root_model { let field_name = Some(ROOT_FIELD); let root_extra = Extra { @@ -198,7 +189,10 @@ impl TypeSerializer for ModelSerializer { self.serializer.to_python(&root, include, exclude, &root_extra) } else if self.allow_value(value, &model_extra)? { let inner_value = self.get_inner_value(value, &model_extra)?; - self.serializer.to_python(&inner_value, include, exclude, &model_extra) + // There is strong coupling between a model serializer and its child, we should + // not fall back to type inference in the midddle. + self.serializer + .to_python_no_infer(&inner_value, include, exclude, &model_extra) } else { extra.warnings.on_fallback_py(self.get_name(), value, &model_extra)?; infer_to_python(value, include, exclude, &model_extra) @@ -223,15 +217,7 @@ impl TypeSerializer for ModelSerializer { extra: &Extra, ) -> Result { let model = Some(value); - let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode(); - let model_extra = Extra { - model, - duck_typing_ser_mode, - ..*extra - }; - if model_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred { - return infer_serialize(value, serializer, include, exclude, &model_extra); - } + let model_extra = Extra { model, ..*extra }; if self.root_model { let field_name = Some(ROOT_FIELD); let root_extra = Extra { @@ -244,8 +230,10 @@ impl TypeSerializer for ModelSerializer { .serde_serialize(&root, serializer, include, exclude, &root_extra) } else if self.allow_value(value, &model_extra).map_err(py_err_se_err)? { let inner_value = self.get_inner_value(value, &model_extra).map_err(py_err_se_err)?; + // There is strong coupling between a model serializer and its child, we should + // not fall back to type inference in the midddle. self.serializer - .serde_serialize(&inner_value, serializer, include, exclude, &model_extra) + .serde_serialize_no_infer(&inner_value, serializer, include, exclude, &model_extra) } else { extra .warnings diff --git a/tests/serializers/test_serialize_as_any.py b/tests/serializers/test_serialize_as_any.py index 1d1be238b..ee841ffa8 100644 --- a/tests/serializers/test_serialize_as_any.py +++ b/tests/serializers/test_serialize_as_any.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional from typing_extensions import TypedDict @@ -155,3 +156,217 @@ class Other: 'x': 1, 'y': 'hopefully not a secret', } + + +def test_serialize_as_any_with_nested_models() -> None: + class Parent: + x: int + + class Other(Parent): + y: str + + class Outer: + p: Parent + + Parent.__pydantic_core_schema__ = core_schema.model_schema( + Parent, + core_schema.model_fields_schema( + { + 'x': core_schema.model_field(core_schema.int_schema()), + } + ), + ref='Parent', + ) + Parent.__pydantic_validator__ = SchemaValidator(Parent.__pydantic_core_schema__) + Parent.__pydantic_serializer__ = SchemaSerializer(Parent.__pydantic_core_schema__) + + Other.__pydantic_core_schema__ = core_schema.model_schema( + Other, + core_schema.model_fields_schema( + { + 'x': core_schema.model_field(core_schema.int_schema()), + 'y': core_schema.model_field(core_schema.str_schema()), + } + ), + config=core_schema.CoreConfig(extra_fields_behavior='allow'), + ) + Other.__pydantic_validator__ = SchemaValidator(Other.__pydantic_core_schema__) + Other.__pydantic_serializer__ = SchemaSerializer(Other.__pydantic_core_schema__) + + Outer.__pydantic_core_schema__ = core_schema.definitions_schema( + core_schema.model_schema( + Outer, + core_schema.model_fields_schema( + { + 'p': core_schema.model_field(core_schema.definition_reference_schema('Parent')), + } + ), + ), + [ + Parent.__pydantic_core_schema__, + ], + ) + Outer.__pydantic_validator__ = SchemaValidator(Outer.__pydantic_core_schema__) + Outer.__pydantic_serializer__ = SchemaSerializer(Outer.__pydantic_core_schema__) + + other = Other.__pydantic_validator__.validate_python({'x': 1, 'y': 'hopefully not a secret'}) + outer = Outer() + outer.p = other + + assert Outer.__pydantic_serializer__.to_python(outer, serialize_as_any=False) == { + 'p': {'x': 1}, + } + assert Outer.__pydantic_serializer__.to_python(outer, serialize_as_any=True) == { + 'p': { + 'x': 1, + 'y': 'hopefully not a secret', + } + } + + assert Outer.__pydantic_serializer__.to_json(outer, serialize_as_any=False) == b'{"p":{"x":1}}' + assert ( + Outer.__pydantic_serializer__.to_json(outer, serialize_as_any=True) + == b'{"p":{"x":1,"y":"hopefully not a secret"}}' + ) + + +def test_serialize_with_recursive_models() -> None: + class Node: + next: Optional['Node'] = None + value: int = 42 + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('Node'), + [ + core_schema.model_schema( + Node, + core_schema.model_fields_schema( + { + 'value': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=42) + ), + 'next': core_schema.model_field( + core_schema.with_default_schema( + core_schema.nullable_schema(core_schema.definition_reference_schema('Node')), + default=None, + ) + ), + } + ), + ref='Node', + ) + ], + ) + + Node.__pydantic_core_schema__ = schema + Node.__pydantic_validator__ = SchemaValidator(Node.__pydantic_core_schema__) + Node.__pydantic_serializer__ = SchemaSerializer(Node.__pydantic_core_schema__) + other = Node.__pydantic_validator__.validate_python({'next': {'value': 4}}) + + assert Node.__pydantic_serializer__.to_python(other, serialize_as_any=False) == { + 'next': {'next': None, 'value': 4}, + 'value': 42, + } + assert Node.__pydantic_serializer__.to_python(other, serialize_as_any=True) == { + 'next': {'next': None, 'value': 4}, + 'value': 42, + } + + +def test_serialize_with_custom_type_and_subclasses(): + class Node: + x: int + + Node.__pydantic_core_schema__ = core_schema.model_schema( + Node, + core_schema.model_fields_schema( + { + 'x': core_schema.model_field(core_schema.int_schema()), + } + ), + ref='Node', + ) + Node.__pydantic_validator__ = SchemaValidator(Node.__pydantic_core_schema__) + Node.__pydantic_serializer__ = SchemaSerializer(Node.__pydantic_core_schema__) + + class NodeSubClass(Node): + y: int + + NodeSubClass.__pydantic_core_schema__ = core_schema.model_schema( + NodeSubClass, + core_schema.model_fields_schema( + { + 'x': core_schema.model_field(core_schema.int_schema()), + 'y': core_schema.model_field(core_schema.int_schema()), + } + ), + ) + NodeSubClass.__pydantic_validator__ = SchemaValidator(NodeSubClass.__pydantic_core_schema__) + NodeSubClass.__pydantic_serializer__ = SchemaSerializer(NodeSubClass.__pydantic_core_schema__) + + class CustomType: + values: list[Node] + + CustomType.__pydantic_core_schema__ = core_schema.model_schema( + CustomType, + core_schema.definitions_schema( + core_schema.model_fields_schema( + { + 'values': core_schema.model_field( + core_schema.list_schema(core_schema.definition_reference_schema('Node')) + ), + } + ), + [ + Node.__pydantic_core_schema__, + ], + ), + ) + CustomType.__pydantic_validator__ = SchemaValidator(CustomType.__pydantic_core_schema__) + CustomType.__pydantic_serializer__ = SchemaSerializer(CustomType.__pydantic_core_schema__) + + value = CustomType.__pydantic_validator__.validate_python({'values': [{'x': 1}, {'x': 2}]}) + value.values.append(NodeSubClass.__pydantic_validator__.validate_python({'x': 3, 'y': 4})) + + assert CustomType.__pydantic_serializer__.to_python(value, serialize_as_any=False) == { + 'values': [{'x': 1}, {'x': 2}, {'x': 3}], + } + assert CustomType.__pydantic_serializer__.to_python(value, serialize_as_any=True) == { + 'values': [{'x': 1}, {'x': 2}, {'x': 3, 'y': 4}], + } + + +def test_serialize_as_any_wrap_serializer_applied_once() -> None: + # https://github.com/pydantic/pydantic/issues/11139 + + class InnerModel: + an_inner_field: int + + InnerModel.__pydantic_core_schema__ = core_schema.model_schema( + InnerModel, + core_schema.model_fields_schema({'an_inner_field': core_schema.model_field(core_schema.int_schema())}), + ) + InnerModel.__pydantic_validator__ = SchemaValidator(InnerModel.__pydantic_core_schema__) + InnerModel.__pydantic_serializer__ = SchemaSerializer(InnerModel.__pydantic_core_schema__) + + class MyModel: + a_field: InnerModel + + def a_model_serializer(self, handler, info): + return {k + '_wrapped': v for k, v in handler(self).items()} + + MyModel.__pydantic_core_schema__ = core_schema.model_schema( + MyModel, + core_schema.model_fields_schema({'a_field': core_schema.model_field(InnerModel.__pydantic_core_schema__)}), + serialization=core_schema.wrap_serializer_function_ser_schema( + MyModel.a_model_serializer, + info_arg=True, + ), + ) + MyModel.__pydantic_validator__ = SchemaValidator(MyModel.__pydantic_core_schema__) + MyModel.__pydantic_serializer__ = SchemaSerializer(MyModel.__pydantic_core_schema__) + + instance = MyModel.__pydantic_validator__.validate_python({'a_field': {'an_inner_field': 1}}) + assert MyModel.__pydantic_serializer__.to_python(instance, serialize_as_any=True) == { + 'a_field_wrapped': {'an_inner_field': 1}, + }