Skip to content

Commit 26e3ec4

Browse files
committed
a
1 parent b65d178 commit 26e3ec4

File tree

4 files changed

+131
-49
lines changed

4 files changed

+131
-49
lines changed

python/pydantic_core/core_schema.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1440,18 +1440,24 @@ def uuid_schema(
14401440
class NestedModelSchema(TypedDict, total=False):
14411441
type: Required[Literal['nested-model']]
14421442
model: Required[Type[Any]]
1443+
# Should return `(CoreSchema, SchemaValidator, SchemaSerializer)` but this requires a forward ref
1444+
get_info: Required[Callable[[], Any]]
14431445
metadata: Any
1444-
1446+
serialization: SerSchema
14451447

14461448
def nested_model_schema(
14471449
*,
14481450
model: Type[Any],
1451+
get_info: Callable[[], Any],
14491452
metadata: Any = None,
1453+
serialization: SerSchema | None = None
14501454
) -> NestedModelSchema:
14511455
return _dict_not_none(
14521456
type='nested-model',
14531457
model=model,
1458+
get_info=get_info,
14541459
metadata=metadata,
1460+
serialization=serialization
14551461
)
14561462

14571463

src/py_gc.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::sync::{Arc, OnceLock};
22

33
use ahash::AHashMap;
44
use enum_dispatch::enum_dispatch;
@@ -58,6 +58,15 @@ impl<T: PyGcTraverse> PyGcTraverse for Option<T> {
5858
}
5959
}
6060

