Skip to content

Commit d16f4ff

Browse files
Optimize to only return possible to inline. Add bench
1 parent 5dc5b6d commit d16f4ff

File tree

4 files changed

+53
-30
lines changed

4 files changed

+53
-30
lines changed

python/pydantic_core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,6 @@ class MultiHostHost(_TypedDict):
146146
class GatherResult(_TypedDict):
147147
"""Internal result of gathering schemas for cleaning."""
148148

149-
definition_refs: dict[str, list[DefinitionReferenceSchema]]
149+
inlinable_def_refs: dict[str, DefinitionReferenceSchema]
150150
recursive_refs: set[str]
151151
schemas_with_meta_keys: dict[str, list[CoreSchema]] | None

src/schema_traverse.rs

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
use crate::tools::py_err;
22
use pyo3::exceptions::{PyException, PyKeyError};
33
use pyo3::prelude::*;
4-
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
4+
use pyo3::types::{PyDict, PyList, PyNone, PySet, PyString, PyTuple};
55
use pyo3::{create_exception, intern, Bound, PyResult};
6-
use std::collections::HashSet;
76

87
create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException);
98

9+
macro_rules! none {
10+
($py: expr) => {
11+
PyNone::get_bound($py)
12+
};
13+
}
14+
1015
macro_rules! get {
1116
($dict: expr, $key: expr) => {
1217
$dict.get_item(intern!($dict.py(), $key))?
@@ -47,28 +52,32 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4752
return py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref");
4853
};
4954
let schema_ref = schema_ref.downcast_exact::<PyString>()?;
55+
let py = schema_ref_dict.py();
5056

5157
if !ctx.recursively_seen_refs.contains(schema_ref)? {
52-
// Check if we already gathered this definition ref instance.
53-
// No need to again process the same def ref instance if we already did it.
54-
if ctx.seen_ref_instances.insert(schema_ref_dict.as_ptr() as isize) {
58+
// Def ref in no longer consider as inlinable if its re-encountered. Then its used multiple times.
59+
// No need to retraverse it either if we already encountered this.
60+
if !ctx.inline_def_ref_candidates.contains(schema_ref)? {
5561
let Some(definition) = ctx.definitions.get_item(schema_ref)? else {
5662
return py_err!(GatherInvalidDefinitionError; "{}", schema_ref.to_str()?);
5763
};
5864

59-
defaultdict_list_append!(&ctx.def_refs, schema_ref, schema_ref_dict);
60-
65+
ctx.inline_def_ref_candidates.set_item(schema_ref, schema_ref_dict)?;
6166
ctx.recursively_seen_refs.add(schema_ref)?;
6267

6368
gather_schema(definition.downcast_exact()?, ctx)?;
6469
traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);
6570
gather_meta(schema_ref_dict, ctx)?;
6671

6772
ctx.recursively_seen_refs.discard(schema_ref)?;
73+
} else {
74+
ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used multiple times)
6875
}
6976
} else {
77+
ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used in recursion)
7078
ctx.recursive_def_refs.add(schema_ref)?;
7179
for seen_ref in ctx.recursively_seen_refs.iter() {
80+
ctx.inline_def_ref_candidates.set_item(&seen_ref, none!(py))?; // Mark not inlinable (used in recursion)
7281
ctx.recursive_def_refs.add(seen_ref)?;
7382
}
7483
}
@@ -160,10 +169,9 @@ fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()
160169
struct GatherCtx<'a, 'py> {
161170
definitions: &'a Bound<'py, PyDict>,
162171
meta_with_keys: Option<(Bound<'py, PyDict>, &'a Bound<'py, PySet>)>,
163-
def_refs: Bound<'py, PyDict>,
172+
inline_def_ref_candidates: Bound<'py, PyDict>,
164173
recursive_def_refs: Bound<'py, PySet>,
165174
recursively_seen_refs: Bound<'py, PySet>,
166-
seen_ref_instances: HashSet<isize>,
167175
}
168176

169177
#[pyfunction(signature = (schema, definitions, find_meta_with_keys))]
@@ -179,15 +187,21 @@ pub fn gather_schemas_for_cleaning<'py>(
179187
true => None,
180188
false => Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?)),
181189
},
182-
def_refs: PyDict::new_bound(py),
190+
inline_def_ref_candidates: PyDict::new_bound(py),
183191
recursive_def_refs: PySet::empty_bound(py)?,
184192
recursively_seen_refs: PySet::empty_bound(py)?,
185-
seen_ref_instances: HashSet::new(),
186193
};
187194
gather_schema(schema.downcast_exact()?, &mut ctx)?;
188195

196+
let inlinable_def_refs = PyDict::new_bound(py);
197+
for (ref_str, def_ref_candidate) in ctx.inline_def_ref_candidates.iter() {
198+
if !def_ref_candidate.is_none() {
199+
inlinable_def_refs.set_item(ref_str, def_ref_candidate)?;
200+
}
201+
}
202+
189203
let res = PyDict::new_bound(py);
190-
res.set_item(intern!(py, "definition_refs"), ctx.def_refs)?;
204+
res.set_item(intern!(py, "inlinable_def_refs"), inlinable_def_refs)?;
191205
res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?;
192206
res.set_item(intern!(py, "schemas_with_meta_keys"), ctx.meta_with_keys.map(|v| v.0))?;
193207
Ok(res)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Callable
2+
3+
from python.pydantic_core import gather_schemas_for_cleaning
4+
5+
from .nested_schema import inlined_schema, schema_using_defs
6+
7+
8+
def test_nested_schema_using_defs(benchmark: Callable[..., None]) -> None:
9+
schema = schema_using_defs()
10+
definitions = {def_schema['ref']: def_schema for def_schema in schema['definitions']}
11+
schema = schema['schema']
12+
benchmark(gather_schemas_for_cleaning, schema, definitions, None)
13+
14+
15+
def test_nested_schema_inlined(benchmark: Callable[..., None]) -> None:
16+
schema = inlined_schema()
17+
benchmark(gather_schemas_for_cleaning, schema, {}, {'some_meta_key'})

