Skip to content

Commit f2b4fc2

Browse files
Add schema tree node gathering for cleaning in pydantic GenerateSchema
1 parent 92a259e commit f2b4fc2

File tree

7 files changed

+324
-3
lines changed

7 files changed

+324
-3
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ jobs:
212212
steps:
213213
- uses: actions/checkout@v4
214214
with:
215-
repository: pydantic/pydantic
215+
repository: MarkusSintonen/pydantic # TODO remove before merging
216+
ref: optimized-schema-building
216217
path: pydantic
217218

218219
- uses: actions/checkout@v4

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ python-source = "python"
5353
module-name = "pydantic_core._pydantic_core"
5454
bindings = 'pyo3'
5555
features = ["pyo3/extension-module"]
56+
profile = "release" # TEMPORARY: remove this
5657

5758
[tool.ruff]
5859
line-length = 120

python/pydantic_core/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323
ValidationError,
2424
__version__,
2525
from_json,
26+
gather_schemas_for_cleaning,
2627
to_json,
2728
to_jsonable_python,
2829
validate_core_schema,
2930
)
30-
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType
31+
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, DefinitionReferenceSchema, ErrorType
3132

3233
if _sys.version_info < (3, 11):
3334
from typing_extensions import NotRequired as _NotRequired
@@ -67,6 +68,7 @@
6768
'from_json',
6869
'to_jsonable_python',
6970
'validate_core_schema',
71+
'gather_schemas_for_cleaning',
7072
]
7173

7274

@@ -137,3 +139,11 @@ class MultiHostHost(_TypedDict):
137139
"""The host part of this host, or `None`."""
138140
port: int | None
139141
"""The port part of this host, or `None`."""
142+
143+
144+
class GatherResult(_TypedDict):
145+
"""Internal result of gathering schemas for cleaning."""
146+
147+
definition_refs: dict[str, list[DefinitionReferenceSchema]]
148+
recursive_refs: set[str]
149+
deferred_discriminators: list[tuple[CoreSchema, _Any]]

python/pydantic_core/_pydantic_core.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Any, Callable, Generic, Literal, TypeVar, final
55
from _typeshed import SupportsAllComparisons
66
from typing_extensions import LiteralString, Self, TypeAlias
77

8-
from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost
8+
from pydantic_core import ErrorDetails, ErrorTypeInfo, GatherResult, InitErrorDetails, MultiHostHost
99
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType
1010

