Skip to content

Commit dbd22f0

Browse files
Use HashSet to track
1 parent 6845b7c commit dbd22f0

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

src/schema_traverse.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
33
use pyo3::prelude::*;
44
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
55
use pyo3::{intern, Bound, PyResult};
6-
use std::collections::HashMap;
6+
use std::collections::HashSet;
77

88
const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";
99

@@ -41,21 +41,19 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4141
let schema_ref_str = schema_ref_pystr.to_str()?;
4242
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
4343

44-
if *ctx.refs_recursion_count.entry(schema_ref_str.to_string()).or_insert(0) == 0 {
44+
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
45+
// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
4546
if let Some(def) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
46-
*ctx.refs_recursion_count.get_mut(schema_ref_str).unwrap() += 1;
47-
ctx.def_refs_chain.push(schema_ref_str.to_string());
47+
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
4848
gather_schema(def.downcast_exact::<PyDict>()?, ctx)?;
49-
ctx.def_refs_chain.pop();
50-
*ctx.refs_recursion_count.get_mut(schema_ref_str).unwrap() -= 1;
49+
ctx.recursively_seen_refs.remove(schema_ref_str);
5150
}
5251
Ok(false)
5352
} else {
5453
ctx.recursive_def_refs.add(schema_ref_pystr)?;
55-
for r in &ctx.def_refs_chain {
54+
for r in &ctx.recursively_seen_refs {
5655
ctx.recursive_def_refs.add(PyString::new_bound(schema_ref.py(), r))?;
5756
}
58-
ctx.def_refs_chain.clear();
5957
Ok(true)
6058
}
6159
} else {
@@ -188,8 +186,7 @@ pub struct GatherCtx<'a, 'py> {
188186
pub def_refs: Bound<'py, PyDict>,
189187
pub recursive_def_refs: Bound<'py, PySet>,
190188
pub discriminators: Bound<'py, PyList>,
191-
refs_recursion_count: HashMap<String, i32>,
192-
def_refs_chain: Vec<String>,
189+
recursively_seen_refs: HashSet<String>,
193190
}
194191

195192
impl<'a, 'py> GatherCtx<'a, 'py> {
@@ -199,8 +196,7 @@ impl<'a, 'py> GatherCtx<'a, 'py> {
199196
def_refs: PyDict::new_bound(definitions.py()),
200197
recursive_def_refs: PySet::empty_bound(definitions.py())?,
201198
discriminators: PyList::empty_bound(definitions.py()),
202-
refs_recursion_count: HashMap::default(),
203-
def_refs_chain: Vec::new(),
199+
recursively_seen_refs: HashSet::new(),
204200
};
205201
Ok(ctx)
206202
}

0 commit comments

Comments
 (0)