tests/test_gather_schemas_for_cleaning.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def test_no_refs():
88
p2 = core_schema.arguments_parameter('b', core_schema.int_schema())
99
schema = core_schema.arguments_schema([p1, p2])
1010
res = gather_schemas_for_cleaning(schema, definitions={}, find_meta_with_keys=None)
11-
assert res['definition_refs'] == {}
11+
assert res['inlinable_def_refs'] == {}
1212
assert res['recursive_refs'] == set()
1313
assert res['schemas_with_meta_keys'] is None
1414

@@ -18,12 +18,12 @@ def test_simple_ref_schema():
1818
definitions = {'ref1': core_schema.int_schema(ref='ref1')}
1919

2020
res = gather_schemas_for_cleaning(schema, definitions, find_meta_with_keys=None)
21-
assert res['definition_refs'] == {'ref1': [schema]} and res['definition_refs']['ref1'][0] is schema
21+
assert res['inlinable_def_refs'] == {'ref1': schema} and res['inlinable_def_refs']['ref1'] is schema
2222
assert res['recursive_refs'] == set()
2323
assert res['schemas_with_meta_keys'] is None
2424

2525

26-
def test_deep_ref_schema():
26+
def test_deep_ref_schema_used_multiple_times():
2727
class Model:
2828
pass
2929

@@ -42,9 +42,7 @@ class Model:
4242
definitions = {'ref1': core_schema.str_schema(ref='ref1'), 'ref2': core_schema.bytes_schema(ref='ref2')}
4343

4444
res = gather_schemas_for_cleaning(schema, definitions, find_meta_with_keys=None)
45-
assert res['definition_refs'] == {'ref1': [ref11, ref12], 'ref2': [ref2]}
46-
assert res['definition_refs']['ref1'][0] is ref11 and res['definition_refs']['ref1'][1] is ref12
47-
assert res['definition_refs']['ref2'][0] is ref2
45+
assert res['inlinable_def_refs'] == {'ref2': ref2} and res['inlinable_def_refs']['ref2'] is ref2
4846
assert res['recursive_refs'] == set()
4947
assert res['schemas_with_meta_keys'] is None
5048

@@ -55,15 +53,15 @@ def test_ref_in_serialization_schema():
5553
serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v, return_schema=ref),
5654
)
5755
res = gather_schemas_for_cleaning(schema, definitions={'ref1': core_schema.str_schema()}, find_meta_with_keys=None)
58-
assert res['definition_refs'] == {'ref1': [ref]} and res['definition_refs']['ref1'][0] is ref
56+
assert res['inlinable_def_refs'] == {'ref1': ref} and res['inlinable_def_refs']['ref1'] is ref
5957
assert res['recursive_refs'] == set()
6058
assert res['schemas_with_meta_keys'] is None
6159

6260

6361
def test_recursive_ref_schema():
6462
ref1 = core_schema.definition_reference_schema('ref1')
6563
res = gather_schemas_for_cleaning(ref1, definitions={'ref1': ref1}, find_meta_with_keys=None)
66-
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
64+
assert res['inlinable_def_refs'] == {}
6765
assert res['recursive_refs'] == {'ref1'}
6866
assert res['schemas_with_meta_keys'] is None
6967

@@ -82,11 +80,8 @@ def test_deep_recursive_ref_schema():
8280
},
8381
find_meta_with_keys=None,
8482
)
85-
assert res['definition_refs'] == {'ref1': [ref1], 'ref2': [ref2], 'ref3': [ref3]}
83+
assert res['inlinable_def_refs'] == {}
8684
assert res['recursive_refs'] == {'ref1', 'ref2', 'ref3'}
87-
assert res['definition_refs']['ref1'][0] is ref1
88-
assert res['definition_refs']['ref2'][0] is ref2
89-
assert res['definition_refs']['ref3'][0] is ref3
9085
assert res['schemas_with_meta_keys'] is None
9186

9287

@@ -108,7 +103,7 @@ class Model:
108103
res = gather_schemas_for_cleaning(
109104
schema, definitions={'ref1': field2}, find_meta_with_keys={'find_meta1', 'find_meta2'}
110105
)
111-
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
106+
assert res['inlinable_def_refs'] == {'ref1': ref1} and res['inlinable_def_refs']['ref1'] is ref1
112107
assert res['recursive_refs'] == set()
113108
assert res['schemas_with_meta_keys'] == {'find_meta1': [field1, field2], 'find_meta2': [field2]}
114109
assert res['schemas_with_meta_keys']['find_meta1'][0] is field1
@@ -134,9 +129,6 @@ def test_no_duplicate_ref_instances_gathered():
134129
definitions = {'ref1': schema1, 'ref2': schema2}
135130

136131
res = gather_schemas_for_cleaning(schema3, definitions=definitions, find_meta_with_keys=None)
137-
assert res['definition_refs'] == {
138-
'ref1': [schema2['items_schema'][0], schema2['items_schema'][1]],
139-
'ref2': [schema3['items_schema'][0], schema3['items_schema'][1]],
140-
}
132+
assert res['inlinable_def_refs'] == {}
141133
assert res['recursive_refs'] == set()
142134
assert res['schemas_with_meta_keys'] is None

0 commit comments

Comments
 (0)