Skip to content

Commit 10144b0

Browse files
count recursion count
1 parent a6d9868 commit 10144b0

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/schema_traverse.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
use std::collections::HashMap;
12
use crate::tools::py_err;
23
use pyo3::exceptions::{PyKeyError, PyValueError};
34
use pyo3::prelude::*;
45
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
56
use pyo3::{intern, Bound, PyResult};
6-
use std::collections::HashSet;
77

88
const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";
99

@@ -40,13 +40,13 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4040
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
4141
let schema_ref_str = schema_ref_pystr.to_str()?;
4242

43-
if !ctx.seen_refs.contains(schema_ref_str) {
43+
if *ctx.refs_recursion_count.entry(schema_ref_str.to_string()).or_insert(0) == 0 {
4444
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
4545

4646
if let Some(def) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
47-
ctx.seen_refs.insert(schema_ref_str.to_string());
47+
*ctx.refs_recursion_count.get_mut(schema_ref_str).unwrap() += 1;
4848
gather_schema(def.downcast_exact::<PyDict>()?, ctx)?;
49-
ctx.seen_refs.remove(schema_ref_str);
49+
*ctx.refs_recursion_count.get_mut(schema_ref_str).unwrap() -= 1;
5050
}
5151
Ok(false)
5252
} else {
@@ -183,7 +183,7 @@ pub struct GatherCtx<'a, 'py> {
183183
pub def_refs: Bound<'py, PyDict>,
184184
pub recursive_def_refs: Bound<'py, PySet>,
185185
pub discriminators: Bound<'py, PyList>,
186-
seen_refs: HashSet<String>,
186+
refs_recursion_count: HashMap<String, i32>,
187187
}
188188

189189
impl<'a, 'py> GatherCtx<'a, 'py> {
@@ -193,7 +193,7 @@ impl<'a, 'py> GatherCtx<'a, 'py> {
193193
def_refs: PyDict::new_bound(definitions.py()),
194194
recursive_def_refs: PySet::empty_bound(definitions.py())?,
195195
discriminators: PyList::empty_bound(definitions.py()),
196-
seen_refs: HashSet::new(),
196+
refs_recursion_count: HashMap::default(),
197197
};
198198
Ok(ctx)
199199
}

0 commit comments

Comments
 (0)