diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 434b43c53..5516f81e1 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -14,7 +14,7 @@ pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValu use extra::{CollectWarnings, SerRecursionState, WarningsMode}; pub(crate) use extra::{DuckTypingSerMode, Extra, SerMode, SerializationState}; pub use shared::CombinedSerializer; -use shared::{to_json_bytes, BuildSerializer, TypeSerializer}; +use shared::{to_json_bytes, TypeSerializer}; mod computed_fields; mod config; @@ -91,7 +91,7 @@ impl SchemaSerializer { #[pyo3(signature = (schema, config=None))] pub fn py_new(schema: Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; + let serializer = CombinedSerializer::build_base(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { serializer, definitions: definitions_builder.finish()?, diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index f7a018749..e28ae9cee 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -149,10 +149,21 @@ combined_serializer! { } impl CombinedSerializer { + // Used when creating the base serializer instance, to avoid reusing the instance + // when unpickling: + pub fn build_base( + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, + definitions: &mut DefinitionsBuilder, + ) -> PyResult { + Self::_build(schema, config, definitions, false) + } + fn _build( schema: &Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>, definitions: &mut DefinitionsBuilder, + use_prebuilt: bool, ) -> PyResult { let py = schema.py(); let type_key = intern!(py, "type"); @@ -199,9 +210,13 @@ impl CombinedSerializer { let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?; let type_ = type_.to_str()?; - // if we have a SchemaValidator on the type already, use it - if let Ok(Some(prebuilt_serializer)) = super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) { - return Ok(prebuilt_serializer); + if use_prebuilt { + // if we have a SchemaValidator on the type already, use it + if let Ok(Some(prebuilt_serializer)) = + super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) + { + return Ok(prebuilt_serializer); + } } Self::find_serializer(type_, schema, config, definitions) @@ -217,7 +232,7 @@ impl BuildSerializer for CombinedSerializer { config: Option<&Bound<'_, PyDict>>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - Self::_build(schema, config, definitions) + Self::_build(schema, config, definitions, true) } } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index f105e1854..524de10bb 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -127,7 +127,7 @@ impl SchemaValidator { pub fn py_new(py: Python, schema: &Bound<'_, PyAny>, config: Option<&Bound<'_, PyDict>>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let validator = build_validator(schema, config, &mut definitions_builder)?; + let validator = build_validator_base(schema, config, &mut definitions_builder)?; let definitions = definitions_builder.finish()?; let py_schema = schema.clone().unbind(); let py_config = match config { @@ -159,11 +159,6 @@ impl SchemaValidator { }) } - pub fn __reduce__<'py>(slf: &Bound<'py, Self>) -> PyResult<(Bound<'py, PyType>, Bound<'py, PyTuple>)> { - let init_args = (&slf.get().py_schema, &slf.get().py_config).into_pyobject(slf.py())?; - Ok((slf.get_type(), init_args)) - } - #[allow(clippy::too_many_arguments)] #[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None, allow_partial=PartialMode::Off, by_alias=None, by_name=None))] pub fn validate_python( @@ -355,6 +350,11 @@ impl SchemaValidator { } } + pub fn __reduce__<'py>(slf: &Bound<'py, Self>) -> PyResult<(Bound<'py, PyType>, Bound<'py, PyTuple>)> { + let init_args = (&slf.get().py_schema, &slf.get().py_config).into_pyobject(slf.py())?; + Ok((slf.get_type(), init_args)) + } + pub fn __repr__(&self, py: Python) -> String { format!( "SchemaValidator(title={:?}, validator={:#?}, definitions={:#?}, cache_strings={})", @@ -553,19 +553,40 @@ macro_rules! validator_match { }; } +// Used when creating the base validator instance, to avoid reusing the instance +// when unpickling: +pub fn build_validator_base( + schema: &Bound<'_, PyAny>, + config: Option<&Bound<'_, PyDict>>, + definitions: &mut DefinitionsBuilder, +) -> PyResult { + build_validator_inner(schema, config, definitions, false) +} + pub fn build_validator( schema: &Bound<'_, PyAny>, config: Option<&Bound<'_, PyDict>>, definitions: &mut DefinitionsBuilder, +) -> PyResult { + build_validator_inner(schema, config, definitions, true) +} + +fn build_validator_inner( + schema: &Bound<'_, PyAny>, + config: Option<&Bound<'_, PyDict>>, + definitions: &mut DefinitionsBuilder, + use_prebuilt: bool, ) -> PyResult { let dict = schema.downcast::()?; let py = schema.py(); let type_: Bound<'_, PyString> = dict.get_as_req(intern!(py, "type"))?; let type_ = type_.to_str()?; - // if we have a SchemaValidator on the type already, use it - if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) { - return Ok(prebuilt_validator); + if use_prebuilt { + // if we have a SchemaValidator on the type already, use it + if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) { + return Ok(prebuilt_validator); + } } validator_match!( diff --git a/tests/serializers/test_pickling.py b/tests/serializers/test_pickling.py index 2ca230313..208cafbd9 100644 --- a/tests/serializers/test_pickling.py +++ b/tests/serializers/test_pickling.py @@ -48,3 +48,26 @@ def test_schema_serializer_containing_config(): assert s.to_python(timedelta(seconds=4, microseconds=500_000)) == timedelta(seconds=4, microseconds=500_000) assert s.to_python(timedelta(seconds=4, microseconds=500_000), mode='json') == 4.5 assert s.to_json(timedelta(seconds=4, microseconds=500_000)) == b'4.5' + + +# Should be defined at the module level for pickling to work: +class Model: + __pydantic_serializer__: SchemaSerializer + __pydantic_complete__ = True + + +def test_schema_serializer_not_reused_when_unpickling() -> None: + s = SchemaSerializer( + core_schema.model_schema( + cls=Model, + schema=core_schema.model_fields_schema(fields={}, model_name='Model'), + config={'title': 'Model'}, + ref='Model:123', + ) + ) + + Model.__pydantic_serializer__ = s + assert 'Prebuilt' not in str(Model.__pydantic_serializer__) + + reconstructed = pickle.loads(pickle.dumps(Model.__pydantic_serializer__)) + assert 'Prebuilt' not in str(reconstructed) diff --git a/tests/validators/test_pickling.py b/tests/validators/test_pickling.py index 2037ab8c9..b46c57029 100644 --- a/tests/validators/test_pickling.py +++ b/tests/validators/test_pickling.py @@ -51,3 +51,26 @@ def test_schema_validator_tz_pickle() -> None: validated = v.validate_python('2022-06-08T12:13:14-12:15') assert validated == original assert pickle.loads(pickle.dumps(validated)) == validated == original + + +# Should be defined at the module level for pickling to work: +class Model: + __pydantic_validator__: SchemaValidator + __pydantic_complete__ = True + + +def test_schema_validator_not_reused_when_unpickling() -> None: + s = SchemaValidator( + core_schema.model_schema( + cls=Model, + schema=core_schema.model_fields_schema(fields={}, model_name='Model'), + config={'title': 'Model'}, + ref='Model:123', + ) + ) + + Model.__pydantic_validator__ = s + assert 'Prebuilt' not in str(Model.__pydantic_validator__) + + reconstructed = pickle.loads(pickle.dumps(Model.__pydantic_validator__)) + assert 'Prebuilt' not in str(reconstructed)