61+
impl<T: PyGcTraverse> PyGcTraverse for OnceLock<T> {
62+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
63+
match self.get() {
64+
Some(item) => T::py_gc_traverse(item, visit),
65+
None => Ok(()),
66+
}
67+
}
68+
}
69+
6170
/// A crude alternative to a "derive" macro to help with building PyGcTraverse implementations
6271
macro_rules! impl_py_gc_traverse {
6372
($name:ty { }) => {

src/serializers/type_serializers/nested_model.rs

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
use std::borrow::Cow;
1+
use std::{borrow::Cow, sync::OnceLock};
22

33
use pyo3::{
44
intern,
5-
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
5+
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
66
Bound, Py, PyAny, PyObject, PyResult, Python,
77
};
88

@@ -19,9 +19,15 @@ use crate::{
1919
pub struct NestedModelSerializer {
2020
model: Py<PyType>,
2121
name: String,
22+
get_serializer: Py<PyAny>,
23+
serializer: OnceLock<Py<SchemaSerializer>>,
2224
}
2325

24-
impl_py_gc_traverse!(NestedModelSerializer { model });
26+
impl_py_gc_traverse!(NestedModelSerializer {
27+
model,
28+
get_serializer,
29+
serializer
30+
});
2531

2632
impl BuildSerializer for NestedModelSerializer {
2733
const EXPECTED_TYPE: &'static str = "nested-model";
@@ -32,6 +38,12 @@ impl BuildSerializer for NestedModelSerializer {
3238
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
3339
) -> PyResult<CombinedSerializer> {
3440
let py = schema.py();
41+
42+
let get_serializer = schema
43+
.get_item(intern!(py, "get_info"))?
44+
.expect("Invalid core schema for `nested-model` type")
45+
.unbind();
46+
3547
let model = schema
3648
.get_item(intern!(py, "model"))?
3749
.expect("Invalid core schema for `nested-model` type")
@@ -44,29 +56,30 @@ impl BuildSerializer for NestedModelSerializer {
4456
Ok(CombinedSerializer::NestedModel(NestedModelSerializer {
4557
model: model.clone().unbind(),
4658
name,
59+
get_serializer,
60+
serializer: OnceLock::new(),
4761
}))
4862
}
4963
}
5064

5165
impl NestedModelSerializer {
52-
fn nested_serializer<'py>(&self, py: Python<'py>) -> Bound<'py, SchemaSerializer> {
53-
self.model
54-
.bind(py)
55-
.call_method(intern!(py, "model_rebuild"), (), None)
56-
.unwrap();
57-
58-
self.model
59-
.getattr(py, intern!(py, "__pydantic_serializer__"))
60-
.unwrap()
61-
.downcast_bound::<SchemaSerializer>(py)
62-
.unwrap()
66+
fn nested_serializer<'py>(&self, py: Python<'py>) -> Py<SchemaSerializer> {
67+
self.serializer
68+
.get_or_init(|| {
69+
self.get_serializer
70+
.bind(py)
71+
.call((), None)
72+
.expect("Invalid core schema for `nested-model`")
73+
.downcast::<PyTuple>()
74+
.expect("Invalid return value from `nested-model`'s `get_info` callable")
75+
.get_item(2)
76+
.expect("Invalid return value from `nested-model`'s `get_info` callable")
77+
.downcast::<SchemaSerializer>()
78+
.expect("Invalid return value from `nested-model`'s `get_info` callable")
79+
.clone()
80+
.unbind()
81+
})
6382
.clone()
64-
65-
// crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
66-
// .downcast_bound::<SchemaSerializer>(py)
67-
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
68-
// .expect("Cached validator was not a `SchemaSerializer`")
69-
// .clone()
7083
}
7184
}
7285

@@ -76,16 +89,23 @@ impl TypeSerializer for NestedModelSerializer {
7689
value: &Bound<'_, PyAny>,
7790
include: Option<&Bound<'_, PyAny>>,
7891
exclude: Option<&Bound<'_, PyAny>>,
79-
extra: &Extra,
92+
mut extra: &Extra,
8093
) -> PyResult<PyObject> {
94+
let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?;
95+
8196
self.nested_serializer(value.py())
97+
.bind(value.py())
8298
.get()
8399
.serializer
84-
.to_python(value, include, exclude, extra)
100+
.to_python(value, include, exclude, guard.state())
85101
}
86102

87103
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
88-
self.nested_serializer(key.py()).get().serializer.json_key(key, extra)
104+
self.nested_serializer(key.py())
105+
.bind(key.py())
106+
.get()
107+
.serializer
108+
.json_key(key, extra)
89109
}
90110

91111
fn serde_serialize<S: serde::ser::Serializer>(
@@ -94,12 +114,19 @@ impl TypeSerializer for NestedModelSerializer {
94114
serializer: S,
95115
include: Option<&Bound<'_, PyAny>>,
96116
exclude: Option<&Bound<'_, PyAny>>,
97-
extra: &Extra,
117+
mut extra: &Extra,
98118
) -> Result<S::Ok, S::Error> {
119+
use super::py_err_se_err;
120+
121+
let mut guard = extra
122+
.recursion_guard(value, self.model.as_ptr() as usize)
123+
.map_err(py_err_se_err)?;
124+
99125
self.nested_serializer(value.py())
126+
.bind(value.py())
100127
.get()
101128
.serializer
102-
.serde_serialize(value, serializer, include, exclude, extra)
129+
.serde_serialize(value, serializer, include, exclude, guard.state())
103130
}
104131

105132
fn get_name(&self) -> &str {

src/validators/nested_model.rs

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
1+
use std::sync::OnceLock;
2+
13
use pyo3::{
24
intern,
3-
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
4-
Bound, Py, PyObject, PyResult, Python,
5+
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyTupleMethods, PyType},
6+
Bound, Py, PyAny, PyObject, PyResult, Python,
57
};
68

7-
use crate::{definitions::DefinitionsBuilder, errors::ValResult, input::Input};
9+
use crate::{
10+
definitions::DefinitionsBuilder,
11+
errors::{ErrorTypeDefaults, ValError, ValResult},
12+
input::Input,
13+
recursion_guard::RecursionGuard,
14+
};
815

916
use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};
1017

1118
#[derive(Debug, Clone)]
1219
pub struct NestedModelValidator {
1320
model: Py<PyType>,
1421
name: String,
22+
get_validator: Py<PyAny>,
23+
validator: OnceLock<Py<SchemaValidator>>,
1524
}
1625

17-
impl_py_gc_traverse!(NestedModelValidator { model });
26+
impl_py_gc_traverse!(NestedModelValidator {
27+
model,
28+
get_validator,
29+
validator
30+
});
1831

1932
impl BuildValidator for NestedModelValidator {
2033
const EXPECTED_TYPE: &'static str = "nested-model";
@@ -25,6 +38,12 @@ impl BuildValidator for NestedModelValidator {
2538
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
2639
) -> PyResult<super::CombinedValidator> {
2740
let py = schema.py();
41+
42+
let get_validator = schema
43+
.get_item(intern!(py, "get_info"))?
44+
.expect("Invalid core schema for `nested-model` type")
45+
.unbind();
46+
2847
let model = schema
2948
.get_item(intern!(py, "model"))?
3049
.expect("Invalid core schema for `nested-model` type")
@@ -37,40 +56,61 @@ impl BuildValidator for NestedModelValidator {
3756
Ok(CombinedValidator::NestedModel(NestedModelValidator {
3857
model: model.clone().unbind(),
3958
name,
59+
get_validator: get_validator,
60+
validator: OnceLock::new(),
4061
}))
4162
}
4263
}
4364

65+
impl NestedModelValidator {
66+
fn nested_validator<'py>(&self, py: Python<'py>) -> Py<SchemaValidator> {
67+
self.validator
68+
.get_or_init(|| {
69+
self.get_validator
70+
.bind(py)
71+
.call((), None)
72+
.expect("Invalid core schema for `nested-model`")
73+
.downcast::<PyTuple>()
74+
.expect("Invalid return value from `nested-model`'s `get_info` callable")
75+
.get_item(1)
76+
.expect("Invalid return value from `nested-model`'s `get_info` callable")
77+
.downcast::<SchemaValidator>()
78+
.expect("Invalid return value from `nested-model`'s `get_info` callable")
79+
.clone()
80+
.unbind()
81+
})
82+
.clone()
83+
}
84+
}
85+
4486
impl Validator for NestedModelValidator {
4587
fn validate<'py>(
4688
&self,
4789
py: Python<'py>,
4890
input: &(impl Input<'py> + ?Sized),
4991
state: &mut ValidationState<'_, 'py>,
5092
) -> ValResult<PyObject> {
51-
self.model
52-
.bind(py)
53-
.call_method(intern!(py, "model_rebuild"), (), None)
54-
.unwrap();
55-
56-
let validator = self
57-
.model
58-
.getattr(py, intern!(py, "__pydantic_validator__"))
59-
.unwrap()
60-
.downcast_bound::<SchemaValidator>(py)
61-
.unwrap()
62-
.clone();
93+
let Some(id) = input.as_python().map(py_identity) else {
94+
panic!("")
95+
};
6396

64-
// let validator = crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
65-
// .downcast_bound::<SchemaValidator>(py)
66-
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
67-
// .expect("Cached validator was not a `SchemaValidator`")
68-
// .clone();
97+
// Python objects can be cyclic, so need recursion guard
98+
let Ok(mut guard) = RecursionGuard::new(state, id, self.model.as_ptr() as usize) else {
99+
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
100+
};
69101

70-
validator.get().validator.validate(py, input, state)
102+
self.nested_validator(py)
103+
.bind(py)
104+
.get()
105+
.validator
106+
.validate(py, input, guard.state())
71107
}
72108

73109
fn get_name(&self) -> &str {
74110
&self.name
75111
}
76112
}
113+
114+
fn py_identity(obj: &Bound<'_, PyAny>) -> usize {
115+
obj.as_ptr() as usize
116+
}

0 commit comments

Comments
 (0)