Skip to content

Commit 5b85f8f

Browse files
Add schema cleaning info gathering
1 parent 92a259e commit 5b85f8f

File tree

5 files changed

+244
-3
lines changed

5 files changed

+244
-3
lines changed

python/pydantic_core/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys as _sys
44
from typing import Any as _Any
5+
from typing import NamedTuple
56

67
from ._pydantic_core import (
78
ArgsKwargs,
@@ -23,11 +24,12 @@
2324
ValidationError,
2425
__version__,
2526
from_json,
27+
gather_schemas_for_cleaning,
2628
to_json,
2729
to_jsonable_python,
2830
validate_core_schema,
2931
)
30-
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType
32+
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, DefinitionReferenceSchema, ErrorType
3133

3234
if _sys.version_info < (3, 11):
3335
from typing_extensions import NotRequired as _NotRequired
@@ -67,6 +69,7 @@
6769
'from_json',
6870
'to_jsonable_python',
6971
'validate_core_schema',
72+
'gather_schemas_for_cleaning',
7073
]
7174

7275

@@ -137,3 +140,9 @@ class MultiHostHost(_TypedDict):
137140
"""The host part of this host, or `None`."""
138141
port: int | None
139142
"""The port part of this host, or `None`."""
143+
144+
145+
class GatherResult(NamedTuple):
146+
definition_refs: dict[str, list[DefinitionReferenceSchema]]
147+
recursive_refs: set[str]
148+
deferred_discriminators: list[tuple[CoreSchema, _Any]]

python/pydantic_core/_pydantic_core.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Any, Callable, Generic, Literal, TypeVar, final
55
from _typeshed import SupportsAllComparisons
66
from typing_extensions import LiteralString, Self, TypeAlias
77

8-
from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost
8+
from pydantic_core import ErrorDetails, ErrorTypeInfo, GatherResult, InitErrorDetails, MultiHostHost
99
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType
1010

