Skip to content

Commit 7eae801

Browse files
Add schema tree node gathering for cleaning in pydantic GenerateSchema
1 parent 92a259e commit 7eae801

File tree

7 files changed

+317
-3
lines changed

7 files changed

+317
-3
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ jobs:
212212
steps:
213213
- uses: actions/checkout@v4
214214
with:
215-
repository: pydantic/pydantic
215+
repository: MarkusSintonen/pydantic # TODO remove before merging
216+
ref: optimized-schema-building
216217
path: pydantic
217218

218219
- uses: actions/checkout@v4

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ python-source = "python"
5353
module-name = "pydantic_core._pydantic_core"
5454
bindings = 'pyo3'
5555
features = ["pyo3/extension-module"]
56+
profile = "release" # TEMPORARY: remove this
5657

5758
[tool.ruff]
5859
line-length = 120

python/pydantic_core/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323
ValidationError,
2424
__version__,
2525
from_json,
26+
gather_schemas_for_cleaning,
2627
to_json,
2728
to_jsonable_python,
2829
validate_core_schema,
2930
)
30-
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType
31+
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, DefinitionReferenceSchema, ErrorType
3132

3233
if _sys.version_info < (3, 11):
3334
from typing_extensions import NotRequired as _NotRequired
@@ -67,6 +68,7 @@
6768
'from_json',
6869
'to_jsonable_python',
6970
'validate_core_schema',
71+
'gather_schemas_for_cleaning',
7072
]
7173

7274

@@ -137,3 +139,11 @@ class MultiHostHost(_TypedDict):
137139
"""The host part of this host, or `None`."""
138140
port: int | None
139141
"""The port part of this host, or `None`."""
142+
143+
144+
class GatherResult(_TypedDict):
145+
"""Internal result of gathering schemas for cleaning."""
146+
147+
definition_refs: dict[str, list[DefinitionReferenceSchema]]
148+
recursive_refs: set[str]
149+
deferred_discriminators: list[tuple[CoreSchema, _Any]]

python/pydantic_core/_pydantic_core.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Any, Callable, Generic, Literal, TypeVar, final
55
from _typeshed import SupportsAllComparisons
66
from typing_extensions import LiteralString, Self, TypeAlias
77

8-
from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost
8+
from pydantic_core import ErrorDetails, ErrorTypeInfo, GatherResult, InitErrorDetails, MultiHostHost
99
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType
1010

