1+ use std:: collections:: HashMap ;
12use crate :: tools:: py_err;
23use pyo3:: exceptions:: { PyKeyError , PyValueError } ;
34use pyo3:: prelude:: * ;
45use pyo3:: types:: { PyDict , PyList , PySet , PyString , PyTuple } ;
56use pyo3:: { intern, Bound , PyResult } ;
6- use std:: collections:: HashSet ;
77
88const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY : & str = "pydantic.internal.union_discriminator" ;
99
@@ -40,13 +40,13 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4040 let schema_ref_pystr = schema_ref. downcast_exact :: < PyString > ( ) ?;
4141 let schema_ref_str = schema_ref_pystr. to_str ( ) ?;
4242
43- if ! ctx. seen_refs . contains ( schema_ref_str) {
43+ if * ctx. refs_recursion_count . entry ( schema_ref_str. to_string ( ) ) . or_insert ( 0 ) == 0 {
4444 defaultdict_list_append ! ( & ctx. def_refs, schema_ref_pystr, schema_ref_dict) ;
4545
4646 if let Some ( def) = ctx. definitions_dict . get_item ( schema_ref_pystr) ? {
47- ctx. seen_refs . insert ( schema_ref_str. to_string ( ) ) ;
47+ * ctx. refs_recursion_count . get_mut ( schema_ref_str) . unwrap ( ) += 1 ;
4848 gather_schema ( def. downcast_exact :: < PyDict > ( ) ?, ctx) ?;
49- ctx. seen_refs . remove ( schema_ref_str) ;
49+ * ctx. refs_recursion_count . get_mut ( schema_ref_str) . unwrap ( ) -= 1 ;
5050 }
5151 Ok ( false )
5252 } else {
@@ -183,7 +183,7 @@ pub struct GatherCtx<'a, 'py> {
183183 pub def_refs : Bound < ' py , PyDict > ,
184184 pub recursive_def_refs : Bound < ' py , PySet > ,
185185 pub discriminators : Bound < ' py , PyList > ,
186- seen_refs : HashSet < String > ,
186+ refs_recursion_count : HashMap < String , i32 > ,
187187}
188188
189189impl < ' a , ' py > GatherCtx < ' a , ' py > {
@@ -193,7 +193,7 @@ impl<'a, 'py> GatherCtx<'a, 'py> {
193193 def_refs : PyDict :: new_bound ( definitions. py ( ) ) ,
194194 recursive_def_refs : PySet :: empty_bound ( definitions. py ( ) ) ?,
195195 discriminators : PyList :: empty_bound ( definitions. py ( ) ) ,
196- seen_refs : HashSet :: new ( ) ,
196+ refs_recursion_count : HashMap :: default ( ) ,
197197 } ;
198198 Ok ( ctx)
199199 }
0 commit comments