Skip to content

Commit 2b87d5b

Browse files
Clean up. Use shorthand traversal helper
1 parent 3e63e0b commit 2b87d5b

File tree

1 file changed

+68
-103
lines changed

1 file changed

+68
-103
lines changed

src/schema_traverse.rs

Lines changed: 68 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,27 @@ use std::collections::HashSet;
88
const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";
99

1010
macro_rules! get {
11-
($dict: expr, $key: expr) => {{
11+
($dict: expr, $key: expr) => {
1212
$dict.get_item(intern!($dict.py(), $key))?
13-
}};
13+
};
1414
}
1515

16-
macro_rules! traverse {
17-
($func: expr, $key: expr, $dict: expr, $ctx: expr) => {{
18-
if let Some(v) = $dict.get_item(intern!($dict.py(), $key))? {
16+
macro_rules! traverse_key_fn {
17+
($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{
18+
if let Some(v) = get!($dict, $key) {
1919
$func(v.downcast_exact()?, $ctx)?
2020
}
2121
}};
2222
}
2323

24+
macro_rules! traverse {
25+
($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{
26+
$(traverse_key_fn!($key, $func, $dict, $ctx);)*
27+
gather_serialization($dict, $ctx)?;
28+
gather_meta($dict, $ctx)?;
29+
}}
30+
}
31+
2432
macro_rules! defaultdict_list_append {
2533
($dict: expr, $key: expr, $value: expr) => {{
2634
match $dict.get_item($key)? {
@@ -35,53 +43,51 @@ macro_rules! defaultdict_list_append {
3543
}};
3644
}
3745

38-
fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<bool> {
46+
fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
3947
if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") {
4048
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
4149
let schema_ref_str = schema_ref_pystr.to_str()?;
4250
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
4351

4452
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
4553
// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
46-
if let Some(def) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
54+
if let Some(definition) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
4755
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
48-
gather_schema(def.downcast_exact::<PyDict>()?, ctx)?;
56+
57+
gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?;
58+
gather_serialization(schema_ref_dict, ctx)?;
59+
gather_meta(schema_ref_dict, ctx)?;
60+
4961
ctx.recursively_seen_refs.remove(schema_ref_str);
5062
}
51-
Ok(false)
5263
} else {
5364
ctx.recursive_def_refs.add(schema_ref_pystr)?;
54-
for r in &ctx.recursively_seen_refs {
55-
ctx.recursive_def_refs.add(PyString::new_bound(schema_ref.py(), r))?;
65+
for seen_ref in &ctx.recursively_seen_refs {
66+
let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref);
67+
ctx.recursive_def_refs.add(seen_ref_pystr)?;
5668
}
57-
Ok(true)
5869
}
70+
Ok(())
5971
} else {
60-
py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref")?
72+
py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref")
6173
}
6274
}
6375

64-
fn gather_meta(schema: &Bound<'_, PyDict>, meta_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
65-
if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) {
66-
let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]);
67-
ctx.discriminators.append(schema_discriminator)?;
76+
fn gather_serialization(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
77+
if let Some(ser) = get!(schema, "serialization") {
78+
let ser_dict = ser.downcast_exact::<PyDict>()?;
79+
traverse!("schema" => gather_schema, "return_schema" => gather_schema; ser_dict, ctx);
6880
}
6981
Ok(())
7082
}
7183

72-
fn gather_node(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
73-
let type_ = get!(schema, "type");
74-
if type_.is_none() {
75-
return py_err!(PyValueError; "Schema type missing");
76-
}
77-
if type_.unwrap().downcast_exact::<PyString>()?.to_str()? == "definition-ref" {
78-
let recursive = gather_definition_ref(schema, ctx)?;
79-
if recursive {
80-
return Ok(());
81-
}
82-
}
84+
fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
8385
if let Some(meta) = get!(schema, "metadata") {
84-
gather_meta(schema, meta.downcast_exact()?, ctx)?;
86+
let meta_dict = meta.downcast_exact::<PyDict>()?;
87+
if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) {
88+
let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]);
89+
ctx.discriminators.append(schema_discriminator)?;
90+
}
8591
}
8692
Ok(())
8793
}
@@ -113,71 +119,42 @@ fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) ->
113119

114120
fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
115121
for v in arguments.iter() {
116-
if let Some(schema) = get!(v.downcast_exact::<PyDict>()?, "schema") {
117-
gather_schema(schema.downcast_exact()?, ctx)?;
118-
}
122+
traverse_key_fn!("schema", gather_schema, v.downcast_exact::<PyDict>()?, ctx);
119123
}
120124
Ok(())
121125
}
122126

