Skip to content

Commit cb281b5

Browse files
committed
add polymorphic_serialization config for models and dataclasses
1 parent 20d576b commit cb281b5

File tree

5 files changed

+214
-2
lines changed

5 files changed

+214
-2
lines changed

src/serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ mod fields;
2424
mod filter;
2525
mod infer;
2626
mod ob_type;
27+
mod polymorphism_trampoline;
2728
mod prebuilt;
2829
pub mod ser;
2930
mod shared;
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
use std::{borrow::Cow, sync::Arc};
2+
3+
use pyo3::{prelude::*, types::PyType};
4+
5+
use crate::serializers::{
6+
errors::unwrap_ser_error,
7+
infer::call_pydantic_serializer,
8+
shared::{serialize_to_json, serialize_to_python, DoSerialize, TypeSerializer},
9+
CombinedSerializer, SerializationState,
10+
};
11+
12+
/// The polymorphism trampoline detects subclasses of its target type and dispatches to their
13+
/// `__pydantic_serializer__` serializer for serialization.
14+
///
15+
/// This exists as a separate structure to allow for cases such as model serializers where the
16+
/// inner serializer may just be a function serializer and so cannot handle polymorphism itself.
17+
pub struct PolymorphismTrampoline {
18+
class: Py<PyType>,
19+
/// Inner serializer used when the type is not a subclass (responsible for any fallback etc)
20+
serializer: Arc<CombinedSerializer>,
21+
/// Whether polymorphic serialization is enabled from config
22+
enabled_from_config: bool,
23+
}
24+
25+
impl_py_gc_traverse!(PolymorphismTrampoline { class, serializer });
26+
27+
impl std::fmt::Debug for PolymorphismTrampoline {
28+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29+
// to avoid breaking the repr from before we inserted the trampoline
30+
self.serializer.fmt(f)
31+
}
32+
}
33+
34+
impl PolymorphismTrampoline {
35+
pub fn new(class: Py<PyType>, serializer: Arc<CombinedSerializer>, enabled_from_config: bool) -> Self {
36+
Self {
37+
class,
38+
serializer,
39+
enabled_from_config,
40+
}
41+
}
42+
43+
fn is_subclass(&self, value: &Bound<'_, PyAny>) -> PyResult<bool> {
44+
Ok(!value.get_type().is(&self.class) && value.is_instance(self.class.bind(value.py()))?)
45+
}
46+
47+
fn serialize<'py, T, E: From<PyErr>>(
48+
&self,
49+
value: &Bound<'py, PyAny>,
50+
state: &mut SerializationState<'_, 'py>,
51+
do_serialize: impl DoSerialize<'py, T, E>,
52+
) -> Result<T, E> {
53+
// FIXME: allow per-call override of polymorphic serialization?
54+
if self.enabled_from_config && self.is_subclass(value)? {
55+
call_pydantic_serializer(value, state, do_serialize)
56+
} else {
57+
do_serialize.serialize_no_infer(&self.serializer, value, state)
58+
}
59+
}
60+
}
61+
62+
impl TypeSerializer for PolymorphismTrampoline {
63+
fn to_python<'py>(
64+
&self,
65+
value: &Bound<'py, PyAny>,
66+
state: &mut SerializationState<'_, 'py>,
67+
) -> PyResult<Py<PyAny>> {
68+
self.serialize(value, state, serialize_to_python())
69+
}
70+
71+
fn json_key<'a, 'py>(
72+
&self,
73+
key: &'a Bound<'py, PyAny>,
74+
state: &mut SerializationState<'_, 'py>,
75+
) -> PyResult<Cow<'a, str>> {
76+
// json key serialization for models and dataclasses was always polymorphic anyway
77+
// FIXME: make this consistent with the other cases?
78+
self.serializer.json_key(key, state)
79+
}
80+
81+
fn serde_serialize<'py, S: serde::ser::Serializer>(
82+
&self,
83+
value: &Bound<'py, PyAny>,
84+
serializer: S,
85+
state: &mut SerializationState<'_, 'py>,
86+
) -> Result<S::Ok, S::Error> {
87+
self.serialize(value, state, serialize_to_json(serializer))
88+
.map_err(unwrap_ser_error)
89+
}
90+
91+
fn get_name(&self) -> &str {
92+
self.serializer.get_name()
93+
}
94+
}

