Skip to content

Commit 5dc5b6d

Browse files
Do not gather dupe ref instances
1 parent 6630bda commit 5dc5b6d

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

src/schema_traverse.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use pyo3::exceptions::{PyException, PyKeyError};
33
use pyo3::prelude::*;
44
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
55
use pyo3::{create_exception, intern, Bound, PyResult};
6+
use std::collections::HashSet;
67

78
create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException);
89

@@ -48,18 +49,23 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4849
let schema_ref = schema_ref.downcast_exact::<PyString>()?;
4950

5051
if !ctx.recursively_seen_refs.contains(schema_ref)? {
51-
let Some(definition) = ctx.definitions.get_item(schema_ref)? else {
52-
return py_err!(GatherInvalidDefinitionError; "{}", schema_ref.to_str()?);
53-
};
54-
defaultdict_list_append!(&ctx.def_refs, schema_ref, schema_ref_dict);
52+
// Check if we already gathered this definition ref instance.
53+
// No need to again process the same def ref instance if we already did it.
54+
if ctx.seen_ref_instances.insert(schema_ref_dict.as_ptr() as isize) {
55+
let Some(definition) = ctx.definitions.get_item(schema_ref)? else {
56+
return py_err!(GatherInvalidDefinitionError; "{}", schema_ref.to_str()?);
57+
};
58+
59+
defaultdict_list_append!(&ctx.def_refs, schema_ref, schema_ref_dict);
5560

56-
ctx.recursively_seen_refs.add(schema_ref)?;
61+
ctx.recursively_seen_refs.add(schema_ref)?;
5762

58-
gather_schema(definition.downcast_exact()?, ctx)?;
59-
traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);
60-
gather_meta(schema_ref_dict, ctx)?;
63+
gather_schema(definition.downcast_exact()?, ctx)?;
64+
traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);
65+
gather_meta(schema_ref_dict, ctx)?;
6166

62-
ctx.recursively_seen_refs.discard(schema_ref)?;
67+
ctx.recursively_seen_refs.discard(schema_ref)?;
68+
}
6369
} else {
6470
ctx.recursive_def_refs.add(schema_ref)?;
6571
for seen_ref in ctx.recursively_seen_refs.iter() {
@@ -157,6 +163,7 @@ struct GatherCtx<'a, 'py> {
157163
def_refs: Bound<'py, PyDict>,
158164
recursive_def_refs: Bound<'py, PySet>,
159165
recursively_seen_refs: Bound<'py, PySet>,
166+
seen_ref_instances: HashSet<isize>,
160167
}
161168

162169
#[pyfunction(signature = (schema, definitions, find_meta_with_keys))]
@@ -175,6 +182,7 @@ pub fn gather_schemas_for_cleaning<'py>(
175182
def_refs: PyDict::new_bound(py),
176183
recursive_def_refs: PySet::empty_bound(py)?,
177184
recursively_seen_refs: PySet::empty_bound(py)?,
185+
seen_ref_instances: HashSet::new(),
178186
};
179187
gather_schema(schema.downcast_exact()?, &mut ctx)?;
180188

tests/test_gather_schemas_for_cleaning.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,22 @@ def test_unknown_ref():
121121
schema = core_schema.tuple_schema([core_schema.int_schema(), ref1])
122122
with pytest.raises(GatherInvalidDefinitionError, match='ref1'):
123123
gather_schemas_for_cleaning(schema, definitions={}, find_meta_with_keys=None)
124+
125+
126+
def test_no_duplicate_ref_instances_gathered():
127+
schema1 = core_schema.tuple_schema([core_schema.str_schema(), core_schema.int_schema()])
128+
schema2 = core_schema.tuple_schema(
129+
[core_schema.definition_reference_schema('ref1'), core_schema.definition_reference_schema('ref1')]
130+
)
131+
schema3 = core_schema.tuple_schema(
132+
[core_schema.definition_reference_schema('ref2'), core_schema.definition_reference_schema('ref2')]
133+
)
134+
definitions = {'ref1': schema1, 'ref2': schema2}
135+
136+
res = gather_schemas_for_cleaning(schema3, definitions=definitions, find_meta_with_keys=None)
137+
assert res['definition_refs'] == {
138+
'ref1': [schema2['items_schema'][0], schema2['items_schema'][1]],
139+
'ref2': [schema3['items_schema'][0], schema3['items_schema'][1]],
140+
}
141+
assert res['recursive_refs'] == set()
142+
assert res['schemas_with_meta_keys'] is None

0 commit comments

Comments
 (0)