Skip to content

Commit 980fa20

Browse files
committed
serializer reuse logic as well
1 parent e5245b5 commit 980fa20

File tree

3 files changed

+110
-3
lines changed

3 files changed

+110
-3
lines changed

src/serializers/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::fmt::Debug;
22
use std::sync::atomic::{AtomicUsize, Ordering};
3+
use std::sync::Arc;
34

45
use pyo3::prelude::*;
56
use pyo3::types::{PyBytes, PyDict, PyTuple, PyType};
@@ -24,6 +25,7 @@ mod fields;
2425
mod filter;
2526
mod infer;
2627
mod ob_type;
28+
mod prebuilt;
2729
pub mod ser;
2830
mod shared;
2931
mod type_serializers;
@@ -37,7 +39,7 @@ pub enum WarningsArg {
3739
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
3840
#[derive(Debug)]
3941
pub struct SchemaSerializer {
40-
serializer: CombinedSerializer,
42+
serializer: Arc<CombinedSerializer>,
4143
definitions: Definitions<CombinedSerializer>,
4244
expected_json_size: AtomicUsize,
4345
config: SerializationConfig,
@@ -92,7 +94,7 @@ impl SchemaSerializer {
9294
let mut definitions_builder = DefinitionsBuilder::new();
9395
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
9496
Ok(Self {
95-
serializer,
97+
serializer: Arc::new(serializer),
9698
definitions: definitions_builder.finish()?,
9799
expected_json_size: AtomicUsize::new(1024),
98100
config: SerializationConfig::from_config(config)?,

src/serializers/prebuilt.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use std::borrow::Cow;
2+
use std::sync::Arc;
3+
4+
use pyo3::exceptions::PyValueError;
5+
use pyo3::intern;
6+
use pyo3::prelude::*;
7+
use pyo3::types::{PyBool, PyDict, PyType};
8+
9+
use crate::definitions::DefinitionsBuilder;
10+
use crate::tools::SchemaDict;
11+
use crate::SchemaSerializer;
12+
13+
use super::extra::Extra;
14+
use super::shared::{BuildSerializer, CombinedSerializer, TypeSerializer};
15+
16+
#[derive(Debug)]
17+
pub struct PrebuiltSerializer {
18+
serializer: Arc<CombinedSerializer>,
19+
}
20+
21+
impl BuildSerializer for PrebuiltSerializer {
22+
const EXPECTED_TYPE: &'static str = "prebuilt";
23+
24+
fn build(
25+
schema: &Bound<'_, PyDict>,
26+
_config: Option<&Bound<'_, PyDict>>,
27+
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
28+
) -> PyResult<CombinedSerializer> {
29+
let py = schema.py();
30+
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
31+
32+
// note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
33+
// because we don't want to fetch prebuilt serializers from parent classes
34+
let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?;
35+
36+
// Ensure the class has completed its Pydantic setup
37+
let is_complete: bool = class_dict
38+
.get_as_req::<Bound<'_, PyBool>>(intern!(py, "__pydantic_complete__"))
39+
.is_ok_and(|b| b.extract().unwrap_or(false));
40+
41+
if !is_complete {
42+
return Err(PyValueError::new_err("Prebuilt serializer not found."));
43+
}
44+
45+
// Retrieve the prebuilt validator if available
46+
let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_serializer__"))?;
47+
let schema_serializer: PyRef<SchemaSerializer> = prebuilt_serializer.extract()?;
48+
let combined_serializer: Arc<CombinedSerializer> = schema_serializer.serializer.clone();
49+
50+
Ok(Self {
51+
serializer: combined_serializer,
52+
}
53+
.into())
54+
}
55+
}
56+
57+
impl_py_gc_traverse!(PrebuiltSerializer { serializer });
58+
59+
impl TypeSerializer for PrebuiltSerializer {
60+
fn to_python(
61+
&self,
62+
value: &Bound<'_, PyAny>,
63+
include: Option<&Bound<'_, PyAny>>,
64+
exclude: Option<&Bound<'_, PyAny>>,
65+
extra: &Extra,
66+
) -> PyResult<PyObject> {
67+
self.serializer.to_python(value, include, exclude, extra)
68+
}
69+
70+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
71+
self.serializer.json_key(key, extra)
72+
}
73+
74+
fn serde_serialize<S: serde::ser::Serializer>(
75+
&self,
76+
value: &Bound<'_, PyAny>,
77+
serializer: S,
78+
include: Option<&Bound<'_, PyAny>>,
79+
exclude: Option<&Bound<'_, PyAny>>,
80+
extra: &Extra,
81+
) -> Result<S::Ok, S::Error> {
82+
self.serializer
83+
.serde_serialize(value, serializer, include, exclude, extra)
84+
}
85+
86+
fn get_name(&self) -> &str {
87+
self.serializer.get_name()
88+
}
89+
90+
fn retry_with_lax_check(&self) -> bool {
91+
self.serializer.retry_with_lax_check()
92+
}
93+
}

src/serializers/shared.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ combined_serializer! {
8484
Function: super::type_serializers::function::FunctionPlainSerializer;
8585
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
8686
Fields: super::fields::GeneralFieldsSerializer;
87+
// prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum
88+
Prebuilt: super::prebuilt::PrebuiltSerializer;
8789
}
8890
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
8991
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
@@ -195,7 +197,16 @@ impl CombinedSerializer {
195197
}
196198

197199
let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?;
198-
Self::find_serializer(type_.to_str()?, schema, config, definitions)
200+
let type_ = type_.to_str()?;
201+
202+
// if we have a SchemaValidator on the type already, use it
203+
if matches!(type_, "model" | "dataclass" | "typed-dict") {
204+
if let Ok(prebuilt_serializer) = super::prebuilt::PrebuiltSerializer::build(schema, config, definitions) {
205+
return Ok(prebuilt_serializer);
206+
}
207+
}
208+
209+
Self::find_serializer(type_, schema, config, definitions)
199210
}
200211
}
201212

@@ -219,6 +230,7 @@ impl PyGcTraverse for CombinedSerializer {
219230
CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit),
220231
CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit),
221232
CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit),
233+
CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit),
222234
CombinedSerializer::None(inner) => inner.py_gc_traverse(visit),
223235
CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit),
224236
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),

0 commit comments

Comments
 (0)