@@ -3,7 +3,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
33use pyo3:: prelude:: * ;
44use pyo3:: types:: { PyDict , PyList , PySet , PyString , PyTuple } ;
55use pyo3:: { intern, Bound , PyResult } ;
6- use std:: collections:: HashMap ;
6+ use std:: collections:: HashSet ;
77
88const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY : & str = "pydantic.internal.union_discriminator" ;
99
@@ -41,21 +41,19 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4141 let schema_ref_str = schema_ref_pystr. to_str ( ) ?;
4242 defaultdict_list_append ! ( & ctx. def_refs, schema_ref_pystr, schema_ref_dict) ;
4343
44- if * ctx. refs_recursion_count . entry ( schema_ref_str. to_string ( ) ) . or_insert ( 0 ) == 0 {
44+ if !ctx. recursively_seen_refs . contains ( schema_ref_str) {
45+ // TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
4546 if let Some ( def) = ctx. definitions_dict . get_item ( schema_ref_pystr) ? {
46- * ctx. refs_recursion_count . get_mut ( schema_ref_str) . unwrap ( ) += 1 ;
47- ctx. def_refs_chain . push ( schema_ref_str. to_string ( ) ) ;
47+ ctx. recursively_seen_refs . insert ( schema_ref_str. to_string ( ) ) ;
4848 gather_schema ( def. downcast_exact :: < PyDict > ( ) ?, ctx) ?;
49- ctx. def_refs_chain . pop ( ) ;
50- * ctx. refs_recursion_count . get_mut ( schema_ref_str) . unwrap ( ) -= 1 ;
49+ ctx. recursively_seen_refs . remove ( schema_ref_str) ;
5150 }
5251 Ok ( false )
5352 } else {
5453 ctx. recursive_def_refs . add ( schema_ref_pystr) ?;
55- for r in & ctx. def_refs_chain {
54+ for r in & ctx. recursively_seen_refs {
5655 ctx. recursive_def_refs . add ( PyString :: new_bound ( schema_ref. py ( ) , r) ) ?;
5756 }
58- ctx. def_refs_chain . clear ( ) ;
5957 Ok ( true )
6058 }
6159 } else {
@@ -188,8 +186,7 @@ pub struct GatherCtx<'a, 'py> {
188186 pub def_refs : Bound < ' py , PyDict > ,
189187 pub recursive_def_refs : Bound < ' py , PySet > ,
190188 pub discriminators : Bound < ' py , PyList > ,
191- refs_recursion_count : HashMap < String , i32 > ,
192- def_refs_chain : Vec < String > ,
189+ recursively_seen_refs : HashSet < String > ,
193190}
194191
195192impl < ' a , ' py > GatherCtx < ' a , ' py > {
@@ -199,8 +196,7 @@ impl<'a, 'py> GatherCtx<'a, 'py> {
199196 def_refs : PyDict :: new_bound ( definitions. py ( ) ) ,
200197 recursive_def_refs : PySet :: empty_bound ( definitions. py ( ) ) ?,
201198 discriminators : PyList :: empty_bound ( definitions. py ( ) ) ,
202- refs_recursion_count : HashMap :: default ( ) ,
203- def_refs_chain : Vec :: new ( ) ,
199+ recursively_seen_refs : HashSet :: new ( ) ,
204200 } ;
205201 Ok ( ctx)
206202 }
0 commit comments