@@ -2,11 +2,12 @@ use std::fmt::Debug;
22
33use enum_dispatch:: enum_dispatch;
44
5+ use ahash:: AHashSet ;
56use pyo3:: exceptions:: PyTypeError ;
67use pyo3:: intern;
78use pyo3:: once_cell:: GILOnceCell ;
89use pyo3:: prelude:: * ;
9- use pyo3:: types:: { PyAny , PyByteArray , PyBytes , PyDict , PyString } ;
10+ use pyo3:: types:: { PyAny , PyByteArray , PyBytes , PyDict , PyList , PyString } ;
1011
1112use crate :: build_tools:: { py_error, SchemaDict , SchemaError } ;
1213use crate :: errors:: { ErrorKind , ValError , ValLineError , ValResult , ValidationError } ;
@@ -69,7 +70,10 @@ impl SchemaValidator {
6970 . map_err ( |e| SchemaError :: from_val_error ( py, e) ) ?;
7071 let schema = schema_obj. as_ref ( py) ;
7172
72- let mut build_context = BuildContext :: default ( ) ;
73+ let mut used_refs = AHashSet :: new ( ) ;
74+ extract_used_refs ( schema, & mut used_refs) ?;
75+ let mut build_context = BuildContext :: new ( used_refs) ;
76+
7377 let mut validator = build_validator ( schema, config, & mut build_context) ?;
7478 validator. complete ( & build_context) ?;
7579 let slots = build_context. into_slots ( ) ?;
@@ -219,7 +223,12 @@ impl SchemaValidator {
219223 py. run ( code, None , Some ( locals) ) ?;
220224 let self_schema: & PyDict = locals. get_as_req ( intern ! ( py, "self_schema" ) ) ?;
221225
222- let mut build_context = BuildContext :: default ( ) ;
226+ let mut used_refs = AHashSet :: new ( ) ;
227+ // NOTE: we don't call `extract_used_refs` for performance reasons, if more recursive references
228+ // are used, they would need to be manually added here.
229+ used_refs. insert ( "root-schema" . to_string ( ) ) ;
230+ let mut build_context = BuildContext :: new ( used_refs) ;
231+
223232 let validator = match build_validator ( self_schema, None , & mut build_context) {
224233 Ok ( v) => v,
225234 Err ( err) => return Err ( SchemaError :: new_err ( format ! ( "Error building self-schema:\n {}" , err) ) ) ,
@@ -260,26 +269,29 @@ pub trait BuildValidator: Sized {
260269 -> PyResult < CombinedValidator > ;
261270}
262271
272+ /// Logic to create a particular validator, called in the `validator_match` macro, then in turn by `build_validator`
263273fn build_single_validator < ' a , T : BuildValidator > (
264274 val_type : & str ,
265275 schema_dict : & ' a PyDict ,
266276 config : Option < & ' a PyDict > ,
267277 build_context : & mut BuildContext ,
268278) -> PyResult < CombinedValidator > {
269279 let py = schema_dict. py ( ) ;
270- let val: CombinedValidator = if let Some ( schema_ref) = schema_dict. get_as :: < String > ( intern ! ( py, "ref" ) ) ? {
271- let slot_id = build_context. prepare_slot ( schema_ref) ?;
272- let inner_val = T :: build ( schema_dict, config, build_context)
273- . map_err ( |err| SchemaError :: new_err ( format ! ( "Error building \" {}\" validator:\n {}" , val_type, err) ) ) ?;
274- let name = inner_val. get_name ( ) . to_string ( ) ;
275- build_context. complete_slot ( slot_id, inner_val) ?;
276- recursive:: RecursiveContainerValidator :: create ( slot_id, name)
277- } else {
278- T :: build ( schema_dict, config, build_context)
279- . map_err ( |err| SchemaError :: new_err ( format ! ( "Error building \" {}\" validator:\n {}" , val_type, err) ) ) ?
280- } ;
280+ if let Some ( schema_ref) = schema_dict. get_as :: < String > ( intern ! ( py, "ref" ) ) ? {
281+ // we only want to use a RecursiveContainerValidator if the ref is actually used,
282+ // this means refs can always be set without having an effect on the validator which is generated
283+ // unless it's used/referenced
284+ if build_context. ref_used ( & schema_ref) {
285+ let slot_id = build_context. prepare_slot ( schema_ref) ?;
286+ let inner_val = T :: build ( schema_dict, config, build_context) ?;
287+ let name = inner_val. get_name ( ) . to_string ( ) ;
288+ build_context. complete_slot ( slot_id, inner_val) ?;
289+ return Ok ( recursive:: RecursiveContainerValidator :: create ( slot_id, name) ) ;
290+ }
291+ }
281292
282- Ok ( val)
293+ T :: build ( schema_dict, config, build_context)
294+ . map_err ( |err| SchemaError :: new_err ( format ! ( "Error building \" {}\" validator:\n {}" , val_type, err) ) )
283295}
284296
285297// macro to build the match statement for validator selection
@@ -523,10 +535,23 @@ pub trait Validator: Send + Sync + Clone + Debug {
523535/// and therefore can't be owned by them directly.
524536#[ derive( Default , Clone ) ]
525537pub struct BuildContext {
538+ used_refs : AHashSet < String > ,
526539 slots : Vec < ( String , Option < CombinedValidator > ) > ,
527540}
528541
529542impl BuildContext {
543+ pub fn new ( used_refs : AHashSet < String > ) -> Self {
544+ Self {
545+ used_refs,
546+ ..Default :: default ( )
547+ }
548+ }
549+
550+ /// check if a ref is used elsewhere in the schema
551+ pub fn ref_used ( & self , ref_ : & str ) -> bool {
552+ self . used_refs . contains ( ref_)
553+ }
554+
530555 /// First of two part process to add a new validator slot, we add the `slot_ref` to the array, but not the
531556 /// actual `validator`, we can't add the validator until it's build.
532557 /// We need the `id` to build the validator, hence this two-step process.
@@ -584,3 +609,21 @@ impl BuildContext {
584609 . collect ( )
585610 }
586611}
612+
613+ fn extract_used_refs ( schema : & PyAny , refs : & mut AHashSet < String > ) -> PyResult < ( ) > {
614+ if let Ok ( dict) = schema. cast_as :: < PyDict > ( ) {
615+ let py = schema. py ( ) ;
616+ if matches ! ( dict. get_as( intern!( py, "type" ) ) , Ok ( Some ( "recursive-ref" ) ) ) {
617+ refs. insert ( dict. get_as_req ( intern ! ( py, "schema_ref" ) ) ?) ;
618+ } else {
619+ for ( _, value) in dict. iter ( ) {
620+ extract_used_refs ( value, refs) ?;
621+ }
622+ }
623+ } else if let Ok ( list) = schema. cast_as :: < PyList > ( ) {
624+ for item in list. iter ( ) {
625+ extract_used_refs ( item, refs) ?;
626+ }
627+ }
628+ Ok ( ( ) )
629+ }
0 commit comments