Skip to content

Commit 373b6fa

Browse files
authored
Only use RecursiveContainerValidator when necessary (#243)
* Only use RecursiveContainerValidator when necessary * tweak tests
1 parent 146309f commit 373b6fa

File tree

2 files changed

+94
-20
lines changed

2 files changed

+94
-20
lines changed

src/validators/mod.rs

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ use std::fmt::Debug;
22

33
use enum_dispatch::enum_dispatch;
44

5+
use ahash::AHashSet;
56
use pyo3::exceptions::PyTypeError;
67
use pyo3::intern;
78
use pyo3::once_cell::GILOnceCell;
89
use pyo3::prelude::*;
9-
use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyString};
10+
use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyList, PyString};
1011

1112
use crate::build_tools::{py_error, SchemaDict, SchemaError};
1213
use crate::errors::{ErrorKind, ValError, ValLineError, ValResult, ValidationError};
@@ -69,7 +70,10 @@ impl SchemaValidator {
6970
.map_err(|e| SchemaError::from_val_error(py, e))?;
7071
let schema = schema_obj.as_ref(py);
7172

72-
let mut build_context = BuildContext::default();
73+
let mut used_refs = AHashSet::new();
74+
extract_used_refs(schema, &mut used_refs)?;
75+
let mut build_context = BuildContext::new(used_refs);
76+
7377
let mut validator = build_validator(schema, config, &mut build_context)?;
7478
validator.complete(&build_context)?;
7579
let slots = build_context.into_slots()?;
@@ -219,7 +223,12 @@ impl SchemaValidator {
219223
py.run(code, None, Some(locals))?;
220224
let self_schema: &PyDict = locals.get_as_req(intern!(py, "self_schema"))?;
221225

222-
let mut build_context = BuildContext::default();
226+
let mut used_refs = AHashSet::new();
227+
// NOTE: we don't call `extract_used_refs` for performance reasons, if more recursive references
228+
// are used, they would need to be manually added here.
229+
used_refs.insert("root-schema".to_string());
230+
let mut build_context = BuildContext::new(used_refs);
231+
223232
let validator = match build_validator(self_schema, None, &mut build_context) {
224233
Ok(v) => v,
225234
Err(err) => return Err(SchemaError::new_err(format!("Error building self-schema:\n {}", err))),
@@ -260,26 +269,29 @@ pub trait BuildValidator: Sized {
260269
-> PyResult<CombinedValidator>;
261270
}
262271

272+
/// Logic to create a particular validator, called in the `validator_match` macro, then in turn by `build_validator`
263273
fn build_single_validator<'a, T: BuildValidator>(
264274
val_type: &str,
265275
schema_dict: &'a PyDict,
266276
config: Option<&'a PyDict>,
267277
build_context: &mut BuildContext,
268278
) -> PyResult<CombinedValidator> {
269279
let py = schema_dict.py();
270-
let val: CombinedValidator = if let Some(schema_ref) = schema_dict.get_as::<String>(intern!(py, "ref"))? {
271-
let slot_id = build_context.prepare_slot(schema_ref)?;
272-
let inner_val = T::build(schema_dict, config, build_context)
273-
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?;
274-
let name = inner_val.get_name().to_string();
275-
build_context.complete_slot(slot_id, inner_val)?;
276-
recursive::RecursiveContainerValidator::create(slot_id, name)
277-
} else {
278-
T::build(schema_dict, config, build_context)
279-
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?
280-
};
280+
if let Some(schema_ref) = schema_dict.get_as::<String>(intern!(py, "ref"))? {
281+
// we only want to use a RecursiveContainerValidator if the ref is actually used,
282+
// this means refs can always be set without having an effect on the validator which is generated
283+
// unless it's used/referenced
284+
if build_context.ref_used(&schema_ref) {
285+
let slot_id = build_context.prepare_slot(schema_ref)?;
286+
let inner_val = T::build(schema_dict, config, build_context)?;
287+
let name = inner_val.get_name().to_string();
288+
build_context.complete_slot(slot_id, inner_val)?;
289+
return Ok(recursive::RecursiveContainerValidator::create(slot_id, name));
290+
}
291+
}
281292

282-
Ok(val)
293+
T::build(schema_dict, config, build_context)
294+
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))
283295
}
284296

