Skip to content

Commit 5c9c9e9

Browse files
Add schema tree node gathering for cleaning in pydantic GenerateSchema
1 parent 061711f commit 5c9c9e9

File tree

6 files changed

+386
-2
lines changed

6 files changed

+386
-2
lines changed

python/pydantic_core/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ._pydantic_core import (
77
ArgsKwargs,
8+
GatherInvalidDefinitionError,
89
MultiHostUrl,
910
PydanticCustomError,
1011
PydanticKnownError,
@@ -23,11 +24,12 @@
2324
ValidationError,
2425
__version__,
2526
from_json,
27+
gather_schemas_for_cleaning,
2628
to_json,
2729
to_jsonable_python,
2830
validate_core_schema,
2931
)
30-
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType
32+
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, DefinitionReferenceSchema, ErrorType
3133

3234
if _sys.version_info < (3, 11):
3335
from typing_extensions import NotRequired as _NotRequired
@@ -62,11 +64,13 @@
6264
'PydanticUseDefault',
6365
'PydanticSerializationError',
6466
'PydanticSerializationUnexpectedValue',
67+
'GatherInvalidDefinitionError',
6568
'TzInfo',
6669
'to_json',
6770
'from_json',
6871
'to_jsonable_python',
6972
'validate_core_schema',
73+
'gather_schemas_for_cleaning',
7074
]
7175

7276

@@ -137,3 +141,11 @@ class MultiHostHost(_TypedDict):
137141
"""The host part of this host, or `None`."""
138142
port: int | None
139143
"""The port part of this host, or `None`."""
144+
145+
146+
class GatherResult(_TypedDict):
147+
"""Internal result of gathering schemas for cleaning."""
148+
149+
inlinable_def_refs: dict[str, DefinitionReferenceSchema | None]
150+
recursive_refs: set[str]
151+
schemas_with_meta_keys: dict[str, list[CoreSchema]] | None

python/pydantic_core/_pydantic_core.pyi

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Any, Callable, Generic, Literal, TypeVar, final
55
from _typeshed import SupportsAllComparisons
66
from typing_extensions import LiteralString, Self, TypeAlias
77

8-
from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost
8+
from pydantic_core import ErrorDetails, ErrorTypeInfo, GatherResult, InitErrorDetails, MultiHostHost
99
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType
1010

