diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index ba47445b5..1900d4e83 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -48,6 +48,10 @@ impl DuckTypingSerMode { } } + pub fn is_need_inference(self) -> bool { + self == DuckTypingSerMode::NeedsInference + } + pub fn to_bool(self) -> bool { match self { DuckTypingSerMode::SchemaBased => false, diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 38fd7cea5..3d1ba2412 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -7,6 +7,7 @@ use pyo3::types::{PyDict, PyList}; use crate::definitions::DefinitionsBuilder; use crate::definitions::{DefinitionRef, RecursionSafeCache}; +use crate::serializers::DuckTypingSerMode; use crate::tools::SchemaDict; @@ -93,8 +94,12 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> PyResult { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let mut guard = extra.recursion_guard(value, self.definition.id())?; - comb_serializer.to_python(value, include, exclude, guard.state()) + if extra.duck_typing_ser_mode == DuckTypingSerMode::NeedsInference { + comb_serializer.to_python(value, include, exclude, extra) + } else { + let mut guard = extra.recursion_guard(value, self.definition.id())?; + comb_serializer.to_python(value, include, exclude, guard.state()) + } }) } @@ -112,10 +117,14 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> Result { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let mut guard = extra - .recursion_guard(value, self.definition.id()) - .map_err(py_err_se_err)?; - comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state()) + if extra.duck_typing_ser_mode.is_need_inference() { + comb_serializer.serde_serialize(value, serializer, include, exclude, extra) + } else { + let mut guard = extra + .recursion_guard(value, self.definition.id()) + .map_err(py_err_se_err)?; + comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state()) + } }) } diff --git a/tests/serializers/test_serialize_as_any.py b/tests/serializers/test_serialize_as_any.py index 1d1be238b..ea7e7fd67 100644 --- a/tests/serializers/test_serialize_as_any.py +++ b/tests/serializers/test_serialize_as_any.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional from typing_extensions import TypedDict @@ -155,3 +156,42 @@ class Other: 'x': 1, 'y': 'hopefully not a secret', } + + +def test_serialize_with_recursive_models() -> None: + class Node: + next: Optional['Node'] = None + value: int = 42 + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('Node'), + [ + core_schema.model_schema( + Node, + core_schema.model_fields_schema( + { + 'value': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=42) + ), + 'next': core_schema.model_field( + core_schema.with_default_schema( + core_schema.nullable_schema(core_schema.definition_reference_schema('Node')), + default=None, + ) + ), + } + ), + ref='Node', + ) + ], + ) + + Node.__pydantic_core_schema__ = schema + Node.__pydantic_validator__ = SchemaValidator(Node.__pydantic_core_schema__) + Node.__pydantic_serializer__ = SchemaSerializer(Node.__pydantic_core_schema__) + other = Node.__pydantic_validator__.validate_python({'next': {'value': 4}}) + + assert Node.__pydantic_serializer__.to_python(other, serialize_as_any=True) == { + 'next': {'next': None, 'value': 4}, + 'value': 42, + }