285297
// macro to build the match statement for validator selection
@@ -523,10 +535,23 @@ pub trait Validator: Send + Sync + Clone + Debug {
523535
/// and therefore can't be owned by them directly.
524536
#[derive(Default, Clone)]
525537
pub struct BuildContext {
538+
used_refs: AHashSet<String>,
526539
slots: Vec<(String, Option<CombinedValidator>)>,
527540
}
528541

529542
impl BuildContext {
543+
pub fn new(used_refs: AHashSet<String>) -> Self {
544+
Self {
545+
used_refs,
546+
..Default::default()
547+
}
548+
}
549+
550+
/// check if a ref is used elsewhere in the schema
551+
pub fn ref_used(&self, ref_: &str) -> bool {
552+
self.used_refs.contains(ref_)
553+
}
554+
530555
/// First of two part process to add a new validator slot, we add the `slot_ref` to the array, but not the
531556
/// actual `validator`, we can't add the validator until it's build.
532557
/// We need the `id` to build the validator, hence this two-step process.
@@ -584,3 +609,21 @@ impl BuildContext {
584609
.collect()
585610
}
586611
}
612+
613+
fn extract_used_refs(schema: &PyAny, refs: &mut AHashSet<String>) -> PyResult<()> {
614+
if let Ok(dict) = schema.cast_as::<PyDict>() {
615+
let py = schema.py();
616+
if matches!(dict.get_as(intern!(py, "type")), Ok(Some("recursive-ref"))) {
617+
refs.insert(dict.get_as_req(intern!(py, "schema_ref"))?);
618+
} else {
619+
for (_, value) in dict.iter() {
620+
extract_used_refs(value, refs)?;
621+
}
622+
}
623+
} else if let Ok(list) = schema.cast_as::<PyList>() {
624+
for item in list.iter() {
625+
extract_used_refs(item, refs)?;
626+
}
627+
}
628+
Ok(())
629+
}

tests/validators/test_recursive.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pydantic_core import SchemaError, SchemaValidator, ValidationError
77

8-
from ..conftest import Err
8+
from ..conftest import Err, plain_repr
99
from .test_typed_dict import Cls
1010

1111

@@ -19,10 +19,7 @@ def test_branch_nullable():
1919
'sub_branch': {
2020
'schema': {
2121
'type': 'default',
22-
'schema': {
23-
'type': 'union',
24-
'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Branch'}],
25-
},
22+
'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}},
2623
'default': None,
2724
}
2825
},
@@ -31,6 +28,7 @@ def test_branch_nullable():
3128
)
3229

3330
assert v.validate_python({'name': 'root'}) == {'name': 'root', 'sub_branch': None}
31+
assert plain_repr(v).startswith('SchemaValidator(name="typed-dict",validator=Recursive(RecursiveContainerValidator')
3432

3533
assert v.validate_python({'name': 'root', 'sub_branch': {'name': 'b1'}}) == (
3634
{'name': 'root', 'sub_branch': {'name': 'b1', 'sub_branch': None}}
@@ -40,6 +38,14 @@ def test_branch_nullable():
4038
)
4139

4240

41+
def test_unused_ref():
42+
v = SchemaValidator(
43+
{'type': 'typed-dict', 'ref': 'Branch', 'fields': {'name': {'schema': 'str'}, 'other': {'schema': 'int'}}}
44+
)
45+
assert plain_repr(v).startswith('SchemaValidator(name="typed-dict",validator=TypedDict(TypedDictValidator')
46+
assert v.validate_python({'name': 'root', 'other': '4'}) == {'name': 'root', 'other': 4}
47+
48+
4349
def test_nullable_error():
4450
v = SchemaValidator(
4551
{
@@ -680,3 +686,28 @@ def test_many_uses_of_ref():
680686

681687
long_input = {'name': 'Anne', 'other_names': [f'p-{i}' for i in range(300)]}
682688
assert v.validate_python(long_input) == long_input
689+
690+
691+
def test_error_inside_recursive_wrapper():
692+
with pytest.raises(SchemaError) as exc_info:
693+
SchemaValidator(
694+
{
695+
'type': 'typed-dict',
696+
'ref': 'Branch',
697+
'fields': {
698+
'sub_branch': {
699+
'schema': {
700+
'type': 'default',
701+
'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}},
702+
'default': None,
703+
'default_factory': lambda x: 'foobar',
704+
}
705+
}
706+
},
707+
}
708+
)
709+
assert str(exc_info.value) == (
710+
'Field "sub_branch":\n'
711+
' SchemaError: Error building "default" validator:\n'
712+
" SchemaError: 'default' and 'default_factory' cannot be used together"
713+
)

0 commit comments

Comments
 (0)