1111
__all__ = [
@@ -35,6 +35,7 @@ __all__ = [
3535
'list_all_errors',
3636
'TzInfo',
3737
'validate_core_schema',
38+
'gather_schemas_for_cleaning',
3839
]
3940
__version__: str
4041
build_profile: str
@@ -1164,3 +1165,7 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C
11641165
We may also remove this function altogether, do not rely on it being present if you are
11651166
using pydantic-core directly.
11661167
"""
1168+
1169+
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
1170+
"""Used internally for schema cleaning when schemas are generated.
1171+
Gathers information from the schema tree for the cleaning."""

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod errors;
2222
mod input;
2323
mod lookup_key;
2424
mod recursion_guard;
25+
mod schema_traverse;
2526
mod serializers;
2627
mod tools;
2728
mod url;
@@ -39,7 +40,7 @@ pub use serializers::{
3940
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
4041
WarningsArg,
4142
};
42-
pub use validators::{validate_core_schema, PySome, SchemaValidator};
43+
pub use validators::{gather_schemas_for_cleaning, validate_core_schema, PySome, SchemaValidator};
4344

4445
use crate::input::Input;
4546

@@ -133,6 +134,7 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
133134
m.add_function(wrap_pyfunction!(from_json, m)?)?;
134135
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
135136
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
137+
m.add_function(wrap_pyfunction!(gather_schemas_for_cleaning, m)?)?;
136138
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
137139
Ok(())
138140
}

src/schema_traverse.rs

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
use crate::tools::py_err;
2+
use pyo3::exceptions::{PyKeyError, PyValueError};
3+
use pyo3::prelude::*;
4+
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
5+
use pyo3::{intern, Bound, PyResult};
6+
use std::collections::HashSet;
7+
8+
const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";
9+
10+
macro_rules! get {
11+
($dict: expr, $key: expr) => {{
12+
$dict.get_item(intern!($dict.py(), $key))?
13+
}};
14+
}
15+
16+
macro_rules! traverse {
17+
($func: expr, $key: expr, $dict: expr, $ctx: expr) => {{
18+
if let Some(v) = $dict.get_item(intern!($dict.py(), $key))? {
19+
$func(v.downcast_exact()?, $ctx)?
20+
}
21+
}};
22+
}
23+
24+
macro_rules! defaultdict_list_append {
25+
($dict: expr, $key: expr, $value: expr) => {{
26+
match $dict.get_item($key)? {
27+
None => {
28+
let list = PyList::empty_bound($dict.py());
29+
list.append($value)?;
30+
$dict.set_item($key, list)?;
31+
}
32+
// Safety: we know that the value is a PyList as we just created it above
33+
Some(list) => unsafe { list.downcast_unchecked::<PyList>() }.append($value)?,
34+
};
35+
}};
36+
}
37+
38+
fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<bool> {
39+
if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") {
40+
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
41+
let schema_ref_str = schema_ref_pystr.to_str()?;
42+
43+
if !ctx.seen_refs.contains(schema_ref_str) {
44+
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
45+
46+
if let Some(def) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
47+
ctx.seen_refs.insert(schema_ref_str.to_string());
48+
gather_schema(def.downcast_exact::<PyDict>()?, ctx)?;
49+
ctx.seen_refs.remove(schema_ref_str);
50+
}
51+
Ok(false)
52+
} else {
53+
ctx.recursive_def_refs.add(schema_ref_pystr)?;
54+
Ok(true)
55+
}
56+
} else {
57+
py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref")?
58+
}
59+
}
60+
61+
fn gather_meta(schema: &Bound<'_, PyDict>, meta_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
62+
if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) {
63+
let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]);
64+
ctx.discriminators.append(schema_discriminator)?;
65+
}
66+
Ok(())
67+
}
68+
69+
fn gather_node(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
70+
let type_ = get!(schema, "type");
71+
if type_.is_none() {
72+
return py_err!(PyValueError; "Schema type missing");
73+
}
74+
if type_.unwrap().downcast_exact::<PyString>()?.to_str()? == "definition-ref" {
75+
let recursive = gather_definition_ref(schema, ctx)?;
76+
if recursive {
77+
return Ok(());
78+
}
79+
}
80+
if let Some(meta) = get!(schema, "metadata") {
81+
gather_meta(schema, meta.downcast_exact()?, ctx)?;
82+
}
83+
Ok(())
84+
}
85+
86+
fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
87+
for v in schema_list.iter() {
88+
gather_schema(v.downcast_exact()?, ctx)?;
89+
}
90+
Ok(())
91+
}
92+
93+
fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
94+
for (_, v) in schemas_by_key.iter() {
95+
gather_schema(v.downcast_exact()?, ctx)?;
96+
}
97+
Ok(())
98+
}
99+
100+
fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
101+
for v in schema_list.iter() {
102+
if let Ok(tup) = v.downcast_exact::<PyTuple>() {
103+
gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?;
104+
} else {
105+
gather_schema(v.downcast_exact()?, ctx)?;
106+
}
107+
}
108+
Ok(())
109+
}
110+
111+
fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
112+
for v in arguments.iter() {
113+
if let Some(schema) = get!(v.downcast_exact::<PyDict>()?, "schema") {
114+
gather_schema(schema.downcast_exact()?, ctx)?;
115+
}
116+
}
117+
Ok(())
118+
}
119+
120+
fn traverse_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
121+
let type_ = get!(schema, "type");
122+
if type_.is_none() {
123+
return py_err!(PyValueError; "Schema type missing");
124+
}
125+
match type_.unwrap().downcast_exact::<PyString>()?.to_str()? {
126+
"definitions" => {
127+
traverse!(gather_schema, "schema", schema, ctx);
128+
traverse!(gather_list, "definitions", schema, ctx);
129+
}
130+
"list" | "set" | "frozenset" | "generator" => traverse!(gather_schema, "items_schema", schema, ctx),
131+
"tuple" => traverse!(gather_list, "items_schema", schema, ctx),
132+
"dict" => {
133+
traverse!(gather_schema, "keys_schema", schema, ctx);
134+
traverse!(gather_schema, "values_schema", schema, ctx);
135+
}
136+
"union" => traverse!(gather_union_choices, "choices", schema, ctx),
137+
"tagged-union" => traverse!(gather_dict, "choices", schema, ctx),
138+
"chain" => traverse!(gather_list, "steps", schema, ctx),
139+
"lax-or-strict" => {
140+
traverse!(gather_schema, "lax_schema", schema, ctx);
141+
traverse!(gather_schema, "strict_schema", schema, ctx);
142+
}
143+
"json-or-python" => {
144+
traverse!(gather_schema, "json_schema", schema, ctx);
145+
traverse!(gather_schema, "python_schema", schema, ctx);
146+
}
147+
"model-fields" | "typed-dict" => {
148+
traverse!(gather_schema, "extras_schema", schema, ctx);
149+
traverse!(gather_list, "computed_fields", schema, ctx);
150+
traverse!(gather_dict, "fields", schema, ctx);
151+
}
152+
"dataclass-args" => {
153+
traverse!(gather_list, "computed_fields", schema, ctx);
154+
traverse!(gather_list, "fields", schema, ctx);
155+
}
156+
"arguments" => {
157+
traverse!(gather_arguments, "arguments_schema", schema, ctx);
158+
traverse!(gather_schema, "var_args_schema", schema, ctx);
159+
traverse!(gather_schema, "var_kwargs_schema", schema, ctx);
160+
}
161+
"computed-field" | "function-plain" => traverse!(gather_schema, "return_schema", schema, ctx),
162+
"function-wrap" => {
163+
traverse!(gather_schema, "return_schema", schema, ctx);
164+
traverse!(gather_schema, "schema", schema, ctx);
165+
}
166+
"call" => {
167+
traverse!(gather_schema, "arguments_schema", schema, ctx);
168+
traverse!(gather_schema, "return_schema", schema, ctx);
169+
}
170+
_ => traverse!(gather_schema, "schema", schema, ctx),
171+
};
172+
173+
if let Some(ser) = get!(schema, "serialization") {
174+
let ser_dict = ser.downcast_exact::<PyDict>()?;
175+
traverse!(gather_schema, "schema", ser_dict, ctx);
176+
traverse!(gather_schema, "return_schema", ser_dict, ctx);
177+
}
178+
Ok(())
179+
}
180+
181+
pub struct GatherCtx<'a, 'py> {
182+
pub definitions_dict: &'a Bound<'py, PyDict>,
183+
pub def_refs: Bound<'py, PyDict>,
184+
pub recursive_def_refs: Bound<'py, PySet>,
185+
pub discriminators: Bound<'py, PyList>,
186+
seen_refs: HashSet<String>,
187+
}
188+
189+
impl<'a, 'py> GatherCtx<'a, 'py> {
190+
pub fn new(definitions: &'a Bound<'py, PyDict>) -> PyResult<Self> {
191+
let ctx = GatherCtx {
192+
definitions_dict: definitions,
193+
def_refs: PyDict::new_bound(definitions.py()),
194+
recursive_def_refs: PySet::empty_bound(definitions.py())?,
195+
discriminators: PyList::empty_bound(definitions.py()),
196+
seen_refs: HashSet::new(),
197+
};
198+
Ok(ctx)
199+
}
200+
}
201+
202+
pub fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
203+
traverse_schema(schema, ctx)?;
204+
gather_node(schema, ctx)
205+
}

src/validators/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ mod validation_state;
6464
mod with_default;
6565

6666
pub use self::validation_state::{Exactness, ValidationState};
67+
use crate::schema_traverse::{gather_schema, GatherCtx};
6768
pub use with_default::DefaultType;
6869

6970
#[pyclass(module = "pydantic_core._pydantic_core", name = "Some")]
@@ -443,6 +444,25 @@ impl<'py> SelfValidator<'py> {
443444
}
444445
}
445446

447+
#[pyfunction(signature = (schema, definitions))]
448+
pub fn gather_schemas_for_cleaning<'py>(
449+
schema: &Bound<'py, PyAny>,
450+
definitions: &Bound<'py, PyAny>,
451+
) -> PyResult<Bound<'py, PyTuple>> {
452+
let py = schema.py();
453+
let schema_dict = schema.downcast_exact::<PyDict>()?;
454+
455+
let mut ctx = GatherCtx::new(definitions.downcast_exact()?)?;
456+
gather_schema(schema_dict, &mut ctx)?;
457+
458+
let res = vec![
459+
ctx.def_refs.as_any(),
460+
ctx.recursive_def_refs.as_any(),
461+
ctx.discriminators.as_any(),
462+
];
463+
return Ok(PyTuple::new_bound(py, res));
464+
}
465+
446466
#[pyfunction(signature = (schema, *, strict = None))]
447467
pub fn validate_core_schema<'py>(schema: &Bound<'py, PyAny>, strict: Option<bool>) -> PyResult<Bound<'py, PyAny>> {
448468
let self_validator = SelfValidator::new(schema.py())?;

0 commit comments

Comments
 (0)