1111
__all__ = [
@@ -35,6 +35,7 @@ __all__ = [
3535
'list_all_errors',
3636
'TzInfo',
3737
'validate_core_schema',
38+
'gather_schemas_for_cleaning',
3839
]
3940
__version__: str
4041
build_profile: str
@@ -1164,3 +1165,7 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C
11641165
We may also remove this function altogether, do not rely on it being present if you are
11651166
using pydantic-core directly.
11661167
"""
1168+
1169+
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
1170+
"""Used internally for schema cleaning when schemas are generated.
1171+
Gathers information from the schema tree for the cleaning."""

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod errors;
2222
mod input;
2323
mod lookup_key;
2424
mod recursion_guard;
25+
mod schema_traverse;
2526
mod serializers;
2627
mod tools;
2728
mod url;
@@ -35,6 +36,7 @@ pub use build_tools::SchemaError;
3536
pub use errors::{
3637
list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault, ValidationError,
3738
};
39+
pub use schema_traverse::gather_schemas_for_cleaning;
3840
pub use serializers::{
3941
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
4042
WarningsArg,
@@ -133,6 +135,7 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
133135
m.add_function(wrap_pyfunction!(from_json, m)?)?;
134136
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
135137
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
138+
m.add_function(wrap_pyfunction!(gather_schemas_for_cleaning, m)?)?;
136139
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
137140
Ok(())
138141
}

src/schema_traverse.rs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
use crate::tools::py_err;
2+
use pyo3::exceptions::PyKeyError;
3+
use pyo3::prelude::*;
4+
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
5+
use pyo3::{intern, Bound, PyResult};
6+
use std::collections::HashSet;
7+
8+
const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";
9+
10+
macro_rules! get {
11+
($dict: expr, $key: expr) => {
12+
$dict.get_item(intern!($dict.py(), $key))?
13+
};
14+
}
15+
16+
macro_rules! traverse_key_fn {
17+
($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{
18+
if let Some(v) = get!($dict, $key) {
19+
$func(v.downcast_exact()?, $ctx)?
20+
}
21+
}};
22+
}
23+
24+
macro_rules! traverse {
25+
($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{
26+
$(traverse_key_fn!($key, $func, $dict, $ctx);)*
27+
traverse_key_fn!("serialization", gather_serialization, $dict, $ctx);
28+
gather_meta($dict, $ctx)
29+
}}
30+
}
31+
32+
macro_rules! defaultdict_list_append {
33+
($dict: expr, $key: expr, $value: expr) => {{
34+
match $dict.get_item($key)? {
35+
None => {
36+
let list = PyList::empty_bound($dict.py());
37+
list.append($value)?;
38+
$dict.set_item($key, list)?;
39+
}
40+
// Safety: we know that the value is a PyList as we just created it above
41+
Some(list) => unsafe { list.downcast_unchecked::<PyList>() }.append($value)?,
42+
};
43+
}};
44+
}
45+
46+
fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
47+
if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") {
48+
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
49+
let schema_ref_str = schema_ref_pystr.to_str()?;
50+
51+
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
52+
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
53+
54+
// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
55+
if let Some(definition) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
56+
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
57+
58+
gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?;
59+
traverse_key_fn!("serialization", gather_serialization, schema_ref_dict, ctx);
60+
gather_meta(schema_ref_dict, ctx)?;
61+
62+
ctx.recursively_seen_refs.remove(schema_ref_str);
63+
}
64+
} else {
65+
ctx.recursive_def_refs.add(schema_ref_pystr)?;
66+
for seen_ref in &ctx.recursively_seen_refs {
67+
let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref);
68+
ctx.recursive_def_refs.add(seen_ref_pystr)?;
69+
}
70+
}
71+
Ok(())
72+
} else {
73+
py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref")
74+
}
75+
}
76+
77+
fn gather_serialization(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
78+
traverse!("schema" => gather_schema, "return_schema" => gather_schema; schema, ctx)
79+
}
80+
81+
fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
82+
if let Some(meta) = get!(schema, "metadata") {
83+
let meta_dict = meta.downcast_exact::<PyDict>()?;
84+
if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) {
85+
let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]);
86+
ctx.discriminators.append(schema_discriminator)?;
87+
}
88+
}
89+
Ok(())
90+
}
91+
92+
fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
93+
for v in schema_list.iter() {
94+
gather_schema(v.downcast_exact()?, ctx)?;
95+
}
96+
Ok(())
97+
}
98+
99+
fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
100+
for (_, v) in schemas_by_key.iter() {
101+
gather_schema(v.downcast_exact()?, ctx)?;
102+
}
103+
Ok(())
104+
}
105+
106+
fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
107+
for v in schema_list.iter() {
108+
if let Ok(tup) = v.downcast_exact::<PyTuple>() {
109+
gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?;
110+
} else {
111+
gather_schema(v.downcast_exact()?, ctx)?;
112+
}
113+
}
114+
Ok(())
115+
}
116+
117+
fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
118+
for v in arguments.iter() {
119+
traverse_key_fn!("schema", gather_schema, v.downcast_exact::<PyDict>()?, ctx);
120+
}
121+
Ok(())
122+
}
123+
124+
// Has 100% coverage in Pydantic side. This is exclusively used there
125+
#[cfg_attr(has_coverage_attribute, coverage(off))]
126+
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
127+
let type_ = get!(schema, "type");
128+
if type_.is_none() {
129+
return py_err!(PyKeyError; "Schema type missing");
130+
}
131+
match type_.unwrap().downcast_exact::<PyString>()?.to_str()? {
132+
"definition-ref" => gather_definition_ref(schema, ctx),
133+
"definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx),
134+
"list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx),
135+
"tuple" => traverse!("items_schema" => gather_list; schema, ctx),
136+
"dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx),
137+
"union" => traverse!("choices" => gather_union_choices; schema, ctx),
138+
"tagged-union" => traverse!("choices" => gather_dict; schema, ctx),
139+
"chain" => traverse!("steps" => gather_list; schema, ctx),
140+
"lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx),
141+
"json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx),
142+
"model-fields" | "typed-dict" => traverse!(
143+
"extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx
144+
),
145+
"dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx),
146+
"arguments" => traverse!(
147+
"arguments_schema" => gather_arguments,
148+
"var_args_schema" => gather_schema,
149+
"var_kwargs_schema" => gather_schema;
150+
schema, ctx
151+
),
152+
"call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx),
153+
"computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx),
154+
"function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx),
155+
_ => traverse!("schema" => gather_schema; schema, ctx),
156+
}
157+
}
158+
159+
pub struct GatherCtx<'a, 'py> {
160+
pub definitions_dict: &'a Bound<'py, PyDict>,
161+
pub def_refs: Bound<'py, PyDict>,
162+
pub recursive_def_refs: Bound<'py, PySet>,
163+
pub discriminators: Bound<'py, PyList>,
164+
recursively_seen_refs: HashSet<String>,
165+
}
166+
167+
#[pyfunction(signature = (schema, definitions))]
168+
pub fn gather_schemas_for_cleaning<'py>(
169+
schema: &Bound<'py, PyAny>,
170+
definitions: &Bound<'py, PyAny>,
171+
) -> PyResult<Bound<'py, PyDict>> {
172+
let py = schema.py();
173+
let mut ctx = GatherCtx {
174+
definitions_dict: definitions.downcast_exact()?,
175+
def_refs: PyDict::new_bound(py),
176+
recursive_def_refs: PySet::empty_bound(py)?,
177+
discriminators: PyList::empty_bound(py),
178+
recursively_seen_refs: HashSet::new(),
179+
};
180+
gather_schema(schema.downcast_exact::<PyDict>()?, &mut ctx)?;
181+
182+
let res = PyDict::new_bound(py);
183+
res.set_item(intern!(py, "definition_refs"), ctx.def_refs)?;
184+
res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?;
185+
res.set_item(intern!(py, "deferred_discriminators"), ctx.discriminators)?;
186+
Ok(res)
187+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from pydantic_core import core_schema, gather_schemas_for_cleaning
2+
3+
4+
def test_no_refs():
5+
p1 = core_schema.arguments_parameter('a', core_schema.int_schema())
6+
p2 = core_schema.arguments_parameter('b', core_schema.int_schema())
7+
schema = core_schema.arguments_schema([p1, p2])
8+
res = gather_schemas_for_cleaning(schema, definitions={})
9+
assert res['definition_refs'] == {}
10+
assert res['recursive_refs'] == set()
11+
assert res['deferred_discriminators'] == []
12+
13+
14+
def test_simple_ref_schema():
15+
schema = core_schema.definition_reference_schema('ref1')
16+
definitions = {'ref1': core_schema.int_schema(ref='ref1')}
17+
18+
res = gather_schemas_for_cleaning(schema, definitions)
19+
assert res['definition_refs'] == {'ref1': [schema]} and res['definition_refs']['ref1'][0] is schema
20+
assert res['recursive_refs'] == set()
21+
assert res['deferred_discriminators'] == []
22+
23+
24+
def test_deep_ref_schema():
25+
class Model:
26+
pass
27+
28+
ref11 = core_schema.definition_reference_schema('ref1')
29+
ref12 = core_schema.definition_reference_schema('ref1')
30+
ref2 = core_schema.definition_reference_schema('ref2')
31+
32+
union = core_schema.union_schema([core_schema.int_schema(), (ref11, 'ref_label')])
33+
tup = core_schema.tuple_schema([ref12, core_schema.str_schema()])
34+
schema = core_schema.model_schema(
35+
Model,
36+
core_schema.model_fields_schema(
37+
{'a': core_schema.model_field(union), 'b': core_schema.model_field(ref2), 'c': core_schema.model_field(tup)}
38+
),
39+
)
40+
definitions = {'ref1': core_schema.str_schema(ref='ref1'), 'ref2': core_schema.bytes_schema(ref='ref2')}
41+
42+
res = gather_schemas_for_cleaning(schema, definitions)
43+
assert res['definition_refs'] == {'ref1': [ref11, ref12], 'ref2': [ref2]}
44+
assert res['definition_refs']['ref1'][0] is ref11 and res['definition_refs']['ref1'][1] is ref12
45+
assert res['definition_refs']['ref2'][0] is ref2
46+
assert res['recursive_refs'] == set()
47+
assert res['deferred_discriminators'] == []
48+
49+
50+
def test_ref_in_serialization_schema():
51+
ref = core_schema.definition_reference_schema('ref1')
52+
schema = core_schema.str_schema(
53+
serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v, return_schema=ref),
54+
)
55+
res = gather_schemas_for_cleaning(schema, definitions={'ref1': core_schema.str_schema()})
56+
assert res['definition_refs'] == {'ref1': [ref]} and res['definition_refs']['ref1'][0] is ref
57+
assert res['recursive_refs'] == set()
58+
assert res['deferred_discriminators'] == []
59+
60+
61+
def test_recursive_ref_schema():
62+
ref1 = core_schema.definition_reference_schema('ref1')
63+
res = gather_schemas_for_cleaning(ref1, definitions={'ref1': ref1})
64+
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
65+
assert res['recursive_refs'] == {'ref1'}
66+
assert res['deferred_discriminators'] == []
67+
68+
69+
def test_deep_recursive_ref_schema():
70+
ref1 = core_schema.definition_reference_schema('ref1')
71+
ref2 = core_schema.definition_reference_schema('ref2')
72+
ref3 = core_schema.definition_reference_schema('ref3')
73+
74+
res = gather_schemas_for_cleaning(
75+
core_schema.union_schema([ref1, core_schema.int_schema()]),
76+
definitions={
77+
'ref1': core_schema.union_schema([core_schema.int_schema(), ref2]),
78+
'ref2': core_schema.union_schema([ref3, core_schema.float_schema()]),
79+
'ref3': core_schema.union_schema([ref1, core_schema.str_schema()]),
80+
},
81+
)
82+
assert res['definition_refs'] == {'ref1': [ref1], 'ref2': [ref2], 'ref3': [ref3]}
83+
assert res['recursive_refs'] == {'ref1', 'ref2', 'ref3'}
84+
assert res['definition_refs']['ref1'][0] is ref1
85+
assert res['definition_refs']['ref2'][0] is ref2
86+
assert res['definition_refs']['ref3'][0] is ref3
87+
assert res['deferred_discriminators'] == []
88+
89+
90+
def test_discriminator_meta():
91+
class Model:
92+
pass
93+
94+
ref1 = core_schema.definition_reference_schema('ref1')
95+
96+
field1 = core_schema.model_field(core_schema.str_schema())
97+
field1['metadata'] = {'pydantic.internal.union_discriminator': 'foobar1'}
98+
99+
field2 = core_schema.model_field(core_schema.int_schema())
100+
field2['metadata'] = {'pydantic.internal.union_discriminator': 'foobar2'}
101+
102+
schema = core_schema.model_schema(Model, core_schema.model_fields_schema({'a': field1, 'b': ref1}))
103+
res = gather_schemas_for_cleaning(schema, definitions={'ref1': field2})
104+
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
105+
assert res['recursive_refs'] == set()
106+
assert res['deferred_discriminators'] == [(field1, 'foobar1'), (field2, 'foobar2')]
107+
assert res['deferred_discriminators'][0][0] is field1 and res['deferred_discriminators'][1][0] is field2

0 commit comments

Comments
 (0)