src/serializers/shared.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::build_tools::py_schema_error_type;
1919
use crate::definitions::DefinitionsBuilder;
2020
use crate::py_gc::PyGcTraverse;
2121
use crate::serializers::errors::WrappedSerError;
22+
use crate::serializers::polymorphism_trampoline::PolymorphismTrampoline;
2223
use crate::serializers::ser::PythonSerializer;
2324
use crate::serializers::type_serializers::any::AnySerializer;
2425
use crate::tools::{py_err, SchemaDict};
@@ -91,6 +92,9 @@ combined_serializer! {
9192
Fields: super::fields::GeneralFieldsSerializer;
9293
// prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum
9394
Prebuilt: super::prebuilt::PrebuiltSerializer;
95+
// polymorphism trampoline is manually constructed to wrap models and dataclasses with
96+
// polymorphic serialization
97+
PolymorphismTrampoline: super::polymorphism_trampoline::PolymorphismTrampoline;
9498
}
9599
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
96100
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
@@ -162,7 +166,8 @@ impl CombinedSerializer {
162166
config: Option<&Bound<'_, PyDict>>,
163167
definitions: &mut DefinitionsBuilder<Arc<CombinedSerializer>>,
164168
) -> PyResult<Arc<CombinedSerializer>> {
165-
Self::_build(schema, config, definitions, false)
169+
let serializer = Self::_build(schema, config, definitions, false)?;
170+
Self::maybe_wrap_in_polymorphism_trampoline(serializer, schema)
166171
}
167172

168173
fn _build(
@@ -230,6 +235,34 @@ impl CombinedSerializer {
230235
Self::find_serializer(type_, schema, config, definitions)
231236
}
232237

238+
fn maybe_wrap_in_polymorphism_trampoline(
239+
serializer: Arc<CombinedSerializer>,
240+
schema: &Bound<'_, PyDict>,
241+
) -> PyResult<Arc<CombinedSerializer>> {
242+
let py = schema.py();
243+
let type_: Bound<'_, PyString> = schema.get_as_req(intern!(py, "type"))?;
244+
let type_ = type_.to_str()?;
245+
246+
if type_ == "model" || type_ == "dataclass" {
247+
// Get polymorphic serialization config
248+
let config = schema.get_as::<Bound<'_, PyDict>>(intern!(py, "config"))?;
249+
let polymorphic_serialization: bool = config
250+
.and_then(|cfg| cfg.get_as(intern!(py, "polymorphic_serialization")).transpose())
251+
.unwrap_or(Ok(false))?;
252+
253+
Ok(Arc::new(
254+
PolymorphismTrampoline::new(
255+
schema.get_as_req(intern!(py, "cls"))?,
256+
serializer,
257+
polymorphic_serialization,
258+
)
259+
.into(),
260+
))
261+
} else {
262+
Ok(serializer)
263+
}
264+
}
265+
233266
/// Main recursive way to call serializers, supports possible recursive type inference by
234267
/// switching to type inference mode eagerly.
235268
pub fn to_python<'py>(
@@ -308,7 +341,8 @@ impl BuildSerializer for CombinedSerializer {
308341
config: Option<&Bound<'_, PyDict>>,
309342
definitions: &mut DefinitionsBuilder<Arc<CombinedSerializer>>,
310343
) -> PyResult<Arc<CombinedSerializer>> {
311-
Self::_build(schema, config, definitions, true)
344+
let serializer = Self::_build(schema, config, definitions, true)?;
345+
Self::maybe_wrap_in_polymorphism_trampoline(serializer, schema)
312346
}
313347
}
314348

@@ -356,6 +390,7 @@ impl PyGcTraverse for CombinedSerializer {
356390
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
357391
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
358392
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
393+
CombinedSerializer::PolymorphismTrampoline(inner) => inner.py_gc_traverse(visit),
359394
}
360395
}
361396
}

