Skip to content

Commit e12703a

Browse files
committed
WIP
1 parent bb67044 commit e12703a

File tree

9 files changed

+246
-1
lines changed

9 files changed

+246
-1
lines changed

python/pydantic_core/core_schema.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,24 @@ def uuid_schema(
13841384
)
13851385

13861386

1387+
class NestedModelSchema(TypedDict, total=False):
1388+
type: Required[Literal['nested-model']]
1389+
model: Required[Type[Any]]
1390+
metadata: Any
1391+
1392+
1393+
def nested_model_schema(
1394+
*,
1395+
model: Type[Any],
1396+
metadata: Any = None,
1397+
) -> NestedModelSchema:
1398+
return _dict_not_none(
1399+
type='nested-model',
1400+
model=model,
1401+
metadata=metadata,
1402+
)
1403+
1404+
13871405
class IncExSeqSerSchema(TypedDict, total=False):
13881406
type: Required[Literal['include-exclude-sequence']]
13891407
include: Set[int]
@@ -3796,6 +3814,7 @@ def definition_reference_schema(
37963814
DefinitionsSchema,
37973815
DefinitionReferenceSchema,
37983816
UuidSchema,
3817+
NestedModelSchema,
37993818
]
38003819
elif False:
38013820
CoreSchema: TypeAlias = Mapping[str, Any]
@@ -3851,6 +3870,7 @@ def definition_reference_schema(
38513870
'definitions',
38523871
'definition-ref',
38533872
'uuid',
3873+
'nested-model',
38543874
]
38553875

38563876
CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod errors;
2121
mod input;
2222
mod lookup_key;
2323
mod recursion_guard;
24+
mod schema_cache;
2425
mod serializers;
2526
mod tools;
2627
mod url;
@@ -133,5 +134,7 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
133134
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
134135
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
135136
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
137+
m.add_function(wrap_pyfunction!(schema_cache::cache_built_schema, m)?)?;
138+
m.add_function(wrap_pyfunction!(schema_cache::retrieve_schema, m)?)?;
136139
Ok(())
137140
}

src/schema_cache.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
use std::sync::OnceLock;
2+
3+
use pyo3::{
4+
types::{PyAnyMethods, PyDict},
5+
Bound, Py, PyAny, Python,
6+
};
7+
8+
/// Intended to be a `WeakKeyDictionary[Type, CoreSchema]`
9+
static CORE_SCHEMA_CACHE: OnceLock<Py<PyAny>> = OnceLock::new();
10+
11+
fn get_cache(py: Python<'_>) -> Bound<'_, PyDict> {
12+
let dict = CORE_SCHEMA_CACHE
13+
.get_or_init(|| {
14+
py.eval_bound(
15+
"
16+
import weakref
17+
weakref.WeakKeyDictionary()
18+
",
19+
None,
20+
None,
21+
)
22+
.unwrap()
23+
.unbind()
24+
})
25+
.downcast_bound::<PyDict>(py)
26+
.unwrap();
27+
28+
dict.clone()
29+
}
30+
31+
#[pyo3::pyfunction]
32+
pub fn cache_built_schema(py: Python<'_>, ty: Py<PyAny>, schema: Py<PyDict>) {
33+
let cache = get_cache(py);
34+
cache.set_item(ty, schema).unwrap();
35+
}
36+
37+
#[pyo3::pyfunction]
38+
pub fn retrieve_schema(py: Python<'_>, ty: Py<PyAny>) -> Py<PyAny> {
39+
let cache = get_cache(py);
40+
cache.get_item(ty).unwrap().into()
41+
}

src/serializers/shared.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ combined_serializer! {
142142
Enum: super::type_serializers::enum_::EnumSerializer;
143143
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
144144
Tuple: super::type_serializers::tuple::TupleSerializer;
145+
NestedModel: super::type_serializers::nested_model::NestedModelSerializer;
145146
}
146147
}
147148