1111
__all__ = [
@@ -28,13 +28,15 @@ __all__ = [
2828
'PydanticSerializationUnexpectedValue',
2929
'PydanticUndefined',
3030
'PydanticUndefinedType',
31+
'GatherInvalidDefinitionError',
3132
'Some',
3233
'to_json',
3334
'from_json',
3435
'to_jsonable_python',
3536
'list_all_errors',
3637
'TzInfo',
3738
'validate_core_schema',
39+
'gather_schemas_for_cleaning',
3840
]
3941
__version__: str
4042
build_profile: str
@@ -1011,3 +1013,14 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C
10111013
We may also remove this function altogether, do not rely on it being present if you are
10121014
using pydantic-core directly.
10131015
"""
1016+
@final
1017+
class GatherInvalidDefinitionError(Exception):
1018+
"""Internal error when encountering invalid definition refs"""
1019+
1020+
def gather_schemas_for_cleaning(
1021+
schema: CoreSchema,
1022+
definitions: dict[str, CoreSchema],
1023+
find_meta_with_keys: set[str] | None,
1024+
) -> GatherResult:
1025+
"""Used internally for schema cleaning when schemas are generated.
1026+
Gathers information from the schema tree for the cleaning."""

src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod errors;
2222
mod input;
2323
mod lookup_key;
2424
mod recursion_guard;
25+
mod schema_traverse;
2526
mod serializers;
2627
mod tools;
2728
mod url;
@@ -35,6 +36,7 @@ pub use build_tools::SchemaError;
3536
pub use errors::{
3637
list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault, ValidationError,
3738
};
39+
pub use schema_traverse::{gather_schemas_for_cleaning, GatherInvalidDefinitionError};
3840
pub use serializers::{
3941
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
4042
WarningsArg,
@@ -129,10 +131,15 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
129131
m.add_class::<ArgsKwargs>()?;
130132
m.add_class::<SchemaSerializer>()?;
131133
m.add_class::<TzInfo>()?;
134+
m.add(
135+
"GatherInvalidDefinitionError",
136+
py.get_type_bound::<GatherInvalidDefinitionError>(),
137+
)?;
132138
m.add_function(wrap_pyfunction!(to_json, m)?)?;
133139
m.add_function(wrap_pyfunction!(from_json, m)?)?;
134140
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
135141
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
142+
m.add_function(wrap_pyfunction!(gather_schemas_for_cleaning, m)?)?;
136143
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
137144
Ok(())
138145
}

src/schema_traverse.rs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
use crate::tools::py_err;
2+
use pyo3::exceptions::{PyException, PyKeyError};
3+
use pyo3::prelude::*;
4+
use pyo3::types::{PyDict, PyList, PyNone, PySet, PyString, PyTuple};
5+
use pyo3::{create_exception, intern, Bound, PyResult};
6+
7+
create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException);
8+
9+
macro_rules! none {
10+
($py: expr) => {
11+
PyNone::get_bound($py)
12+
};
13+
}
14+
15+
macro_rules! get {
16+
($dict: expr, $key: expr) => {
17+
$dict.get_item(intern!($dict.py(), $key))?
18+
};
19+
}
20+
21+
macro_rules! traverse_key_fn {
22+
($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{
23+
if let Some(v) = get!($dict, $key) {
24+
$func(v.downcast_exact()?, $ctx)?
25+
}
26+
}};
27+
}
28+
29+
macro_rules! traverse {
30+
($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{
31+
$(traverse_key_fn!($key, $func, $dict, $ctx);)*
32+
traverse_key_fn!("serialization", gather_schema, $dict, $ctx);
33+
gather_meta($dict, $ctx)
34+
}}
35+
}
36+
37+
macro_rules! defaultdict_list_append {
38+
($dict: expr, $key: expr, $value: expr) => {{
39+
match $dict.get_item($key)? {
40+
None => {
41+
let list = PyList::new_bound($dict.py(), [$value]);
42+
$dict.set_item($key, list)?;
43+
}
44+
// Safety: we know that the value is a PyList as we just created it above
45+
Some(list) => unsafe { list.downcast_unchecked::<PyList>() }.append($value)?,
46+
};
47+
}};
48+
}
49+
50+
fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
51+
let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") else {
52+
return py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref");
53+
};
54+
let schema_ref = schema_ref.downcast_exact::<PyString>()?;
55+
let py = schema_ref_dict.py();
56+
57+
if !ctx.recursively_seen_refs.contains(schema_ref)? {
58+
// Def ref in no longer consider as inlinable if its re-encountered. Then its used multiple times.
59+
// No need to retraverse it either if we already encountered this.
60+
if !ctx.inline_def_ref_candidates.contains(schema_ref)? {
61+
let Some(definition) = ctx.definitions.get_item(schema_ref)? else {
62+
return py_err!(GatherInvalidDefinitionError; "{}", schema_ref.to_str()?);
63+
};
64+
65+
ctx.inline_def_ref_candidates.set_item(schema_ref, schema_ref_dict)?;
66+
ctx.recursively_seen_refs.add(schema_ref)?;
67+
68+
gather_schema(definition.downcast_exact()?, ctx)?;
69+
traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);
70+
gather_meta(schema_ref_dict, ctx)?;
71+
72+
ctx.recursively_seen_refs.discard(schema_ref)?;
73+
} else {
74+
ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used multiple times)
75+
}
76+
} else {
77+
ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used in recursion)
78+
ctx.recursive_def_refs.add(schema_ref)?;
79+
for seen_ref in ctx.recursively_seen_refs.iter() {
80+
ctx.inline_def_ref_candidates.set_item(&seen_ref, none!(py))?; // Mark not inlinable (used in recursion)
81+
ctx.recursive_def_refs.add(seen_ref)?;
82+
}
83+
}
84+
Ok(())
85+
}
86+
87+
fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
88+
let Some((res, find_keys)) = &ctx.meta_with_keys else {
89+
return Ok(());
90+
};
91+
let Some(meta) = get!(schema, "metadata") else {
92+
return Ok(());
93+
};
94+
let meta_dict = meta.downcast_exact::<PyDict>()?;
95+
for k in find_keys.iter() {
96+
if meta_dict.contains(&k)? {
97+
defaultdict_list_append!(res, &k, schema);
98+
}
99+
}
100+
Ok(())
101+
}
102+
103+
fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
104+
for v in schema_list.iter() {
105+
gather_schema(v.downcast_exact()?, ctx)?;
106+
}
107+
Ok(())
108+
}
109+
110+
fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
111+
for (_, v) in schemas_by_key.iter() {
112+
gather_schema(v.downcast_exact()?, ctx)?;
113+
}
114+
Ok(())
115+
}
116+
117+
fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
118+
for v in schema_list.iter() {
119+
if let Ok(tup) = v.downcast_exact::<PyTuple>() {
120+
gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?;
121+
} else {
122+
gather_schema(v.downcast_exact()?, ctx)?;
123+
}
124+
}
125+
Ok(())
126+
}
127+
128+
fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
129+
for v in arguments.iter() {
130+
traverse_key_fn!("schema", gather_schema, v.downcast_exact::<PyDict>()?, ctx);
131+
}
132+
Ok(())
133+
}
134+
135+
// Has 100% coverage in Pydantic side. This is exclusively used there
136+
#[cfg_attr(has_coverage_attribute, coverage(off))]
137+
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
138+
let Some(type_) = get!(schema, "type") else {
139+
return py_err!(PyKeyError; "Schema type missing");
140+
};
141+
match type_.downcast_exact::<PyString>()?.to_str()? {
142+
"definition-ref" => gather_definition_ref(schema, ctx),
143+
"definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx),
144+
"list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx),
145+
"tuple" => traverse!("items_schema" => gather_list; schema, ctx),
146+
"dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx),
147+
"union" => traverse!("choices" => gather_union_choices; schema, ctx),
148+
"tagged-union" => traverse!("choices" => gather_dict; schema, ctx),
149+
"chain" => traverse!("steps" => gather_list; schema, ctx),
150+
"lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx),
151+
"json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx),
152+
"model-fields" | "typed-dict" => traverse!(
153+
"extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx
154+
),
155+
"dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx),
156+
"arguments" => traverse!(
157+
"arguments_schema" => gather_arguments,
158+
"var_args_schema" => gather_schema,
159+
"var_kwargs_schema" => gather_schema;
160+
schema, ctx
161+
),
162+
"call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx),
163+
"computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx),
164+
"function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx),
165+
_ => traverse!("schema" => gather_schema; schema, ctx),
166+
}
167+
}
168+
169+
struct GatherCtx<'a, 'py> {
170+
definitions: &'a Bound<'py, PyDict>,
171+
meta_with_keys: Option<(Bound<'py, PyDict>, &'a Bound<'py, PySet>)>,
172+
inline_def_ref_candidates: Bound<'py, PyDict>,
173+
recursive_def_refs: Bound<'py, PySet>,
174+
recursively_seen_refs: Bound<'py, PySet>,
175+
}
176+
177+
#[pyfunction(signature = (schema, definitions, find_meta_with_keys))]
178+
pub fn gather_schemas_for_cleaning<'py>(
179+
schema: &Bound<'py, PyAny>,
180+
definitions: &Bound<'py, PyAny>,
181+
find_meta_with_keys: &Bound<'py, PyAny>,
182+
) -> PyResult<Bound<'py, PyDict>> {
183+
let py = schema.py();
184+
let mut ctx = GatherCtx {
185+
definitions: definitions.downcast_exact()?,
186+
meta_with_keys: match find_meta_with_keys.is_none() {
187+
true => None,
188+
false => Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?)),
189+
},
190+
inline_def_ref_candidates: PyDict::new_bound(py),
191+
recursive_def_refs: PySet::empty_bound(py)?,
192+
recursively_seen_refs: PySet::empty_bound(py)?,
193+
};
194+
gather_schema(schema.downcast_exact()?, &mut ctx)?;
195+
196+
let res = PyDict::new_bound(py);
197+
res.set_item(intern!(py, "inlinable_def_refs"), ctx.inline_def_ref_candidates)?;
198+
res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?;
199+
res.set_item(intern!(py, "schemas_with_meta_keys"), ctx.meta_with_keys.map(|v| v.0))?;
200+
Ok(res)
201+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Callable
2+
3+
from pydantic_core import gather_schemas_for_cleaning
4+
5+
from .nested_schema import inlined_schema, schema_using_defs
6+
7+
8+
def test_nested_schema_using_defs(benchmark: Callable[..., None]) -> None:
9+
schema = schema_using_defs()
10+
definitions = {def_schema['ref']: def_schema for def_schema in schema['definitions']}
11+
schema = schema['schema']
12+
benchmark(gather_schemas_for_cleaning, schema, definitions, None)
13+
14+
15+
def test_nested_schema_inlined(benchmark: Callable[..., None]) -> None:
16+
schema = inlined_schema()
17+
benchmark(gather_schemas_for_cleaning, schema, {}, {'some_meta_key'})

0 commit comments

Comments
 (0)