diff --git a/src/validators/any.rs b/src/validators/any.rs index 93b0761db..732849271 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -45,4 +45,8 @@ impl Validator for AnyValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index ad13e6ade..f9857128e 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -18,7 +18,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy)] enum VarKwargsMode { Uniform, UnpackedTypedDict, @@ -45,7 +45,7 @@ struct Parameter { name: String, kwarg_key: Option>, validator: Arc, - lookup_key_collection: LookupKeyCollection, + lookup_key_collection: Arc, mode: String, } @@ -129,7 +129,7 @@ impl BuildValidator for ArgumentsValidator { } let validation_alias = arg.get_item(intern!(py, "alias"))?; - let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, name.as_str())?; + let lookup_key_collection = Arc::new(LookupKeyCollection::new(py, validation_alias, name.as_str())?); parameters.push(Parameter { positional, @@ -404,4 +404,58 @@ impl Validator for ArgumentsValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + let mut children: Vec<&Arc> = self.parameters.iter().map(|param| ¶m.validator).collect(); + if let Some(var_args_validator) = &self.var_args_validator { + children.push(var_args_validator); + } + if let Some(var_kwargs_validator) = &self.var_kwargs_validator { + children.push(var_kwargs_validator); + } + children + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + let expected_len = self.parameters.len() + + self.var_args_validator.as_ref().map_or(0, |_| 1) + + self.var_kwargs_validator.as_ref().map_or(0, |_| 1); + + if children.len() != expected_len { + return py_schema_err!("Expected {} children, got {}", expected_len, children.len()); + } + + let mut child_iter = children.into_iter(); + Ok(CombinedValidator::Arguments(Self { + parameters: self + .parameters + .iter() + .map(|param| Parameter { + positional: param.positional, + name: param.name.clone(), + kwarg_key: param.kwarg_key.clone(), + validator: child_iter.next().unwrap(), + lookup_key_collection: param.lookup_key_collection.clone(), + mode: param.mode.clone(), + }) + .collect(), + positional_params_count: self.positional_params_count, + var_args_validator: if self.var_args_validator.is_some() { + Some(child_iter.next().unwrap()) + } else { + None + }, + var_kwargs_mode: self.var_kwargs_mode, + var_kwargs_validator: if self.var_kwargs_validator.is_some() { + Some(child_iter.next().unwrap()) + } else { + None + }, + loc_by_alias: self.loc_by_alias, + extra: self.extra, + validate_by_alias: self.validate_by_alias, + validate_by_name: self.validate_by_name, + }) + .into()) + } } diff --git a/src/validators/arguments_v3.rs b/src/validators/arguments_v3.rs index 8e5781051..af13a9e73 100644 --- a/src/validators/arguments_v3.rs +++ b/src/validators/arguments_v3.rs @@ -22,7 +22,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy)] enum ParameterMode { PositionalOnly, PositionalOrKeyword, @@ -52,7 +52,7 @@ impl FromStr for ParameterMode { struct Parameter { name: String, mode: ParameterMode, - lookup_key_collection: LookupKeyCollection, + lookup_key_collection: Arc, validator: Arc, } @@ -185,7 +185,7 @@ impl BuildValidator for ArgumentsV3Validator { } let validation_alias = arg.get_item(intern!(py, "alias"))?; - let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, name.as_str())?; + let lookup_key_collection = Arc::new(LookupKeyCollection::new(py, validation_alias, name.as_str())?); parameters.push(Parameter { name, @@ -783,4 +783,34 @@ impl Validator for ArgumentsV3Validator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + self.parameters.iter().map(|p| &p.validator).collect() + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != self.parameters.len() { + return py_schema_err!("Expected {} children, got {}", self.parameters.len(), children.len()); + } + + Ok(CombinedValidator::ArgumentsV3(Self { + parameters: self + .parameters + .iter() + .zip(children.into_iter()) + .map(|(p, v)| Parameter { + name: p.name.clone(), + mode: p.mode, + lookup_key_collection: p.lookup_key_collection.clone(), + validator: v, + }) + .collect(), + positional_params_count: self.positional_params_count, + loc_by_alias: self.loc_by_alias, + extra: self.extra, + validate_by_alias: self.validate_by_alias, + validate_by_name: self.validate_by_name, + }) + .into()) + } } diff --git a/src/validators/bool.rs b/src/validators/bool.rs index 094343791..969920c71 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -55,4 +55,8 @@ impl Validator for BoolValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 6092b0771..08ddaaeba 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -60,6 +60,10 @@ impl Validator for BytesValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } #[derive(Debug, Clone)] @@ -112,6 +116,10 @@ impl Validator for BytesConstrainedValidator { fn get_name(&self) -> &'static str { "constrained-bytes" } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } impl BytesConstrainedValidator { diff --git a/src/validators/call.rs b/src/validators/call.rs index 48509a92f..aee7dc973 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -6,6 +6,7 @@ use pyo3::prelude::*; use pyo3::types::PyString; use pyo3::types::{PyDict, PyTuple}; +use crate::build_tools::py_schema_err; use crate::errors::ValResult; use crate::input::Input; @@ -106,4 +107,36 @@ impl Validator for CallValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + let mut children = vec![&self.arguments_validator]; + if let Some(return_validator) = &self.return_validator { + children.push(return_validator); + } + children + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != if self.return_validator.is_some() { 2 } else { 1 } { + return py_schema_err!( + "Expected {} children, got {}", + if self.return_validator.is_some() { 2 } else { 1 }, + children.len() + ); + } + + let mut child_iter = children.into_iter(); + + Ok(CombinedValidator::FunctionCall(Self { + function: self.function.clone(), + arguments_validator: child_iter.next().unwrap(), + return_validator: if self.return_validator.is_some() { + Some(child_iter.next().unwrap()) + } else { + None + }, + name: self.name.clone(), + }) + .into()) + } } diff --git a/src/validators/callable.rs b/src/validators/callable.rs index fc421af61..3ed69a0ea 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -48,4 +48,8 @@ impl Validator for CallableValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } diff --git a/src/validators/chain.rs b/src/validators/chain.rs index b0a6d9979..4ba3ca020 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -35,22 +35,7 @@ impl BuildValidator for ChainValidator { .flatten() .collect(); - match steps.len() { - 0 => py_schema_err!("One or more steps are required for a chain validator"), - 1 => { - let step = steps.into_iter().next().unwrap(); - Ok(step) - } - _ => { - let descr = steps.iter().map(|v| v.get_name()).collect::>().join(","); - - Ok(CombinedValidator::Chain(Self { - steps, - name: format!("{}[{descr}]", Self::EXPECTED_TYPE), - }) - .into()) - } - } + ChainValidator::from_steps(steps) } } @@ -71,6 +56,27 @@ fn build_validator_steps( impl_py_gc_traverse!(ChainValidator { steps }); +impl ChainValidator { + fn from_steps(steps: Vec>) -> PyResult> { + match steps.len() { + 0 => py_schema_err!("One or more steps are required for a chain validator"), + 1 => { + let step = steps.into_iter().next().unwrap(); + Ok(step) + } + _ => { + let descr = steps.iter().map(|v| v.get_name()).collect::>().join(","); + + Ok(CombinedValidator::Chain(Self { + steps, + name: format!("{}[{descr}]", Self::EXPECTED_TYPE), + }) + .into()) + } + } + } +} + impl Validator for ChainValidator { fn validate<'py>( &self, @@ -88,4 +94,12 @@ impl Validator for ChainValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + self.steps.iter().collect() + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + ChainValidator::from_steps(children) + } } diff --git a/src/validators/complex.rs b/src/validators/complex.rs index 2e6cc3c50..dc986a0c0 100644 --- a/src/validators/complex.rs +++ b/src/validators/complex.rs @@ -61,6 +61,10 @@ impl Validator for ComplexValidator { fn get_name(&self) -> &'static str { "complex" } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } pub(crate) fn string_to_complex<'py>( diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index 426a191ed..d31dddf4c 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -104,4 +104,20 @@ impl Validator for CustomErrorValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Exactly one child is required for a custom-error validator"); + } + Ok(CombinedValidator::CustomError(Self { + validator: children.into_iter().next().unwrap(), + custom_error: self.custom_error.clone(), + name: self.name.clone(), + }) + .into()) + } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index dbd797f12..7a891a1b1 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -29,7 +29,7 @@ struct Field { name_py: Py, init: bool, init_only: bool, - lookup_key_collection: LookupKeyCollection, + lookup_key_collection: Arc, validator: Arc, frozen: bool, } @@ -96,7 +96,7 @@ impl BuildValidator for DataclassArgsValidator { } let validation_alias = field.get_item(intern!(py, "validation_alias"))?; - let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, name.as_str())?; + let lookup_key_collection = Arc::new(LookupKeyCollection::new(py, validation_alias, name.as_str())?); fields.push(Field { kw_only, @@ -449,6 +449,55 @@ impl Validator for DataclassArgsValidator { fn get_name(&self) -> &str { &self.validator_name } + + fn children(&self) -> Vec<&Arc> { + let mut children: Vec<&Arc> = self.fields.iter().map(|f| &f.validator).collect(); + if let Some(extras_validator) = &self.extras_validator { + children.push(extras_validator); + } + children + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + let expected_len = self.fields.len() + if self.extras_validator.is_some() { 1 } else { 0 }; + if children.len() != expected_len { + return py_schema_err!("Expected {} children, got {}", expected_len, children.len()); + } + let mut children_iter = children.into_iter(); + + let fields = self + .fields + .iter() + .map(|f| Field { + kw_only: f.kw_only, + name: f.name.clone(), + name_py: f.name_py.clone(), + lookup_key_collection: f.lookup_key_collection.clone(), + validator: children_iter.next().unwrap(), + init: f.init, + init_only: f.init_only, + frozen: f.frozen, + }) + .collect(); + + Ok(CombinedValidator::DataclassArgs(Self { + fields, + positional_count: self.positional_count, + init_only_count: self.init_only_count, + dataclass_name: self.dataclass_name.clone(), + validator_name: self.validator_name.clone(), + extra_behavior: self.extra_behavior, + extras_validator: if self.extras_validator.is_some() { + Some(children_iter.next().unwrap()) + } else { + None + }, + loc_by_alias: self.loc_by_alias, + validate_by_alias: self.validate_by_alias, + validate_by_name: self.validate_by_name, + }) + .into()) + } } #[derive(Debug)] @@ -621,6 +670,30 @@ impl Validator for DataclassValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Expected 1 child, got {}", children.len()); + } + + Ok(CombinedValidator::Dataclass(Self { + strict: self.strict, + validator: children.into_iter().next().unwrap(), + class: self.class.clone(), + generic_origin: self.generic_origin.clone(), + fields: self.fields.clone(), + post_init: self.post_init.clone(), + revalidate: self.revalidate, + name: self.name.clone(), + frozen: self.frozen, + slots: self.slots, + }) + .into()) + } } impl DataclassValidator { diff --git a/src/validators/date.rs b/src/validators/date.rs index 4836ec73c..a0fbc6c82 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -107,6 +107,10 @@ impl Validator for DateValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } /// In lax mode, if the input is not a date, we try parsing the input as a datetime, then check it is an diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 23fdfe6a4..92dd2be09 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -140,6 +140,10 @@ impl Validator for DateTimeValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } /// In lax mode, if the input is not a datetime, we try parsing the input as a date and add the "00:00:00" time. diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 56f0aa766..b478932ba 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -277,6 +277,10 @@ impl Validator for DecimalValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } pub(crate) fn create_decimal<'py>(arg: &Bound<'py, PyAny>, input: impl ToErrorValue) -> ValResult> { diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 387d329ab..e164481aa 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -51,6 +51,10 @@ impl DefinitionRefValidator { pub fn new(definition: DefinitionRef>) -> Self { Self { definition } } + + pub fn definition(&self) -> &DefinitionRef> { + &self.definition + } } impl BuildValidator for DefinitionRefValidator { @@ -126,6 +130,12 @@ impl Validator for DefinitionRefValidator { fn get_name(&self) -> &str { self.definition.get_or_init_name(|v| v.get_name().into()) } + + fn children(&self) -> Vec<&Arc> { + // deliberately return empty to avoid circular walk during optimization passes + // TODO this may not be correct + vec![] + } } fn py_identity(obj: &Bound<'_, PyAny>) -> usize { diff --git a/src/validators/dict.rs b/src/validators/dict.rs index b204e1f88..b7b566ef9 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -5,6 +5,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::is_strict; +use crate::build_tools::py_schema_err; use crate::errors::{LocItem, ValError, ValLineError, ValResult}; use crate::input::BorrowInput; use crate::input::ConsumeIterator; @@ -92,6 +93,26 @@ impl Validator for DictValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.key_validator, &self.value_validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 2 { + return py_schema_err!("DictValidator expected 2 children, got {}", children.len()); + } + Ok(CombinedValidator::Dict(Self { + strict: self.strict, + key_validator: children[0].clone(), + value_validator: children[1].clone(), + min_length: self.min_length, + max_length: self.max_length, + fail_fast: self.fail_fast, + name: self.name.clone(), + }) + .into()) + } } struct ValidateToDict<'a, 's, 'py, I: Input<'py> + ?Sized> { diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index febb16cb7..7622f1ab6 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -164,6 +164,10 @@ impl Validator for EnumValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } #[derive(Debug, Clone)] diff --git a/src/validators/float.rs b/src/validators/float.rs index 59ea26801..622f1eda3 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -82,6 +82,10 @@ impl Validator for FloatValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } #[derive(Debug, Clone)] @@ -173,6 +177,10 @@ impl Validator for ConstrainedFloatValidator { fn get_name(&self) -> &'static str { "constrained-float" } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } impl BuildValidator for ConstrainedFloatValidator { diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index d1fdd6449..c4d0a543a 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use pyo3::types::{PyDict, PyFrozenSet}; use pyo3::{prelude::*, IntoPyObjectExt}; +use crate::build_tools::py_schema_err; use crate::errors::ValResult; use crate::input::{validate_iter_to_set, BorrowInput, ConsumeIterator, Input, ValidatedSet}; use crate::tools::SchemaDict; @@ -54,6 +55,25 @@ impl Validator for FrozenSetValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.item_validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("FrozenSetValidator expected 1 child, got {}", children.len()); + } + Ok(CombinedValidator::FrozenSet(Self { + strict: self.strict, + item_validator: children.into_iter().next().unwrap(), + min_length: self.min_length, + max_length: self.max_length, + name: self.name.clone(), + fail_fast: self.fail_fast, + }) + .into()) + } } struct ValidateToFrozenSet<'a, 's, 'py, I: Input<'py> + ?Sized> { diff --git a/src/validators/function.rs b/src/validators/function.rs index 6b71d82d6..292b27414 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -5,6 +5,7 @@ use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; +use crate::build_tools::py_schema_err; use crate::errors::{ ErrorType, PydanticCustomError, PydanticKnownError, PydanticOmit, ToErrorValue, ValError, ValResult, ValidationError, @@ -153,6 +154,25 @@ impl Validator for FunctionBeforeValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Exactly one child is required for a function-before validator"); + } + Ok(CombinedValidator::FunctionBefore(Self { + validator: children.into_iter().next().unwrap(), + func: self.func.clone(), + config: self.config.clone(), + name: self.name.clone(), + field_name: self.field_name.clone(), + info_arg: self.info_arg, + }) + .into()) + } } #[derive(Debug)] @@ -227,6 +247,25 @@ impl Validator for FunctionAfterValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Exactly one child is required for a function-after validator"); + } + Ok(CombinedValidator::FunctionAfter(Self { + validator: children.into_iter().next().unwrap(), + func: self.func.clone(), + config: self.config.clone(), + name: self.name.clone(), + field_name: self.field_name.clone(), + info_arg: self.info_arg, + }) + .into()) + } } #[derive(Debug, Clone)] @@ -289,6 +328,10 @@ impl Validator for FunctionPlainValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } #[derive(Debug)] @@ -414,6 +457,27 @@ impl Validator for FunctionWrapValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Exactly one child is required for a function-wrap validator"); + } + Ok(CombinedValidator::FunctionWrap(Self { + validator: children.into_iter().next().unwrap(), + func: self.func.clone(), + config: self.config.clone(), + name: self.name.clone(), + field_name: self.field_name.clone(), + info_arg: self.info_arg, + hide_input_in_errors: self.hide_input_in_errors, + validation_error_cause: self.validation_error_cause, + }) + .into()) + } } #[pyclass(module = "pydantic_core._pydantic_core")] diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 6dc63a752..c432fbfaa 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use pyo3::types::{PyDict, PyString}; use pyo3::{prelude::*, IntoPyObjectExt, PyTraverseError, PyVisit}; -use crate::build_tools::ExtraBehavior; +use crate::build_tools::{py_schema_err, ExtraBehavior}; use crate::errors::{ErrorType, LocItem, ValError, ValResult}; use crate::input::{BorrowInput, GenericIterator, Input}; use crate::py_gc::PyGcTraverse; @@ -95,6 +95,35 @@ impl Validator for GeneratorValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + match &self.item_validator { + Some(v) => vec![v], + None => vec![], + } + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if self.item_validator.is_none() { + if !children.is_empty() { + return py_schema_err!("GeneratorValidator expected 0 children, got {}", children.len()); + } + return Ok(CombinedValidator::Generator(self.clone()).into()); + } else { + if children.len() != 1 { + return py_schema_err!("GeneratorValidator expected 1 child, got {}", children.len()); + } + Ok(CombinedValidator::Generator(Self { + item_validator: Some(children.into_iter().next().unwrap()), + min_length: self.min_length, + max_length: self.max_length, + name: self.name.clone(), + hide_input_in_errors: self.hide_input_in_errors, + validation_error_cause: self.validation_error_cause, + }) + .into()) + } + } } #[pyclass(module = "pydantic_core._pydantic_core")] diff --git a/src/validators/int.rs b/src/validators/int.rs index 1d427eb76..4674a5c5c 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -84,6 +84,10 @@ impl Validator for IntValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } #[derive(Debug, Clone)] @@ -184,4 +188,8 @@ impl Validator for ConstrainedIntValidator { fn get_name(&self) -> &'static str { "constrained-int" } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index 62d995b54..2b53c2d41 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -81,6 +81,10 @@ impl Validator for IsInstanceValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } pub fn class_repr(schema: &Bound<'_, PyDict>, class: &Bound<'_, PyAny>) -> PyResult { diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index 6e6b8598b..e5f0d71d9 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -76,4 +76,8 @@ impl Validator for IsSubclassValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } diff --git a/src/validators/json.rs b/src/validators/json.rs index 5b2d6e563..6975d6da1 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -6,6 +6,7 @@ use pyo3::types::PyDict; use jiter::{FloatMode, JsonValue, PythonParse}; +use crate::build_tools::py_schema_err; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{EitherBytes, Input, InputType, ValidationMatch}; use crate::serializers::BytesMode; @@ -87,6 +88,32 @@ impl Validator for JsonValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + self.validator.iter().collect() + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if self.validator.is_none() { + if !children.is_empty() { + return py_schema_err!("No children expected for 'json' validator without schema"); + } + return Ok(CombinedValidator::Json(Self { + validator: None, + name: self.name.clone(), + }) + .into()); + } else { + if children.len() != 1 { + return py_schema_err!("Exactly one child expected for 'json' validator with schema"); + } + return Ok(CombinedValidator::Json(Self { + validator: Some(children[0].clone()), + name: self.name.clone(), + }) + .into()); + } + } } pub fn validate_json_bytes<'a, 'py>( diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index 6b482ef38..b0332bc16 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -4,6 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; +use crate::build_tools::py_schema_err; use crate::definitions::DefinitionsBuilder; use crate::errors::ValResult; use crate::input::Input; @@ -61,4 +62,20 @@ impl Validator for JsonOrPython { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.json, &self.python] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 2 { + return py_schema_err!("JsonOrPython must have exactly two children: json and python"); + } + Ok(CombinedValidator::JsonOrPython(Self { + json: children[0].clone(), + python: children[1].clone(), + name: self.name.clone(), + }) + .into()) + } } diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index 09d5fef7a..cada37953 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -5,6 +5,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::is_strict; +use crate::build_tools::py_schema_err; use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; @@ -83,4 +84,22 @@ impl Validator for LaxOrStrictValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.lax_validator, &self.strict_validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 2 { + return py_schema_err!("LaxOrStrict must have exactly two children: lax and strict"); + } + let mut iter = children.into_iter(); + Ok(CombinedValidator::LaxOrStrict(Self { + lax_validator: iter.next().unwrap(), + strict_validator: iter.next().unwrap(), + name: self.name.clone(), + strict: self.strict, + }) + .into()) + } } diff --git a/src/validators/list.rs b/src/validators/list.rs index 0188d73ea..3ca176680 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -177,6 +177,28 @@ impl Validator for ListValidator { } } } + + fn children(&self) -> Vec<&Arc> { + match &self.item_validator { + Some(v) => vec![v], + None => vec![], + } + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() > 1 { + return crate::build_tools::py_schema_err!("List must have zero or one child"); + } + Ok(CombinedValidator::List(Self { + strict: self.strict, + item_validator: children.into_iter().next(), + min_length: self.min_length, + max_length: self.max_length, + name: OnceLock::new(), + fail_fast: self.fail_fast, + }) + .into()) + } } struct ValidateToVec<'a, 's, 'py, I: Input<'py> + ?Sized> { diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 9c8ffae10..97a156690 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -302,6 +302,10 @@ impl Validator for LiteralValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } pub fn expected_repr_name(mut repr_args: Vec, base_name: &'static str) -> (String, String) { diff --git a/src/validators/missing_sentinel.rs b/src/validators/missing_sentinel.rs index fa897f928..dc61facbd 100644 --- a/src/validators/missing_sentinel.rs +++ b/src/validators/missing_sentinel.rs @@ -49,4 +49,8 @@ impl Validator for MissingSentinelValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 0456566ae..33b63e8f3 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -128,8 +128,13 @@ impl SchemaValidator { pub fn py_new(py: Python, schema: &Bound<'_, PyAny>, config: Option<&Bound<'_, PyDict>>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let validator = build_validator_base(schema, config, &mut definitions_builder)?; + let validator = Arc::new(build_validator_base(schema, config, &mut definitions_builder)?); let definitions = definitions_builder.finish()?; + + // now optimize the schemas: + // - replace any `DefinitionRef` with actual definition (unless it's recursive) + let validator = inline_definitions(&validator, &definitions)?; + let py_schema = schema.clone().unbind(); let py_config = match config { Some(c) if !c.is_empty() => Some(c.clone().into()), @@ -880,4 +885,120 @@ pub trait Validator: Send + Sync + Debug { /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; + + /// Return a list of all child validators, if any + fn children(&self) -> Vec<&Arc>; + + /// Create a new instance of this validator with different children, used by optimization passes + fn with_new_children(&self, _children: Vec>) -> PyResult> { + unimplemented!( + "with_new_children not implemented for {}, should be implemented if `children` returns non-empty", + self.get_name() + ) + } +} + +trait TreeNodeTransformer { + /// Called when entering a node, before visiting children. + /// + /// If it returns `Some`, the traversal will skip this node's children and call to + /// `transform_up`, and return to the parent. + fn transform_down(&mut self, node: &Arc) -> PyResult>> { + Ok(None) + } + + /// Called when exiting a node, after visiting children. + /// + /// If any of the children were replaced, `node` will be the modified one, not the original one + /// passed to `transform_down`. + fn transform_up(&mut self, node: &Arc) -> PyResult>> { + Ok(None) + } +} + +/// Applies `f` to each node in the validator tree. If `f` returns `Some` then the parent nodes will +/// be rebuilt with the new child, otherwise the original node is kept. +fn transform_validator_tree( + validator: &Arc, + f: &mut impl TreeNodeTransformer, +) -> PyResult> { + if let Some(new_validator) = f.transform_down(validator)? { + return Ok(new_validator); + } + + let children = validator.children(); + if children.is_empty() { + return Ok(validator.clone()); + } + let mut new_children = Vec::with_capacity(children.len()); + let mut changed = false; + for child in children { + let new_child = transform_validator_tree(child, f)?; + if !changed && !Arc::ptr_eq(child, &new_child) { + changed = true; + } + new_children.push(new_child); + } + + let output = if changed { + validator.with_new_children(new_children)? + } else { + validator.clone() + }; + + f.transform_up(validator)?; + Ok(output) +} + +/// Inlines `definition` validators where they are used, except in the case of recursive models +fn inline_definitions( + root_validator: &Arc, + definitions: &Definitions>, +) -> PyResult> { + struct DefinitionInliner<'a> { + stack: Vec>, + // if recursion was detected, all parents above this index should be considered recursive + // (if we are on a recursive path, we don't inline any definitions) + stack_contains_recursion: Option, + definitions: &'a Definitions>, + } + + impl TreeNodeTransformer for DefinitionInliner<'_> { + fn transform_down(&mut self, node: &Arc) -> PyResult>> { + if let CombinedValidator::DefinitionRef(def_ref) = node.as_ref() { + let Some(def) = def_ref.definition().read(|def| def.cloned()) else { + // todo make a nicer error + panic!("definition was never filled"); + }; + + let stack_len = self.stack.len(); + let is_recursive = stack_len > 1 && self.stack[..stack_len - 1].iter().any(|n| { + std::ptr::eq(n, node) + }); + + if !is_recursive { + self.stack.push(def.clone()); + return Ok(Some(def)); + } + } + + self.stack.push(node.clone()); + Ok(None) + } + + fn transform_up(&mut self, node: &Arc) -> PyResult>> { + // TODO: on the way up, we should insert a recursion guard if we have a recursive path + self.stack.pop(); + Ok(None) + } + } + + transform_validator_tree( + root_validator, + &mut DefinitionInliner { + stack: Vec::new(), + stack_contains_recursion: None, + definitions, + }, + ) } diff --git a/src/validators/model.rs b/src/validators/model.rs index 3385f68b6..402208ff1 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -24,7 +24,7 @@ const DUNDER_FIELDS_SET_KEY: &str = "__pydantic_fields_set__"; const DUNDER_MODEL_EXTRA_KEY: &str = "__pydantic_extra__"; const DUNDER_MODEL_PRIVATE_KEY: &str = "__pydantic_private__"; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub(super) enum Revalidate { Always, Never, @@ -243,6 +243,29 @@ impl Validator for ModelValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Model must have exactly one child: the inner validator"); + } + + Ok(Arc::new(CombinedValidator::Model(ModelValidator { + revalidate: self.revalidate, + validator: children.into_iter().next().unwrap(), + class: self.class.clone(), + generic_origin: self.generic_origin.clone(), + post_init: self.post_init.clone(), + frozen: self.frozen, + custom_init: self.custom_init, + root_model: self.root_model, + undefined: self.undefined.clone(), + name: self.name.clone(), + }))) + } } impl ModelValidator { diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 672cd55ce..b7887e2e1 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -22,7 +22,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuild #[derive(Debug)] struct Field { name: String, - lookup_key_collection: LookupKeyCollection, + lookup_key_collection: Arc, name_py: Py, validator: Arc, frozen: bool, @@ -90,7 +90,7 @@ impl BuildValidator for ModelFieldsValidator { }; let validation_alias = field_info.get_item(intern!(py, "validation_alias"))?; - let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, field_name)?; + let lookup_key_collection = Arc::new(LookupKeyCollection::new(py, validation_alias, field_name)?); fields.push(Field { name: field_name.to_string(), @@ -486,4 +486,62 @@ impl Validator for ModelFieldsValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + let mut children = Vec::with_capacity(self.fields.len() + 2); + for field in &self.fields { + children.push(&field.validator); + } + if let Some(ref v) = self.extras_validator { + children.push(v); + } + if let Some(ref v) = self.extras_keys_validator { + children.push(v); + } + children + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + let expected_len = self.fields.len() + + if self.extras_validator.is_some() { 1 } else { 0 } + + if self.extras_keys_validator.is_some() { 1 } else { 0 }; + if children.len() != expected_len { + return py_schema_err!("ModelFields must have exactly {} children", expected_len); + } + let mut iter = children.into_iter(); + let new_fields = self + .fields + .iter() + .map(|field| Field { + name: field.name.clone(), + lookup_key_collection: field.lookup_key_collection.clone(), + name_py: field.name_py.clone(), + validator: iter.next().unwrap(), + frozen: field.frozen, + }) + .collect(); + let extras_validator = if self.extras_validator.is_some() { + Some(iter.next().unwrap()) + } else { + None + }; + let extras_keys_validator = if self.extras_keys_validator.is_some() { + Some(iter.next().unwrap()) + } else { + None + }; + Ok(CombinedValidator::ModelFields(Self { + fields: new_fields, + model_name: self.model_name.clone(), + extra_behavior: self.extra_behavior, + extras_validator, + extras_keys_validator, + strict: self.strict, + from_attributes: self.from_attributes, + loc_by_alias: self.loc_by_alias, + validate_by_alias: self.validate_by_alias, + validate_by_name: self.validate_by_name, + }) + .into()) + } } diff --git a/src/validators/none.rs b/src/validators/none.rs index 9d35eec86..7c2c4727e 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -44,4 +44,8 @@ impl Validator for NoneValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 00e702f88..a143f7d9d 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -4,6 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; +use crate::build_tools::py_schema_err; use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; @@ -50,4 +51,19 @@ impl Validator for NullableValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Nullable must have exactly one child"); + } + Ok(CombinedValidator::Nullable(Self { + validator: children.into_iter().next().unwrap(), + name: self.name.clone(), + }) + .into()) + } } diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index 77e6ba758..9b67f3e6f 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -43,4 +43,10 @@ impl Validator for PrebuiltValidator { fn get_name(&self) -> &str { self.schema_validator.get().validator.get_name() } + + fn children(&self) -> Vec<&std::sync::Arc> { + // Treat "prebuilt" as a leaf node as it may contain config boundary etc, do not want to + // optimize through it + vec![] + } } diff --git a/src/validators/set.rs b/src/validators/set.rs index 8c62c6d1a..2bc8a48af 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use pyo3::types::{PyDict, PySet}; use pyo3::{prelude::*, IntoPyObjectExt}; +use crate::build_tools::py_schema_err; use crate::errors::ValResult; use crate::input::{validate_iter_to_set, BorrowInput, ConsumeIterator, Input, ValidatedSet}; use crate::tools::SchemaDict; @@ -83,6 +84,25 @@ impl Validator for SetValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.item_validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Set must have exactly one child"); + } + Ok(CombinedValidator::Set(Self { + strict: self.strict, + item_validator: children.into_iter().next().unwrap(), + min_length: self.min_length, + max_length: self.max_length, + name: self.name.clone(), + fail_fast: self.fail_fast, + }) + .into()) + } } struct ValidateToSet<'a, 's, 'py, I: Input<'py> + ?Sized> { diff --git a/src/validators/string.rs b/src/validators/string.rs index 705c4f2b1..7666fd4bc 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -86,6 +86,10 @@ impl Validator for StrValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } /// Any new properties set here must be reflected in `has_constraints_set` @@ -175,6 +179,10 @@ impl Validator for StrConstrainedValidator { fn get_name(&self) -> &'static str { "constrained-str" } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } impl StrConstrainedValidator { diff --git a/src/validators/time.rs b/src/validators/time.rs index 6fb5037db..1982911a0 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -86,6 +86,10 @@ impl Validator for TimeValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } fn convert_pytime(schema: &Bound<'_, PyDict>, key: &Bound<'_, PyString>) -> PyResult> { diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index bf2950df7..6a83eef30 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -113,6 +113,10 @@ impl Validator for TimeDeltaValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } fn pydelta_to_human_readable(py_delta: Bound<'_, PyDelta>) -> String { let total_seconds = py_delta.get_seconds(); diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 5ede4bad2..7fab9fef8 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -6,6 +6,7 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; use crate::build_tools::is_strict; +use crate::build_tools::py_schema_err; use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::ConsumeIterator; use crate::input::{BorrowInput, Input, ValidatedTuple}; @@ -316,6 +317,26 @@ impl Validator for TupleValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + self.validators.iter().collect() + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != self.validators.len() { + return py_schema_err!("Tuple must have exactly {} children", self.validators.len()); + } + Ok(CombinedValidator::Tuple(Self { + strict: self.strict, + validators: children, + variadic_item_index: self.variadic_item_index, + min_length: self.min_length, + max_length: self.max_length, + name: self.name.clone(), + fail_fast: self.fail_fast, + }) + .into()) + } } struct ValidateToTuple<'a, 's, 'py, I: Input<'py> + ?Sized> { diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index fdec72143..b09cd2e8e 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -22,7 +22,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuild #[derive(Debug)] struct TypedDictField { name: String, - lookup_key_collection: LookupKeyCollection, + lookup_key_collection: Arc, name_py: Py, required: bool, validator: Arc, @@ -121,7 +121,7 @@ impl BuildValidator for TypedDictValidator { } let validation_alias = field_info.get_item(intern!(py, "validation_alias"))?; - let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, field_name)?; + let lookup_key_collection = Arc::new(LookupKeyCollection::new(py, validation_alias, field_name)?); fields.push(TypedDictField { name: field_name.to_string(), @@ -385,4 +385,51 @@ impl Validator for TypedDictValidator { fn get_name(&self) -> &str { self.cls_name.as_deref().unwrap_or(Self::EXPECTED_TYPE) } + + fn children(&self) -> Vec<&Arc> { + let mut children: Vec<&Arc> = self.fields.iter().map(|f| &f.validator).collect(); + if let Some(extras_validator) = &self.extras_validator { + children.push(extras_validator); + } + children + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + let expected_len = self.fields.len() + if self.extras_validator.is_some() { 1 } else { 0 }; + if children.len() != expected_len { + return py_schema_err!("Expected {} children for TypedDict validator", expected_len); + } + + let (fields_validators, extras_validator) = if let Some(_) = &self.extras_validator { + let (fields, extras) = children.split_at(self.fields.len()); + (fields.to_vec(), Some(extras[0].clone())) + } else { + (children, None) + }; + + let new_fields: Vec = self + .fields + .iter() + .zip(fields_validators.into_iter()) + .map(|(field, validator)| TypedDictField { + name: field.name.clone(), + validator, + lookup_key_collection: field.lookup_key_collection.clone(), + name_py: field.name_py.clone(), + required: field.required, + }) + .collect(); + + Ok(CombinedValidator::TypedDict(Self { + fields: new_fields, + extra_behavior: self.extra_behavior, + extras_validator, + strict: self.strict, + loc_by_alias: self.loc_by_alias, + validate_by_alias: self.validate_by_alias, + validate_by_name: self.validate_by_name, + cls_name: self.cls_name.clone(), + }) + .into()) + } } diff --git a/src/validators/union.rs b/src/validators/union.rs index 200e2b339..e44525eba 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -21,7 +21,7 @@ use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator, }; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] enum UnionMode { Smart, LeftToRight, @@ -224,6 +224,31 @@ impl Validator for UnionValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + self.choices.iter().map(|(v, _)| v).collect() + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != self.choices.len() { + return py_schema_err!( + "Number of children for union must be the same as the existing number of choices ({})", + self.choices.len() + ); + } + let new_choices = children + .into_iter() + .zip(self.choices.iter().map(|(_, label)| label.clone())) + .collect(); + + Ok(CombinedValidator::Union(Self { + mode: self.mode, + choices: new_choices, + custom_error: self.custom_error.clone(), + name: self.name.clone(), + }) + .into()) + } } struct ChoiceLineErrors<'a> { @@ -282,7 +307,7 @@ impl<'a> MaybeErrors<'a> { #[derive(Debug)] pub struct TaggedUnionValidator { - discriminator: Discriminator, + discriminator: Arc, lookup: LiteralLookup>, from_attributes: bool, custom_error: Option, @@ -300,7 +325,10 @@ impl BuildValidator for TaggedUnionValidator { definitions: &mut DefinitionsBuilder>, ) -> PyResult> { let py = schema.py(); - let discriminator = Discriminator::new(py, &schema.get_as_req(intern!(py, "discriminator"))?)?; + let discriminator = Arc::new(Discriminator::new( + py, + &schema.get_as_req(intern!(py, "discriminator"))?, + )?); let discriminator_repr = discriminator.to_string_py(py)?; let choices = PyDict::new(py); @@ -351,7 +379,7 @@ impl Validator for TaggedUnionValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult> { - match &self.discriminator { + match self.discriminator.as_ref() { Discriminator::LookupKey(lookup_key) => { let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes); let dict = input.validate_model_fields(state.strict_or(false), from_attributes)?; @@ -377,6 +405,33 @@ impl Validator for TaggedUnionValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + self.lookup.values.iter().collect() + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != self.lookup.values.len() { + return py_schema_err!( + "Number of children for tagged-union must be the same as the existing number of choices ({})", + self.lookup.values.len() + ); + } + + let mut new_lookup = self.lookup.clone(); + new_lookup.values = children; + + Ok(CombinedValidator::TaggedUnion(Self { + discriminator: self.discriminator.clone(), + lookup: new_lookup, + from_attributes: self.from_attributes, + custom_error: self.custom_error.clone(), + tags_repr: self.tags_repr.clone(), + discriminator_repr: self.discriminator_repr.clone(), + name: self.name.clone(), + }) + .into()) + } } impl TaggedUnionValidator { diff --git a/src/validators/url.rs b/src/validators/url.rs index bcd6e8786..a41b11fae 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -105,6 +105,10 @@ impl Validator for UrlValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } impl UrlValidator { @@ -266,6 +270,10 @@ impl Validator for MultiHostUrlValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } impl MultiHostUrlValidator { diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index 25b80c55a..d9fefe536 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -161,6 +161,10 @@ impl Validator for UuidValidator { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn children(&self) -> Vec<&Arc> { + vec![] + } } impl UuidValidator { diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index c097f40d0..bb7306827 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -212,6 +212,26 @@ impl Validator for WithDefaultValidator { fn get_name(&self) -> &str { &self.name } + + fn children(&self) -> Vec<&Arc> { + vec![&self.validator] + } + + fn with_new_children(&self, children: Vec>) -> PyResult> { + if children.len() != 1 { + return py_schema_err!("Expected 1 child, got {}", children.len()); + } + Ok(CombinedValidator::WithDefault(Self { + default: self.default.clone(), + on_error: self.on_error.clone(), + validator: children.into_iter().next().unwrap(), + validate_default: self.validate_default, + copy_default: self.copy_default, + name: self.name.clone(), + undefined: self.undefined.clone(), + }) + .into()) + } } impl WithDefaultValidator { diff --git a/tests/validators/test_definitions.py b/tests/validators/test_definitions.py index 967eeae2d..eef44efb1 100644 --- a/tests/validators/test_definitions.py +++ b/tests/validators/test_definitions.py @@ -16,6 +16,8 @@ def test_list_with_def(): assert v.validate_json(b'[1, 2, "3"]') == [1, 2, 3] r = plain_repr(v) assert r.startswith('SchemaValidator(title="list[int]",') + # definition should have been inlined, not recursive + assert 'DefinitionRef' not in r def test_ignored_def():