tests/serializers/test_dataclasses.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,46 @@ class Foo:
286286
)
287287
s = SchemaSerializer(schema)
288288
assert s.to_python(Foo(my_field='hello'), by_alias=runtime) == expected
289+
290+
291+
@pytest.mark.parametrize('config', [True, False])
292+
@pytest.mark.parametrize('runtime', [True, False])
293+
def test_polymorphic_serialization(config: bool, runtime: bool) -> None:
294+
@dataclasses.dataclass
295+
class ClassA:
296+
a: int
297+
298+
@dataclasses.dataclass
299+
class ClassB(ClassA):
300+
b: str
301+
302+
schema_a = core_schema.dataclass_schema(
303+
ClassA,
304+
core_schema.dataclass_args_schema(
305+
'ClassA', [core_schema.dataclass_field(name='a', schema=core_schema.int_schema())]
306+
),
307+
['a'],
308+
config=core_schema.CoreConfig(polymorphic_serialization=config),
309+
)
310+
311+
schema_b = core_schema.dataclass_schema(
312+
ClassB,
313+
core_schema.dataclass_args_schema(
314+
'ClassB',
315+
[
316+
core_schema.dataclass_field(name='a', schema=core_schema.int_schema()),
317+
core_schema.dataclass_field(name='b', schema=core_schema.str_schema()),
318+
],
319+
),
320+
['a', 'b'],
321+
)
322+
323+
ClassA.__pydantic_serializer__ = SchemaSerializer(schema_a)
324+
ClassB.__pydantic_serializer__ = SchemaSerializer(schema_b)
325+
326+
assert ClassA.__pydantic_serializer__.to_python(ClassA(123)) == {'a': 123}
327+
328+
if config:
329+
assert ClassA.__pydantic_serializer__.to_python(ClassB(123, 'test')) == {'a': 123, 'b': 'test'}
330+
else:
331+
assert ClassA.__pydantic_serializer__.to_python(ClassB(123, 'test')) == {'a': 123}

tests/serializers/test_model.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,3 +1284,42 @@ def __init__(self, my_field: int) -> None:
12841284
)
12851285
s = SchemaSerializer(schema)
12861286
assert s.to_python(Model(1), by_alias=runtime) == expected
1287+
1288+
1289+
@pytest.mark.parametrize('config', [True, False])
1290+
@pytest.mark.parametrize('runtime', [True, False])
1291+
def test_polymorphic_serialization(config: bool, runtime: bool) -> None:
1292+
class ModelA:
1293+
def __init__(self, a: int) -> None:
1294+
self.a = a
1295+
1296+
class ModelB(ModelA):
1297+
def __init__(self, a: int, b: str) -> None:
1298+
super().__init__(a)
1299+
self.b = b
1300+
1301+
schema_a = core_schema.model_schema(
1302+
ModelA,
1303+
core_schema.model_fields_schema({'a': core_schema.model_field(core_schema.int_schema())}),
1304+
config=core_schema.CoreConfig(polymorphic_serialization=config),
1305+
)
1306+
1307+
schema_b = core_schema.model_schema(
1308+
ModelB,
1309+
core_schema.model_fields_schema(
1310+
{
1311+
'a': core_schema.model_field(core_schema.int_schema()),
1312+
'b': core_schema.model_field(core_schema.str_schema()),
1313+
}
1314+
),
1315+
)
1316+
1317+
ModelA.__pydantic_serializer__ = SchemaSerializer(schema_a)
1318+
ModelB.__pydantic_serializer__ = SchemaSerializer(schema_b)
1319+
1320+
assert ModelA.__pydantic_serializer__.to_python(ModelA(123)) == {'a': 123}
1321+
1322+
if config:
1323+
assert ModelA.__pydantic_serializer__.to_python(ModelB(123, 'test')) == {'a': 123, 'b': 'test'}
1324+
else:
1325+
assert ModelA.__pydantic_serializer__.to_python(ModelB(123, 'test')) == {'a': 123}

0 commit comments

Comments
 (0)