|
1 | 1 | use crate::tools::py_err; |
2 | | -use pyo3::exceptions::PyKeyError; |
| 2 | +use pyo3::exceptions::{PyException, PyKeyError}; |
3 | 3 | use pyo3::prelude::*; |
4 | 4 | use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple}; |
5 | | -use pyo3::{intern, Bound, PyResult}; |
| 5 | +use pyo3::{create_exception, intern, Bound, PyResult}; |
6 | 6 | use std::collections::HashSet; |
7 | 7 |
|
| 8 | +create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException); |
| 9 | + |
8 | 10 | macro_rules! get { |
9 | 11 | ($dict: expr, $key: expr) => { |
10 | 12 | $dict.get_item(intern!($dict.py(), $key))? |
@@ -42,44 +44,47 @@ macro_rules! defaultdict_list_append { |
42 | 44 | } |
43 | 45 |
|
44 | 46 | fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
45 | | - if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") { |
46 | | - let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?; |
47 | | - let schema_ref_str = schema_ref_pystr.to_str()?; |
| 47 | + let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") else { |
| 48 | + return py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref"); |
| 49 | + }; |
| 50 | + let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?; |
| 51 | + let schema_ref_str = schema_ref_pystr.to_str()?; |
48 | 52 |
|
49 | | - if !ctx.recursively_seen_refs.contains(schema_ref_str) { |
50 | | - defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict); |
| 53 | + if !ctx.recursively_seen_refs.contains(schema_ref_str) { |
| 54 | + defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict); |
51 | 55 |
|
52 | | - // TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side |
53 | | - if let Some(definition) = ctx.definitions.get_item(schema_ref_pystr)? { |
54 | | - ctx.recursively_seen_refs.insert(schema_ref_str.to_string()); |
| 56 | + let Some(definition) = ctx.definitions.get_item(schema_ref_pystr)? else { |
| 57 | + return py_err!(GatherInvalidDefinitionError; "Unknown schema_ref: {}", schema_ref_str); |
| 58 | + }; |
55 | 59 |
|
56 | | - gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?; |
57 | | - traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx); |
58 | | - gather_meta(schema_ref_dict, ctx)?; |
| 60 | + ctx.recursively_seen_refs.insert(schema_ref_str.to_string()); |
59 | 61 |
|
60 | | - ctx.recursively_seen_refs.remove(schema_ref_str); |
61 | | - } |
62 | | - } else { |
63 | | - ctx.recursive_def_refs.add(schema_ref_pystr)?; |
64 | | - for seen_ref in &ctx.recursively_seen_refs { |
65 | | - let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref); |
66 | | - ctx.recursive_def_refs.add(seen_ref_pystr)?; |
67 | | - } |
68 | | - } |
69 | | - Ok(()) |
| 62 | + gather_schema(definition.downcast_exact()?, ctx)?; |
| 63 | + traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx); |
| 64 | + gather_meta(schema_ref_dict, ctx)?; |
| 65 | + |
| 66 | + ctx.recursively_seen_refs.remove(schema_ref_str); |
70 | 67 | } else { |
71 | | - py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref") |
| 68 | + ctx.recursive_def_refs.add(schema_ref_pystr)?; |
| 69 | + for seen_ref in &ctx.recursively_seen_refs { |
| 70 | + let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref); |
| 71 | + ctx.recursive_def_refs.add(seen_ref_pystr)?; |
| 72 | + } |
72 | 73 | } |
| 74 | + Ok(()) |
73 | 75 | } |
74 | 76 |
|
75 | 77 | fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
76 | | - if let Some((res, find_keys)) = &ctx.meta_with_keys { |
77 | | - if let Some(meta) = get!(schema, "metadata") { |
78 | | - for (k, _) in meta.downcast_exact::<PyDict>()?.iter() { |
79 | | - if find_keys.contains(&k)? { |
80 | | - defaultdict_list_append!(res, &k, schema); |
81 | | - } |
82 | | - } |
| 78 | + let Some((res, find_keys)) = &ctx.meta_with_keys else { |
| 79 | + return Ok(()); |
| 80 | + }; |
| 81 | + let Some(meta) = get!(schema, "metadata") else { |
| 82 | + return Ok(()); |
| 83 | + }; |
| 84 | + let meta_dict = meta.downcast_exact::<PyDict>()?; |
| 85 | + for k in find_keys.iter() { |
| 86 | + if meta_dict.contains(&k)? { |
| 87 | + defaultdict_list_append!(res, &k, schema); |
83 | 88 | } |
84 | 89 | } |
85 | 90 | Ok(()) |
@@ -120,11 +125,10 @@ fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyRes |
120 | 125 | // Has 100% coverage in Pydantic side. This is exclusively used there |
121 | 126 | #[cfg_attr(has_coverage_attribute, coverage(off))] |
122 | 127 | fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
123 | | - let type_ = get!(schema, "type"); |
124 | | - if type_.is_none() { |
| 128 | + let Some(type_) = get!(schema, "type") else { |
125 | 129 | return py_err!(PyKeyError; "Schema type missing"); |
126 | | - } |
127 | | - match type_.unwrap().downcast_exact::<PyString>()?.to_str()? { |
| 130 | + }; |
| 131 | + match type_.downcast_exact::<PyString>()?.to_str()? { |
128 | 132 | "definition-ref" => gather_definition_ref(schema, ctx), |
129 | 133 | "definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx), |
130 | 134 | "list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx), |
@@ -167,14 +171,12 @@ pub fn gather_schemas_for_cleaning<'py>( |
167 | 171 | find_meta_with_keys: &Bound<'py, PyAny>, |
168 | 172 | ) -> PyResult<Bound<'py, PyDict>> { |
169 | 173 | let py = schema.py(); |
170 | | - let meta_with_keys = if find_meta_with_keys.is_none() { |
171 | | - None |
172 | | - } else { |
173 | | - Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?)) |
174 | | - }; |
175 | 174 | let mut ctx = GatherCtx { |
176 | 175 | definitions: definitions.downcast_exact()?, |
177 | | - meta_with_keys, |
| 176 | + meta_with_keys: match find_meta_with_keys.is_none() { |
| 177 | + true => None, |
| 178 | + false => Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?)), |
| 179 | + }, |
178 | 180 | def_refs: PyDict::new_bound(py), |
179 | 181 | recursive_def_refs: PySet::empty_bound(py)?, |
180 | 182 | recursively_seen_refs: HashSet::new(), |
|
0 commit comments