1111
__all__ = [
@@ -35,6 +35,7 @@ __all__ = [
3535
'list_all_errors',
3636
'TzInfo',
3737
'validate_core_schema',
38+
'gather_schemas_for_cleaning',
3839
]
3940
__version__: str
4041
build_profile: str
@@ -1164,3 +1165,7 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C
11641165
We may also remove this function altogether, do not rely on it being present if you are
11651166
using pydantic-core directly.
11661167
"""
1168+
1169+
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
1170+
"""Used internally for schema cleaning when schemas are generated.
1171+
Gathers information from the schema tree for the cleaning."""

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod errors;
2222
mod input;
2323
mod lookup_key;
2424
mod recursion_guard;
25+
mod schema_traverse;
2526
mod serializers;
2627
mod tools;
2728
mod url;
@@ -35,6 +36,7 @@ pub use build_tools::SchemaError;
3536
pub use errors::{
3637
list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault, ValidationError,
3738
};
39+
pub use schema_traverse::gather_schemas_for_cleaning;
3840
pub use serializers::{
3941
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
4042
WarningsArg,
@@ -133,6 +135,7 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
133135
m.add_function(wrap_pyfunction!(from_json, m)?)?;
134136
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
135137
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
138+
m.add_function(wrap_pyfunction!(gather_schemas_for_cleaning, m)?)?;
136139
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
137140
Ok(())
138141
}

src/schema_traverse.rs

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
use crate::tools::py_err;
2+
use pyo3::exceptions::PyKeyError;
3+
use pyo3::prelude::*;
4+
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
5+
use pyo3::{intern, Bound, PyResult};
6+
use std::collections::HashSet;
7+
8+
const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";
9+
10+
macro_rules! get {
11+
($dict: expr, $key: expr) => {
12+
$dict.get_item(intern!($dict.py(), $key))?
13+
};
14+
}
15+
16+
macro_rules! traverse_key_fn {
17+
($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{
18+
if let Some(v) = get!($dict, $key) {
19+
$func(v.downcast_exact()?, $ctx)?
20+
}
21+
}};
22+
}
23+
24+
macro_rules! traverse {
25+
($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{
26+
$(traverse_key_fn!($key, $func, $dict, $ctx);)*
27+
gather_serialization($dict, $ctx)?;
28+
gather_meta($dict, $ctx)?;
29+
}}
30+
}
31+
32+
macro_rules! defaultdict_list_append {
33+
($dict: expr, $key: expr, $value: expr) => {{
34+
match $dict.get_item($key)? {
35+
None => {
36+
let list = PyList::empty_bound($dict.py());
37+
list.append($value)?;
38+
$dict.set_item($key, list)?;
39+
}
40+
// Safety: we know that the value is a PyList as we just created it above
41+
Some(list) => unsafe { list.downcast_unchecked::<PyList>() }.append($value)?,
42+
};
43+
}};
44+
}
45+
46+
fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
47+
if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") {
48+
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
49+
let schema_ref_str = schema_ref_pystr.to_str()?;
50+
51+
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
52+
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
53+
54+
// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
55+
if let Some(definition) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
56+
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
57+
58+
gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?;
59+
gather_serialization(schema_ref_dict, ctx)?;
60+
gather_meta(schema_ref_dict, ctx)?;
61+
62+
ctx.recursively_seen_refs.remove(schema_ref_str);
63+
}
64+
} else {
65+
ctx.recursive_def_refs.add(schema_ref_pystr)?;
66+
for seen_ref in &ctx.recursively_seen_refs {
67+
let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref);
68+
ctx.recursive_def_refs.add(seen_ref_pystr)?;
69+
}
70+
}
71+
Ok(())
72+
} else {
73+
py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref")
74+
}
75+
}
76+
77+
fn gather_serialization(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
78+
if let Some(ser) = get!(schema, "serialization") {
79+
let ser_dict = ser.downcast_exact::<PyDict>()?;
80+
traverse!("schema" => gather_schema, "return_schema" => gather_schema; ser_dict, ctx);
81+
}
82+
Ok(())
83+
}
84+
85+
fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
86+
if let Some(meta) = get!(schema, "metadata") {
87+
let meta_dict = meta.downcast_exact::<PyDict>()?;
88+
if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) {
89+
let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]);
90+
ctx.discriminators.append(schema_discriminator)?;
91+
}
92+
}
93+
Ok(())
94+
}
95+
96+
fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
97+
for v in schema_list.iter() {
98+
gather_schema(v.downcast_exact()?, ctx)?;
99+
}
100+
Ok(())
101+
}
102+
103+
fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
104+
for (_, v) in schemas_by_key.iter() {
105+
gather_schema(v.downcast_exact()?, ctx)?;
106+
}
107+
Ok(())
108+
}
109+
110+
fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
111+
for v in schema_list.iter() {
112+
if let Ok(tup) = v.downcast_exact::<PyTuple>() {
113+
gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?;
114+
} else {
115+
gather_schema(v.downcast_exact()?, ctx)?;
116+
}
117+
}
118+
Ok(())
119+
}
120+
121+
fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
122+
for v in arguments.iter() {
123+
traverse_key_fn!("schema", gather_schema, v.downcast_exact::<PyDict>()?, ctx);
124+
}
125+
Ok(())
126+
}
127+
128+
// Has 100% coverage in Pydantic side. This is exclusively used there
129+
#[cfg_attr(has_coverage_attribute, coverage(off))]
130+
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
131+
let type_ = get!(schema, "type");
132+
if type_.is_none() {
133+
return py_err!(PyKeyError; "Schema type missing");
134+
}
135+
match type_.unwrap().downcast_exact::<PyString>()?.to_str()? {
136+
"definition-ref" => gather_definition_ref(schema, ctx)?,
137+
"definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx),
138+
"list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx),
139+
"tuple" => traverse!("items_schema" => gather_list; schema, ctx),
140+
"dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx),
141+
"union" => traverse!("choices" => gather_union_choices; schema, ctx),
142+
"tagged-union" => traverse!("choices" => gather_dict; schema, ctx),
143+
"chain" => traverse!("steps" => gather_list; schema, ctx),
144+
"lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx),
145+
"json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx),
146+
"model-fields" | "typed-dict" => traverse!(
147+
"extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx
148+
),
149+
"dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx),
150+
"arguments" => traverse!(
151+
"arguments_schema" => gather_arguments,
152+
"var_args_schema" => gather_schema,
153+
"var_kwargs_schema" => gather_schema;
154+
schema, ctx
155+
),
156+
"call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx),
157+
"computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx),
158+
"function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx),
159+
_ => traverse!("schema" => gather_schema; schema, ctx),
160+
};
161+
Ok(())
162+
}
163+
164+
pub struct GatherCtx<'a, 'py> {
165+
pub definitions_dict: &'a Bound<'py, PyDict>,
166+
pub def_refs: Bound<'py, PyDict>,
167+
pub recursive_def_refs: Bound<'py, PySet>,
168+
pub discriminators: Bound<'py, PyList>,
169+
recursively_seen_refs: HashSet<String>,
170+
}
171+
172+
#[pyfunction(signature = (schema, definitions))]
173+
pub fn gather_schemas_for_cleaning<'py>(
174+
schema: &Bound<'py, PyAny>,
175+
definitions: &Bound<'py, PyAny>,
176+
) -> PyResult<Bound<'py, PyDict>> {
177+
let py = schema.py();
178+
let schema_dict = schema.downcast_exact::<PyDict>()?;
179+
180+
let mut ctx = GatherCtx {
181+
definitions_dict: definitions.downcast_exact()?,
182+
def_refs: PyDict::new_bound(definitions.py()),
183+
recursive_def_refs: PySet::empty_bound(definitions.py())?,
184+
discriminators: PyList::empty_bound(definitions.py()),
185+
recursively_seen_refs: HashSet::new(),
186+
};
187+
gather_schema(schema_dict, &mut ctx)?;
188+
189+
let res = PyDict::new_bound(py);
190+
res.set_item(intern!(py, "definition_refs"), ctx.def_refs)?;
191+
res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?;
192+
res.set_item(intern!(py, "deferred_discriminators"), ctx.discriminators)?;
193+
Ok(res)
194+
}

0 commit comments

Comments
 (0)