Skip to content

Commit 6630bda

Browse files
Use PySet
1 parent 31b7491 commit 6630bda

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

src/schema_traverse.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use pyo3::exceptions::{PyException, PyKeyError};
33
use pyo3::prelude::*;
44
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
55
use pyo3::{create_exception, intern, Bound, PyResult};
6-
use std::collections::HashSet;
76

87
create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException);
98

@@ -46,27 +45,25 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4645
let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") else {
4746
return py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref");
4847
};
49-
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
50-
let schema_ref_str = schema_ref_pystr.to_str()?;
48+
let schema_ref = schema_ref.downcast_exact::<PyString>()?;
5149

52-
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
53-
let Some(definition) = ctx.definitions.get_item(schema_ref_pystr)? else {
54-
return py_err!(GatherInvalidDefinitionError; "{}", schema_ref_str);
50+
if !ctx.recursively_seen_refs.contains(schema_ref)? {
51+
let Some(definition) = ctx.definitions.get_item(schema_ref)? else {
52+
return py_err!(GatherInvalidDefinitionError; "{}", schema_ref.to_str()?);
5553
};
56-
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
54+
defaultdict_list_append!(&ctx.def_refs, schema_ref, schema_ref_dict);
5755

58-
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
56+
ctx.recursively_seen_refs.add(schema_ref)?;
5957

6058
gather_schema(definition.downcast_exact()?, ctx)?;
6159
traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);
6260
gather_meta(schema_ref_dict, ctx)?;
6361

64-
ctx.recursively_seen_refs.remove(schema_ref_str);
62+
ctx.recursively_seen_refs.discard(schema_ref)?;
6563
} else {
66-
ctx.recursive_def_refs.add(schema_ref_pystr)?;
67-
for seen_ref in &ctx.recursively_seen_refs {
68-
let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref);
69-
ctx.recursive_def_refs.add(seen_ref_pystr)?;
64+
ctx.recursive_def_refs.add(schema_ref)?;
65+
for seen_ref in ctx.recursively_seen_refs.iter() {
66+
ctx.recursive_def_refs.add(seen_ref)?;
7067
}
7168
}
7269
Ok(())
@@ -159,7 +156,7 @@ struct GatherCtx<'a, 'py> {
159156
meta_with_keys: Option<(Bound<'py, PyDict>, &'a Bound<'py, PySet>)>,
160157
def_refs: Bound<'py, PyDict>,
161158
recursive_def_refs: Bound<'py, PySet>,
162-
recursively_seen_refs: HashSet<String>,
159+
recursively_seen_refs: Bound<'py, PySet>,
163160
}
164161

165162
#[pyfunction(signature = (schema, definitions, find_meta_with_keys))]
@@ -177,7 +174,7 @@ pub fn gather_schemas_for_cleaning<'py>(
177174
},
178175
def_refs: PyDict::new_bound(py),
179176
recursive_def_refs: PySet::empty_bound(py)?,
180-
recursively_seen_refs: HashSet::new(),
177+
recursively_seen_refs: PySet::empty_bound(py)?,
181178
};
182179
gather_schema(schema.downcast_exact()?, &mut ctx)?;
183180

0 commit comments

Comments
 (0)