@@ -251,6 +252,7 @@ impl PyGcTraverse for CombinedSerializer {
251252
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
252253
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
253254
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
255+
CombinedSerializer::NestedModel(inner) => inner.py_gc_traverse(visit),
254256
}
255257
}
256258
}

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub mod json_or_python;
1515
pub mod list;
1616
pub mod literal;
1717
pub mod model;
18+
pub mod nested_model;
1819
pub mod nullable;
1920
pub mod other;
2021
pub mod set_frozenset;
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use std::borrow::Cow;
2+
3+
use pyo3::{
4+
intern,
5+
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
6+
Bound, Py, PyAny, PyObject, PyResult, Python,
7+
};
8+
9+
use crate::{
10+
definitions::DefinitionsBuilder,
11+
serializers::{
12+
shared::{BuildSerializer, TypeSerializer},
13+
CombinedSerializer, Extra,
14+
},
15+
SchemaSerializer,
16+
};
17+
18+
#[derive(Debug, Clone)]
19+
pub struct NestedModelSerializer {
20+
model: Py<PyType>,
21+
name: String,
22+
}
23+
24+
impl_py_gc_traverse!(NestedModelSerializer { model });
25+
26+
impl BuildSerializer for NestedModelSerializer {
27+
const EXPECTED_TYPE: &'static str = "nested-model";
28+
29+
fn build(
30+
schema: &Bound<'_, PyDict>,
31+
_config: Option<&Bound<'_, PyDict>>,
32+
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
33+
) -> PyResult<CombinedSerializer> {
34+
let py = schema.py();
35+
let model = schema
36+
.get_item(intern!(py, "model"))?
37+
.expect("Invalid core schema for `nested-model` type")
38+
.downcast::<PyType>()
39+
.expect("Invalid core schema for `nested-model` type")
40+
.clone();
41+
42+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
43+
44+
Ok(CombinedSerializer::NestedModel(NestedModelSerializer {
45+
model: model.clone().unbind(),
46+
name,
47+
}))
48+
}
49+
}
50+
51+
impl NestedModelSerializer {
52+
fn nested_serializer<'py>(&self, py: Python<'py>) -> Bound<'py, SchemaSerializer> {
53+
self.model
54+
.getattr(py, intern!(py, "__pydantic_serializer__"))
55+
.unwrap()
56+
.downcast_bound::<SchemaSerializer>(py)
57+
.unwrap()
58+
.clone()
59+
60+
// crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
61+
// .downcast_bound::<SchemaSerializer>(py)
62+
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
63+
// .expect("Cached validator was not a `SchemaSerializer`")
64+
// .clone()
65+
}
66+
}
67+
68+
impl TypeSerializer for NestedModelSerializer {
69+
fn to_python(
70+
&self,
71+
value: &Bound<'_, PyAny>,
72+
include: Option<&Bound<'_, PyAny>>,
73+
exclude: Option<&Bound<'_, PyAny>>,
74+
extra: &Extra,
75+
) -> PyResult<PyObject> {
76+
self.nested_serializer(value.py())
77+
.get()
78+
.serializer
79+
.to_python(value, include, exclude, extra)
80+
}
81+
82+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
83+
self.nested_serializer(key.py()).get().serializer.json_key(key, extra)
84+
}
85+
86+
fn serde_serialize<S: serde::ser::Serializer>(
87+
&self,
88+
value: &Bound<'_, PyAny>,
89+
serializer: S,
90+
include: Option<&Bound<'_, PyAny>>,
91+
exclude: Option<&Bound<'_, PyAny>>,
92+
extra: &Extra,
93+
) -> Result<S::Ok, S::Error> {
94+
self.nested_serializer(value.py())
95+
.get()
96+
.serializer
97+
.serde_serialize(value, serializer, include, exclude, extra)
98+
}
99+
100+
fn get_name(&self) -> &str {
101+
&self.name
102+
}
103+
}

src/validators/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ mod list;
4848
mod literal;
4949
mod model;
5050
mod model_fields;
51+
mod nested_model;
5152
mod none;
5253
mod nullable;
5354
mod set;
@@ -582,6 +583,7 @@ pub fn build_validator(
582583
// recursive (self-referencing) models
583584
definitions::DefinitionRefValidator,
584585
definitions::DefinitionsValidatorBuilder,
586+
nested_model::NestedModelValidator,
585587
)
586588
}
587589

@@ -735,6 +737,8 @@ pub enum CombinedValidator {
735737
DefinitionRef(definitions::DefinitionRefValidator),
736738
// input dependent
737739
JsonOrPython(json_or_python::JsonOrPython),
740+
// Schema for a model inside of another schema
741+
NestedModel(nested_model::NestedModelValidator),
738742
}
739743

740744
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,

src/validators/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator {
7777

7878
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
7979
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
80-
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
80+
let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?;
8181
let name = class.getattr(intern!(py, "__name__"))?.extract()?;
8282

8383
Ok(Self {

src/validators/nested_model.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use pyo3::{
2+
intern,
3+
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
4+
Bound, Py, PyObject, PyResult, Python,
5+
};
6+
7+
use crate::{definitions::DefinitionsBuilder, errors::ValResult, input::Input};
8+
9+
use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};
10+
11+
#[derive(Debug, Clone)]
12+
pub struct NestedModelValidator {
13+
model: Py<PyType>,
14+
name: String,
15+
}
16+
17+
impl_py_gc_traverse!(NestedModelValidator { model });
18+
19+
impl BuildValidator for NestedModelValidator {
20+
const EXPECTED_TYPE: &'static str = "nested-model";
21+
22+
fn build(
23+
schema: &Bound<'_, PyDict>,
24+
_config: Option<&Bound<'_, PyDict>>,
25+
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
26+
) -> PyResult<super::CombinedValidator> {
27+
let py = schema.py();
28+
let model = schema
29+
.get_item(intern!(py, "model"))?
30+
.expect("Invalid core schema for `nested-model` type")
31+
.downcast::<PyType>()
32+
.expect("Invalid core schema for `nested-model` type")
33+
.clone();
34+
35+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
36+
37+
Ok(CombinedValidator::NestedModel(NestedModelValidator {
38+
model: model.clone().unbind(),
39+
name,
40+
}))
41+
}
42+
}
43+
44+
impl Validator for NestedModelValidator {
45+
fn validate<'py>(
46+
&self,
47+
py: Python<'py>,
48+
input: &(impl Input<'py> + ?Sized),
49+
state: &mut ValidationState<'_, 'py>,
50+
) -> ValResult<PyObject> {
51+
let validator = self
52+
.model
53+
.getattr(py, intern!(py, "__pydantic_validator__"))
54+
.unwrap()
55+
.downcast_bound::<SchemaValidator>(py)
56+
.unwrap()
57+
.clone();
58+
59+
// let validator = crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
60+
// .downcast_bound::<SchemaValidator>(py)
61+
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
62+
// .expect("Cached validator was not a `SchemaValidator`")
63+
// .clone();
64+
65+
validator.get().validator.validate(py, input, state)
66+
}
67+
68+
fn get_name(&self) -> &str {
69+
&self.name
70+
}
71+
}

0 commit comments

Comments
 (0)