Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/validators/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ impl Validator for AnyValidator {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
vec![]
}
}
60 changes: 57 additions & 3 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone, Copy)]
enum VarKwargsMode {
Uniform,
UnpackedTypedDict,
Expand All @@ -45,7 +45,7 @@ struct Parameter {
name: String,
kwarg_key: Option<Py<PyString>>,
validator: Arc<CombinedValidator>,
lookup_key_collection: LookupKeyCollection,
lookup_key_collection: Arc<LookupKeyCollection>,
mode: String,
}

Expand Down Expand Up @@ -129,7 +129,7 @@ impl BuildValidator for ArgumentsValidator {
}

let validation_alias = arg.get_item(intern!(py, "alias"))?;
let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, name.as_str())?;
let lookup_key_collection = Arc::new(LookupKeyCollection::new(py, validation_alias, name.as_str())?);

parameters.push(Parameter {
positional,
Expand Down Expand Up @@ -404,4 +404,58 @@ impl Validator for ArgumentsValidator {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
let mut children: Vec<&Arc<CombinedValidator>> = self.parameters.iter().map(|param| &param.validator).collect();
if let Some(var_args_validator) = &self.var_args_validator {
children.push(var_args_validator);
}
if let Some(var_kwargs_validator) = &self.var_kwargs_validator {
children.push(var_kwargs_validator);
}
children
}

fn with_new_children(&self, children: Vec<Arc<CombinedValidator>>) -> PyResult<Arc<CombinedValidator>> {
let expected_len = self.parameters.len()
+ self.var_args_validator.as_ref().map_or(0, |_| 1)
+ self.var_kwargs_validator.as_ref().map_or(0, |_| 1);

if children.len() != expected_len {
return py_schema_err!("Expected {} children, got {}", expected_len, children.len());
}

let mut child_iter = children.into_iter();
Ok(CombinedValidator::Arguments(Self {
parameters: self
.parameters
.iter()
.map(|param| Parameter {
positional: param.positional,
name: param.name.clone(),
kwarg_key: param.kwarg_key.clone(),
validator: child_iter.next().unwrap(),
lookup_key_collection: param.lookup_key_collection.clone(),
mode: param.mode.clone(),
})
.collect(),
positional_params_count: self.positional_params_count,
var_args_validator: if self.var_args_validator.is_some() {
Some(child_iter.next().unwrap())
} else {
None
},
var_kwargs_mode: self.var_kwargs_mode,
var_kwargs_validator: if self.var_kwargs_validator.is_some() {
Some(child_iter.next().unwrap())
} else {
None
},
loc_by_alias: self.loc_by_alias,
extra: self.extra,
validate_by_alias: self.validate_by_alias,
validate_by_name: self.validate_by_name,
})
.into())
}
}
36 changes: 33 additions & 3 deletions src/validators/arguments_v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone, Copy)]
enum ParameterMode {
PositionalOnly,
PositionalOrKeyword,
Expand Down Expand Up @@ -52,7 +52,7 @@ impl FromStr for ParameterMode {
struct Parameter {
name: String,
mode: ParameterMode,
lookup_key_collection: LookupKeyCollection,
lookup_key_collection: Arc<LookupKeyCollection>,
validator: Arc<CombinedValidator>,
}

Expand Down Expand Up @@ -185,7 +185,7 @@ impl BuildValidator for ArgumentsV3Validator {
}

let validation_alias = arg.get_item(intern!(py, "alias"))?;
let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, name.as_str())?;
let lookup_key_collection = Arc::new(LookupKeyCollection::new(py, validation_alias, name.as_str())?);

parameters.push(Parameter {
name,
Expand Down Expand Up @@ -783,4 +783,34 @@ impl Validator for ArgumentsV3Validator {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
self.parameters.iter().map(|p| &p.validator).collect()
}

fn with_new_children(&self, children: Vec<Arc<CombinedValidator>>) -> PyResult<Arc<CombinedValidator>> {
if children.len() != self.parameters.len() {
return py_schema_err!("Expected {} children, got {}", self.parameters.len(), children.len());
}

Ok(CombinedValidator::ArgumentsV3(Self {
parameters: self
.parameters
.iter()
.zip(children.into_iter())
.map(|(p, v)| Parameter {
name: p.name.clone(),
mode: p.mode,
lookup_key_collection: p.lookup_key_collection.clone(),
validator: v,
})
.collect(),
positional_params_count: self.positional_params_count,
loc_by_alias: self.loc_by_alias,
extra: self.extra,
validate_by_alias: self.validate_by_alias,
validate_by_name: self.validate_by_name,
})
.into())
}
}
4 changes: 4 additions & 0 deletions src/validators/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,8 @@ impl Validator for BoolValidator {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
vec![]
}
}
8 changes: 8 additions & 0 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ impl Validator for BytesValidator {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
vec![]
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -112,6 +116,10 @@ impl Validator for BytesConstrainedValidator {
fn get_name(&self) -> &'static str {
"constrained-bytes"
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
vec![]
}
}

impl BytesConstrainedValidator {
Expand Down
33 changes: 33 additions & 0 deletions src/validators/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::prelude::*;
use pyo3::types::PyString;
use pyo3::types::{PyDict, PyTuple};

use crate::build_tools::py_schema_err;
use crate::errors::ValResult;
use crate::input::Input;

Expand Down Expand Up @@ -106,4 +107,36 @@ impl Validator for CallValidator {
fn get_name(&self) -> &str {
&self.name
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
let mut children = vec![&self.arguments_validator];
if let Some(return_validator) = &self.return_validator {
children.push(return_validator);
}
children
}

fn with_new_children(&self, children: Vec<Arc<CombinedValidator>>) -> PyResult<Arc<CombinedValidator>> {
if children.len() != if self.return_validator.is_some() { 2 } else { 1 } {
return py_schema_err!(
"Expected {} children, got {}",
if self.return_validator.is_some() { 2 } else { 1 },
children.len()
);
}

let mut child_iter = children.into_iter();

Ok(CombinedValidator::FunctionCall(Self {
function: self.function.clone(),
arguments_validator: child_iter.next().unwrap(),
return_validator: if self.return_validator.is_some() {
Some(child_iter.next().unwrap())
} else {
None
},
name: self.name.clone(),
})
.into())
}
}
4 changes: 4 additions & 0 deletions src/validators/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,8 @@ impl Validator for CallableValidator {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
vec![]
}
}
46 changes: 30 additions & 16 deletions src/validators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,7 @@ impl BuildValidator for ChainValidator {
.flatten()
.collect();

match steps.len() {
0 => py_schema_err!("One or more steps are required for a chain validator"),
1 => {
let step = steps.into_iter().next().unwrap();
Ok(step)
}
_ => {
let descr = steps.iter().map(|v| v.get_name()).collect::<Vec<_>>().join(",");

Ok(CombinedValidator::Chain(Self {
steps,
name: format!("{}[{descr}]", Self::EXPECTED_TYPE),
})
.into())
}
}
ChainValidator::from_steps(steps)
}
}

