Skip to content

Commit 18a29c4

Browse files
Add tests
1 parent 8c0d5d9 commit 18a29c4

File tree

2 files changed

+113
-3
lines changed

2 files changed

+113
-3
lines changed

src/schema_traverse.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::tools::py_err;
2-
use pyo3::exceptions::{PyKeyError, PyValueError};
2+
use pyo3::exceptions::PyKeyError;
33
use pyo3::prelude::*;
44
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
55
use pyo3::{intern, Bound, PyResult};
@@ -47,9 +47,10 @@ fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCt
4747
if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") {
4848
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
4949
let schema_ref_str = schema_ref_pystr.to_str()?;
50-
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
5150

5251
if !ctx.recursively_seen_refs.contains(schema_ref_str) {
52+
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);
53+
5354
// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
5455
if let Some(definition) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
5556
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());
@@ -124,10 +125,12 @@ fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyRes
124125
Ok(())
125126
}
126127

128+
// Has 100% coverage in Pydantic side. This is exclusively used there
129+
#[cfg_attr(has_coverage_attribute, coverage(off))]
127130
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
128131
let type_ = get!(schema, "type");
129132
if type_.is_none() {
130-
return py_err!(PyValueError; "Schema type missing");
133+
return py_err!(PyKeyError; "Schema type missing");
131134
}
132135
match type_.unwrap().downcast_exact::<PyString>()?.to_str()? {
133136
"definition-ref" => gather_definition_ref(schema, ctx)?,
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from pydantic_core import core_schema, gather_schemas_for_cleaning
2+
3+
4+
def test_no_refs():
5+
p1 = core_schema.arguments_parameter('a', core_schema.int_schema())
6+
p2 = core_schema.arguments_parameter('b', core_schema.int_schema())
7+
schema = core_schema.arguments_schema([p1, p2])
8+
res = gather_schemas_for_cleaning(schema, definitions={})
9+
assert res['definition_refs'] == {}
10+
assert res['recursive_refs'] == set()
11+
assert res['deferred_discriminators'] == []
12+
13+
14+
def test_simple_ref_schema():
15+
schema = core_schema.definition_reference_schema('ref1')
16+
definitions = {'ref1': core_schema.int_schema(ref='ref1')}
17+
18+
res = gather_schemas_for_cleaning(schema, definitions)
19+
assert res['definition_refs'] == {'ref1': [schema]} and res['definition_refs']['ref1'][0] is schema
20+
assert res['recursive_refs'] == set()
21+
assert res['deferred_discriminators'] == []
22+
23+
24+
def test_deep_ref_schema():
25+
class Model:
26+
pass
27+
28+
ref11 = core_schema.definition_reference_schema('ref1')
29+
ref12 = core_schema.definition_reference_schema('ref1')
30+
ref2 = core_schema.definition_reference_schema('ref2')
31+
32+
union = core_schema.union_schema([core_schema.int_schema(), (ref11, 'ref_label')])
33+
tup = core_schema.tuple_schema([ref12, core_schema.str_schema()])
34+
schema = core_schema.model_schema(
35+
Model,
36+
core_schema.model_fields_schema(
37+
{'a': core_schema.model_field(union), 'b': core_schema.model_field(ref2), 'c': core_schema.model_field(tup)}
38+
),
39+
)
40+
definitions = {'ref1': core_schema.str_schema(ref='ref1'), 'ref2': core_schema.bytes_schema(ref='ref2')}
41+
42+
res = gather_schemas_for_cleaning(schema, definitions)
43+
assert res['definition_refs'] == {'ref1': [ref11, ref12], 'ref2': [ref2]}
44+
assert res['definition_refs']['ref1'][0] is ref11 and res['definition_refs']['ref1'][1] is ref12
45+
assert res['definition_refs']['ref2'][0] is ref2
46+
assert res['recursive_refs'] == set()
47+
assert res['deferred_discriminators'] == []
48+
49+
50+
def test_ref_in_serialization_schema():
51+
ref = core_schema.definition_reference_schema('ref1')
52+
schema = core_schema.str_schema(
53+
serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v, return_schema=ref),
54+
)
55+
res = gather_schemas_for_cleaning(schema, definitions={'ref1': core_schema.str_schema()})
56+
assert res['definition_refs'] == {'ref1': [ref]} and res['definition_refs']['ref1'][0] is ref
57+
assert res['recursive_refs'] == set()
58+
assert res['deferred_discriminators'] == []
59+
60+
61+
def test_recursive_ref_schema():
62+
ref1 = core_schema.definition_reference_schema('ref1')
63+
res = gather_schemas_for_cleaning(ref1, definitions={'ref1': ref1})
64+
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
65+
assert res['recursive_refs'] == {'ref1'}
66+
assert res['deferred_discriminators'] == []
67+
68+
69+
def test_deep_recursive_ref_schema():
70+
ref1 = core_schema.definition_reference_schema('ref1')
71+
ref2 = core_schema.definition_reference_schema('ref2')
72+
ref3 = core_schema.definition_reference_schema('ref3')
73+
74+
res = gather_schemas_for_cleaning(
75+
core_schema.union_schema([ref1, core_schema.int_schema()]),
76+
definitions={
77+
'ref1': core_schema.union_schema([core_schema.int_schema(), ref2]),
78+
'ref2': core_schema.union_schema([ref3, core_schema.float_schema()]),
79+
'ref3': core_schema.union_schema([ref1, core_schema.str_schema()]),
80+
},
81+
)
82+
assert res['definition_refs'] == {'ref1': [ref1], 'ref2': [ref2], 'ref3': [ref3]}
83+
assert res['recursive_refs'] == {'ref1', 'ref2', 'ref3'}
84+
assert res['definition_refs']['ref1'][0] is ref1
85+
assert res['definition_refs']['ref2'][0] is ref2
86+
assert res['definition_refs']['ref3'][0] is ref3
87+
assert res['deferred_discriminators'] == []
88+
89+
90+
def test_discriminator_meta():
91+
class Model:
92+
pass
93+
94+
ref1 = core_schema.definition_reference_schema('ref1')
95+
96+
field1 = core_schema.model_field(core_schema.str_schema())
97+
field1['metadata'] = {'pydantic.internal.union_discriminator': 'foobar1'}
98+
99+
field2 = core_schema.model_field(core_schema.int_schema())
100+
field2['metadata'] = {'pydantic.internal.union_discriminator': 'foobar2'}
101+
102+
schema = core_schema.model_schema(Model, core_schema.model_fields_schema({'a': field1, 'b': ref1}))
103+
res = gather_schemas_for_cleaning(schema, definitions={'ref1': field2})
104+
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
105+
assert res['recursive_refs'] == set()
106+
assert res['deferred_discriminators'] == [(field1, 'foobar1'), (field2, 'foobar2')]
107+
assert res['deferred_discriminators'][0][0] is field1 and res['deferred_discriminators'][1][0] is field2

0 commit comments

Comments
 (0)