Skip to content

Commit e518d29

Browse files
Add exception for unknown schema_ref
1 parent 2812533 commit e518d29

File tree

5 files changed

+64
-43
lines changed

5 files changed

+64
-43
lines changed

python/pydantic_core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ._pydantic_core import (
77
ArgsKwargs,
8+
GatherInvalidDefinitionError,
89
MultiHostUrl,
910
PydanticCustomError,
1011
PydanticKnownError,
@@ -63,6 +64,7 @@
6364
'PydanticUseDefault',
6465
'PydanticSerializationError',
6566
'PydanticSerializationUnexpectedValue',
67+
'GatherInvalidDefinitionError',
6668
'TzInfo',
6769
'to_json',
6870
'from_json',

python/pydantic_core/_pydantic_core.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ __all__ = [
2828
'PydanticSerializationUnexpectedValue',
2929
'PydanticUndefined',
3030
'PydanticUndefinedType',
31+
'GatherInvalidDefinitionError',
3132
'Some',
3233
'to_json',
3334
'from_json',
@@ -1165,6 +1166,9 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C
11651166
We may also remove this function altogether, do not rely on it being present if you are
11661167
using pydantic-core directly.
11671168
"""
1169+
@final
1170+
class GatherInvalidDefinitionError(Exception):
1171+
"""Internal error when encountering invalid definition refs"""
11681172

11691173
def gather_schemas_for_cleaning(
11701174
schema: CoreSchema,

src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub use build_tools::SchemaError;
3636
pub use errors::{
3737
list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault, ValidationError,
3838
};
39-
pub use schema_traverse::gather_schemas_for_cleaning;
39+
pub use schema_traverse::{gather_schemas_for_cleaning, GatherInvalidDefinitionError};
4040
pub use serializers::{
4141
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
4242
WarningsArg,
@@ -131,6 +131,10 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
131131
m.add_class::<ArgsKwargs>()?;
132132
m.add_class::<SchemaSerializer>()?;
133133
m.add_class::<TzInfo>()?;
134+
m.add(
135+
"GatherInvalidDefinitionError",
136+
py.get_type_bound::<GatherInvalidDefinitionError>(),
137+
)?;
134138
m.add_function(wrap_pyfunction!(to_json, m)?)?;
135139
m.add_function(wrap_pyfunction!(from_json, m)?)?;
136140
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;

src/schema_traverse.rs

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use crate::tools::py_err;
2-
use pyo3::exceptions::PyKeyError;
2+
use pyo3::exceptions::{PyException, PyKeyError};
33
use pyo3::prelude::*;
44
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
5-
use pyo3::{intern, Bound, PyResult};
5+
use pyo3::{create_exception, intern, Bound, PyResult};
66
use std::collections::HashSet;
77

8+
create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException);
9+
810
macro_rules! get {
911
($dict: expr, $key: expr) => {
1012
$dict.get_item(intern!($dict.py(), $key))?
@@ -42,44 +44,47 @@ macro_rules! defaultdict_list_append {
4244
}
4345

4446
fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
45-
if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") {
46-
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
47-
let schema_ref_str = schema_ref_pystr.to_str()?;
47+
let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") else {
48+
return py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref");
49+
};
50+
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
51+
let schema_ref_str = schema_ref_pystr.to_str()?;
4852

49-
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
50-
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
53+
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
54+
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
5155

52-
// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
53-
if let Some(definition) = ctx.definitions.get_item(schema_ref_pystr)? {
54-
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
56+
let Some(definition) = ctx.definitions.get_item(schema_ref_pystr)? else {
57+
return py_err!(GatherInvalidDefinitionError; "Unknown schema_ref: {}", schema_ref_str);
58+
};
5559

56-
gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?;
57-
traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);
58-
gather_meta(schema_ref_dict, ctx)?;
60+
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
5961

60-
ctx.recursively_seen_refs.remove(schema_ref_str);
61-
}
62-
} else {
63-
ctx.recursive_def_refs.add(schema_ref_pystr)?;
64-
for seen_ref in &ctx.recursively_seen_refs {
65-
let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref);
66-
ctx.recursive_def_refs.add(seen_ref_pystr)?;
67-
}
68-
}
69-
Ok(())
62+
gather_schema(definition.downcast_exact()?, ctx)?;
63+
traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);
64+
gather_meta(schema_ref_dict, ctx)?;
65+
66+
ctx.recursively_seen_refs.remove(schema_ref_str);
7067
} else {
71-
py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref")
68+
ctx.recursive_def_refs.add(schema_ref_pystr)?;
69+
for seen_ref in &ctx.recursively_seen_refs {
70+
let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref);
71+
ctx.recursive_def_refs.add(seen_ref_pystr)?;
72+
}
7273
}
74+
Ok(())
7375
}
7476

7577
fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
76-
if let Some((res, find_keys)) = &ctx.meta_with_keys {
77-
if let Some(meta) = get!(schema, "metadata") {
78-
for (k, _) in meta.downcast_exact::<PyDict>()?.iter() {
79-
if find_keys.contains(&k)? {
80-
defaultdict_list_append!(res, &k, schema);
81-
}
82-
}
78+
let Some((res, find_keys)) = &ctx.meta_with_keys else {
79+
return Ok(());
80+
};
81+
let Some(meta) = get!(schema, "metadata") else {
82+
return Ok(());
83+
};
84+
let meta_dict = meta.downcast_exact::<PyDict>()?;
85+
for k in find_keys.iter() {
86+
if meta_dict.contains(&k)? {
87+
defaultdict_list_append!(res, &k, schema);
8388
}
8489
}
8590
Ok(())
@@ -120,11 +125,10 @@ fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyRes
120125
// Has 100% coverage in Pydantic side. This is exclusively used there
121126
#[cfg_attr(has_coverage_attribute, coverage(off))]
122127
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
123-
let type_ = get!(schema, "type");
124-
if type_.is_none() {
128+
let Some(type_) = get!(schema, "type") else {
125129
return py_err!(PyKeyError; "Schema type missing");
126-
}
127-
match type_.unwrap().downcast_exact::<PyString>()?.to_str()? {
130+
};
131+
match type_.downcast_exact::<PyString>()?.to_str()? {
128132
"definition-ref" => gather_definition_ref(schema, ctx),
129133
"definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx),
130134
"list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx),
@@ -167,14 +171,12 @@ pub fn gather_schemas_for_cleaning<'py>(
167171
find_meta_with_keys: &Bound<'py, PyAny>,
168172
) -> PyResult<Bound<'py, PyDict>> {
169173
let py = schema.py();
170-
let meta_with_keys = if find_meta_with_keys.is_none() {
171-
None
172-
} else {
173-
Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?))
174-
};
175174
let mut ctx = GatherCtx {
176175
definitions: definitions.downcast_exact()?,
177-
meta_with_keys,
176+
meta_with_keys: match find_meta_with_keys.is_none() {
177+
true => None,
178+
false => Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?)),
179+
},
178180
def_refs: PyDict::new_bound(py),
179181
recursive_def_refs: PySet::empty_bound(py)?,
180182
recursively_seen_refs: HashSet::new(),

tests/test_gather_schemas_for_cleaning.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from pydantic_core import core_schema, gather_schemas_for_cleaning
1+
import pytest
2+
3+
from pydantic_core import GatherInvalidDefinitionError, core_schema, gather_schemas_for_cleaning
24

35

46
def test_no_refs():
@@ -112,3 +114,10 @@ class Model:
112114
assert res['schemas_with_meta_keys']['find_meta1'][0] is field1
113115
assert res['schemas_with_meta_keys']['find_meta1'][1] is field2
114116
assert res['schemas_with_meta_keys']['find_meta2'][0] is field2
117+
118+
119+
def test_unknown_ref():
120+
ref1 = core_schema.definition_reference_schema('ref1')
121+
schema = core_schema.tuple_schema([core_schema.int_schema(), ref1])
122+
with pytest.raises(GatherInvalidDefinitionError, match='Unknown schema_ref: ref1'):
123+
gather_schemas_for_cleaning(schema, definitions={}, find_meta_with_keys=None)

0 commit comments

Comments
 (0)