Skip to content

Commit 819c8dd

Browse files
Move discriminator const away. Use generic finding
1 parent b8ad942 commit 819c8dd

File tree

4 files changed

+55
-38
lines changed

4 files changed

+55
-38
lines changed

python/pydantic_core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,4 @@ class GatherResult(_TypedDict):
146146

147147
definition_refs: dict[str, list[DefinitionReferenceSchema]]
148148
recursive_refs: set[str]
149-
deferred_discriminators: list[tuple[CoreSchema, _Any]]
149+
schemas_with_meta_keys: dict[str, list[CoreSchema]] | None

python/pydantic_core/_pydantic_core.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,10 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C
11661166
using pydantic-core directly.
11671167
"""
11681168

1169-
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
1169+
def gather_schemas_for_cleaning(
1170+
schema: CoreSchema,
1171+
definitions: dict[str, CoreSchema],
1172+
find_meta_with_keys: set[str] | None,
1173+
) -> GatherResult:
11701174
"""Used internally for schema cleaning when schemas are generated.
11711175
Gathers information from the schema tree for the cleaning."""

src/schema_traverse.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
55
use pyo3::{intern, Bound, PyResult};
66
use std::collections::HashSet;
77

8-
const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";
9-
108
macro_rules! get {
119
($dict: expr, $key: expr) => {
1210
$dict.get_item(intern!($dict.py(), $key))?
@@ -52,7 +50,7 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
5250
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
5351

5452
// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
55-
if let Some(definition) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
53+
if let Some(definition) = ctx.definitions.get_item(schema_ref_pystr)? {
5654
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
5755

5856
gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?;
@@ -75,11 +73,13 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
7573
}
7674

7775
fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
78-
if let Some(meta) = get!(schema, "metadata") {
79-
let meta_dict = meta.downcast_exact::<PyDict>()?;
80-
if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) {
81-
let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]);
82-
ctx.discriminators.append(schema_discriminator)?;
76+
if let Some((res, find_keys)) = &ctx.meta_with_keys {
77+
if let Some(meta) = get!(schema, "metadata") {
78+
for (k, _) in meta.downcast_exact::<PyDict>()?.iter() {
79+
if find_keys.contains(&k)? {
80+
defaultdict_list_append!(res, &k, schema);
81+
}
82+
}
8383
}
8484
}
8585
Ok(())
@@ -152,32 +152,38 @@ fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()
152152
}
153153
}
154154

155-
pub struct GatherCtx<'a, 'py> {
156-
pub definitions_dict: &'a Bound<'py, PyDict>,
157-
pub def_refs: Bound<'py, PyDict>,
158-
pub recursive_def_refs: Bound<'py, PySet>,
159-
pub discriminators: Bound<'py, PyList>,
155+
struct GatherCtx<'a, 'py> {
156+
definitions: &'a Bound<'py, PyDict>,
157+
meta_with_keys: Option<(Bound<'py, PyDict>, &'a Bound<'py, PySet>)>,
158+
def_refs: Bound<'py, PyDict>,
159+
recursive_def_refs: Bound<'py, PySet>,
160160
recursively_seen_refs: HashSet<String>,
161161
}
162162

163-
#[pyfunction(signature = (schema, definitions))]
163+
#[pyfunction(signature = (schema, definitions, find_meta_with_keys))]
164164
pub fn gather_schemas_for_cleaning<'py>(
165165
schema: &Bound<'py, PyAny>,
166166
definitions: &Bound<'py, PyAny>,
167+
find_meta_with_keys: &Bound<'py, PyAny>,
167168
) -> PyResult<Bound<'py, PyDict>> {
168169
let py = schema.py();
170+
let meta_with_keys = if find_meta_with_keys.is_none() {
171+
None
172+
} else {
173+
Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?))
174+
};
169175
let mut ctx = GatherCtx {
170-
definitions_dict: definitions.downcast_exact()?,
176+
definitions: definitions.downcast_exact()?,
177+
meta_with_keys,
171178
def_refs: PyDict::new_bound(py),
172179
recursive_def_refs: PySet::empty_bound(py)?,
173-
discriminators: PyList::empty_bound(py),
174180
recursively_seen_refs: HashSet::new(),
175181
};
176-
gather_schema(schema.downcast_exact::<PyDict>()?, &mut ctx)?;
182+
gather_schema(schema.downcast_exact()?, &mut ctx)?;
177183

178184
let res = PyDict::new_bound(py);
179185
res.set_item(intern!(py, "definition_refs"), ctx.def_refs)?;
180186
res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?;
181-
res.set_item(intern!(py, "deferred_discriminators"), ctx.discriminators)?;
187+
res.set_item(intern!(py, "schemas_with_meta_keys"), ctx.meta_with_keys.map(|v| v.0))?;
182188
Ok(res)
183189
}

tests/test_gather_schemas_for_cleaning.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@ def test_no_refs():
55
p1 = core_schema.arguments_parameter('a', core_schema.int_schema())
66
p2 = core_schema.arguments_parameter('b', core_schema.int_schema())
77
schema = core_schema.arguments_schema([p1, p2])
8-
res = gather_schemas_for_cleaning(schema, definitions={})
8+
res = gather_schemas_for_cleaning(schema, definitions={}, find_meta_with_keys=None)
99
assert res['definition_refs'] == {}
1010
assert res['recursive_refs'] == set()
11-
assert res['deferred_discriminators'] == []
11+
assert res['schemas_with_meta_keys'] is None
1212

1313

1414
def test_simple_ref_schema():
1515
schema = core_schema.definition_reference_schema('ref1')
1616
definitions = {'ref1': core_schema.int_schema(ref='ref1')}
1717

18-
res = gather_schemas_for_cleaning(schema, definitions)
18+
res = gather_schemas_for_cleaning(schema, definitions, find_meta_with_keys=None)
1919
assert res['definition_refs'] == {'ref1': [schema]} and res['definition_refs']['ref1'][0] is schema
2020
assert res['recursive_refs'] == set()
21-
assert res['deferred_discriminators'] == []
21+
assert res['schemas_with_meta_keys'] is None
2222

2323

2424
def test_deep_ref_schema():
@@ -39,31 +39,31 @@ class Model:
3939
)
4040
definitions = {'ref1': core_schema.str_schema(ref='ref1'), 'ref2': core_schema.bytes_schema(ref='ref2')}
4141

42-
res = gather_schemas_for_cleaning(schema, definitions)
42+
res = gather_schemas_for_cleaning(schema, definitions, find_meta_with_keys=None)
4343
assert res['definition_refs'] == {'ref1': [ref11, ref12], 'ref2': [ref2]}
4444
assert res['definition_refs']['ref1'][0] is ref11 and res['definition_refs']['ref1'][1] is ref12
4545
assert res['definition_refs']['ref2'][0] is ref2
4646
assert res['recursive_refs'] == set()
47-
assert res['deferred_discriminators'] == []
47+
assert res['schemas_with_meta_keys'] is None
4848

4949

5050
def test_ref_in_serialization_schema():
5151
ref = core_schema.definition_reference_schema('ref1')
5252
schema = core_schema.str_schema(
5353
serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v, return_schema=ref),
5454
)
55-
res = gather_schemas_for_cleaning(schema, definitions={'ref1': core_schema.str_schema()})
55+
res = gather_schemas_for_cleaning(schema, definitions={'ref1': core_schema.str_schema()}, find_meta_with_keys=None)
5656
assert res['definition_refs'] == {'ref1': [ref]} and res['definition_refs']['ref1'][0] is ref
5757
assert res['recursive_refs'] == set()
58-
assert res['deferred_discriminators'] == []
58+
assert res['schemas_with_meta_keys'] is None
5959

6060

6161
def test_recursive_ref_schema():
6262
ref1 = core_schema.definition_reference_schema('ref1')
63-
res = gather_schemas_for_cleaning(ref1, definitions={'ref1': ref1})
63+
res = gather_schemas_for_cleaning(ref1, definitions={'ref1': ref1}, find_meta_with_keys=None)
6464
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
6565
assert res['recursive_refs'] == {'ref1'}
66-
assert res['deferred_discriminators'] == []
66+
assert res['schemas_with_meta_keys'] is None
6767

6868

6969
def test_deep_recursive_ref_schema():
@@ -78,30 +78,37 @@ def test_deep_recursive_ref_schema():
7878
'ref2': core_schema.union_schema([ref3, core_schema.float_schema()]),
7979
'ref3': core_schema.union_schema([ref1, core_schema.str_schema()]),
8080
},
81+
find_meta_with_keys=None,
8182
)
8283
assert res['definition_refs'] == {'ref1': [ref1], 'ref2': [ref2], 'ref3': [ref3]}
8384
assert res['recursive_refs'] == {'ref1', 'ref2', 'ref3'}
8485
assert res['definition_refs']['ref1'][0] is ref1
8586
assert res['definition_refs']['ref2'][0] is ref2
8687
assert res['definition_refs']['ref3'][0] is ref3
87-
assert res['deferred_discriminators'] == []
88+
assert res['schemas_with_meta_keys'] is None
8889

8990

90-
def test_discriminator_meta():
91+
def test_find_meta():
9192
class Model:
9293
pass
9394

9495
ref1 = core_schema.definition_reference_schema('ref1')
9596

9697
field1 = core_schema.model_field(core_schema.str_schema())
97-
field1['metadata'] = {'pydantic.internal.union_discriminator': 'foobar1'}
98+
field1['metadata'] = {'find_meta1': 'foobar1', 'unknown': 'foobar2'}
9899

99100
field2 = core_schema.model_field(core_schema.int_schema())
100-
field2['metadata'] = {'pydantic.internal.union_discriminator': 'foobar2'}
101+
field2['metadata'] = {'find_meta1': 'foobar3', 'find_meta2': 'foobar4'}
101102

102-
schema = core_schema.model_schema(Model, core_schema.model_fields_schema({'a': field1, 'b': ref1}))
103-
res = gather_schemas_for_cleaning(schema, definitions={'ref1': field2})
103+
schema = core_schema.model_schema(
104+
Model, core_schema.model_fields_schema({'a': field1, 'b': ref1, 'c': core_schema.float_schema()})
105+
)
106+
res = gather_schemas_for_cleaning(
107+
schema, definitions={'ref1': field2}, find_meta_with_keys={'find_meta1', 'find_meta2'}
108+
)
104109
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
105110
assert res['recursive_refs'] == set()
106-
assert res['deferred_discriminators'] == [(field1, 'foobar1'), (field2, 'foobar2')]
107-
assert res['deferred_discriminators'][0][0] is field1 and res['deferred_discriminators'][1][0] is field2
111+
assert res['schemas_with_meta_keys'] == {'find_meta1': [field1, field2], 'find_meta2': [field2]}
112+
assert res['schemas_with_meta_keys']['find_meta1'][0] is field1
113+
assert res['schemas_with_meta_keys']['find_meta1'][1] is field2
114+
assert res['schemas_with_meta_keys']['find_meta2'][0] is field2

0 commit comments

Comments
 (0)