Expand All @@ -71,6 +56,27 @@ fn build_validator_steps(

impl_py_gc_traverse!(ChainValidator { steps });

impl ChainValidator {
fn from_steps(steps: Vec<Arc<CombinedValidator>>) -> PyResult<Arc<CombinedValidator>> {
match steps.len() {
0 => py_schema_err!("One or more steps are required for a chain validator"),
1 => {
let step = steps.into_iter().next().unwrap();
Ok(step)
}
_ => {
let descr = steps.iter().map(|v| v.get_name()).collect::<Vec<_>>().join(",");

Ok(CombinedValidator::Chain(Self {
steps,
name: format!("{}[{descr}]", Self::EXPECTED_TYPE),
})
.into())
}
}
}
}

impl Validator for ChainValidator {
fn validate<'py>(
&self,
Expand All @@ -88,4 +94,12 @@ impl Validator for ChainValidator {
fn get_name(&self) -> &str {
&self.name
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
self.steps.iter().collect()
}

fn with_new_children(&self, children: Vec<Arc<CombinedValidator>>) -> PyResult<Arc<CombinedValidator>> {
ChainValidator::from_steps(children)
}
}
4 changes: 4 additions & 0 deletions src/validators/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ impl Validator for ComplexValidator {
fn get_name(&self) -> &'static str {
"complex"
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
vec![]
}
}

pub(crate) fn string_to_complex<'py>(
Expand Down
16 changes: 16 additions & 0 deletions src/validators/custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,20 @@ impl Validator for CustomErrorValidator {
fn get_name(&self) -> &str {
&self.name
}

fn children(&self) -> Vec<&Arc<CombinedValidator>> {
vec![&self.validator]
}

fn with_new_children(&self, children: Vec<Arc<CombinedValidator>>) -> PyResult<Arc<CombinedValidator>> {
if children.len() != 1 {
return py_schema_err!("Exactly one child is required for a custom-error validator");
}
Ok(CombinedValidator::CustomError(Self {
validator: children.into_iter().next().unwrap(),
custom_error: self.custom_error.clone(),
name: self.name.clone(),
})
.into())
}
}
Loading
Loading