123-
fn traverse_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
127+
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
124128
let type_ = get!(schema, "type");
125129
if type_.is_none() {
126130
return py_err!(PyValueError; "Schema type missing");
127131
}
128132
match type_.unwrap().downcast_exact::<PyString>()?.to_str()? {
129-
"definitions" => {
130-
traverse!(gather_schema, "schema", schema, ctx);
131-
traverse!(gather_list, "definitions", schema, ctx);
132-
}
133-
"list" | "set" | "frozenset" | "generator" => traverse!(gather_schema, "items_schema", schema, ctx),
134-
"tuple" => traverse!(gather_list, "items_schema", schema, ctx),
135-
"dict" => {
136-
traverse!(gather_schema, "keys_schema", schema, ctx);
137-
traverse!(gather_schema, "values_schema", schema, ctx);
138-
}
139-
"union" => traverse!(gather_union_choices, "choices", schema, ctx),
140-
"tagged-union" => traverse!(gather_dict, "choices", schema, ctx),
141-
"chain" => traverse!(gather_list, "steps", schema, ctx),
142-
"lax-or-strict" => {
143-
traverse!(gather_schema, "lax_schema", schema, ctx);
144-
traverse!(gather_schema, "strict_schema", schema, ctx);
145-
}
146-
"json-or-python" => {
147-
traverse!(gather_schema, "json_schema", schema, ctx);
148-
traverse!(gather_schema, "python_schema", schema, ctx);
149-
}
150-
"model-fields" | "typed-dict" => {
151-
traverse!(gather_schema, "extras_schema", schema, ctx);
152-
traverse!(gather_list, "computed_fields", schema, ctx);
153-
traverse!(gather_dict, "fields", schema, ctx);
154-
}
155-
"dataclass-args" => {
156-
traverse!(gather_list, "computed_fields", schema, ctx);
157-
traverse!(gather_list, "fields", schema, ctx);
158-
}
159-
"arguments" => {
160-
traverse!(gather_arguments, "arguments_schema", schema, ctx);
161-
traverse!(gather_schema, "var_args_schema", schema, ctx);
162-
traverse!(gather_schema, "var_kwargs_schema", schema, ctx);
163-
}
164-
"computed-field" | "function-plain" => traverse!(gather_schema, "return_schema", schema, ctx),
165-
"function-wrap" => {
166-
traverse!(gather_schema, "return_schema", schema, ctx);
167-
traverse!(gather_schema, "schema", schema, ctx);
168-
}
169-
"call" => {
170-
traverse!(gather_schema, "arguments_schema", schema, ctx);
171-
traverse!(gather_schema, "return_schema", schema, ctx);
172-
}
173-
_ => traverse!(gather_schema, "schema", schema, ctx),
133+
"definition-ref" => gather_definition_ref(schema, ctx)?,
134+
"definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx),
135+
"list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx),
136+
"tuple" => traverse!("items_schema" => gather_list; schema, ctx),
137+
"dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx),
138+
"union" => traverse!("choices" => gather_union_choices; schema, ctx),
139+
"tagged-union" => traverse!("choices" => gather_dict; schema, ctx),
140+
"chain" => traverse!("steps" => gather_list; schema, ctx),
141+
"lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx),
142+
"json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx),
143+
"model-fields" | "typed-dict" => traverse!(
144+
"extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx
145+
),
146+
"dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx),
147+
"arguments" => traverse!(
148+
"arguments_schema" => gather_arguments,
149+
"var_args_schema" => gather_schema,
150+
"var_kwargs_schema" => gather_schema;
151+
schema, ctx
152+
),
153+
"call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx),
154+
"computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx),
155+
"function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx),
156+
_ => traverse!("schema" => gather_schema; schema, ctx),
174157
};
175-
176-
if let Some(ser) = get!(schema, "serialization") {
177-
let ser_dict = ser.downcast_exact::<PyDict>()?;
178-
traverse!(gather_schema, "schema", ser_dict, ctx);
179-
traverse!(gather_schema, "return_schema", ser_dict, ctx);
180-
}
181158
Ok(())
182159
}
183160

@@ -189,24 +166,6 @@ pub struct GatherCtx<'a, 'py> {
189166
recursively_seen_refs: HashSet<String>,
190167
}
191168

192-
impl<'a, 'py> GatherCtx<'a, 'py> {
193-
pub fn new(definitions: &'a Bound<'py, PyDict>) -> PyResult<Self> {
194-
let ctx = GatherCtx {
195-
definitions_dict: definitions,
196-
def_refs: PyDict::new_bound(definitions.py()),
197-
recursive_def_refs: PySet::empty_bound(definitions.py())?,
198-
discriminators: PyList::empty_bound(definitions.py()),
199-
recursively_seen_refs: HashSet::new(),
200-
};
201-
Ok(ctx)
202-
}
203-
}
204-
205-
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
206-
traverse_schema(schema, ctx)?;
207-
gather_node(schema, ctx)
208-
}
209-
210169
#[pyfunction(signature = (schema, definitions))]
211170
pub fn gather_schemas_for_cleaning<'py>(
212171
schema: &Bound<'py, PyAny>,
@@ -215,7 +174,13 @@ pub fn gather_schemas_for_cleaning<'py>(
215174
let py = schema.py();
216175
let schema_dict = schema.downcast_exact::<PyDict>()?;
217176

218-
let mut ctx = GatherCtx::new(definitions.downcast_exact()?)?;
177+
let mut ctx = GatherCtx {
178+
definitions_dict: definitions.downcast_exact()?,
179+
def_refs: PyDict::new_bound(definitions.py()),
180+
recursive_def_refs: PySet::empty_bound(definitions.py())?,
181+
discriminators: PyList::empty_bound(definitions.py()),
182+
recursively_seen_refs: HashSet::new(),
183+
};
219184
gather_schema(schema_dict, &mut ctx)?;
220185

221186
let res = PyDict::new_bound(py);

0 commit comments

Comments
 (0)