From d1f23434b0c72b84351b43d1cf4ab60e05b7a9ab Mon Sep 17 00:00:00 2001 From: Andres Date: Sun, 3 Dec 2023 20:09:39 -0300 Subject: [PATCH 1/2] add default comparison --- python/pydantic_core/core_schema.py | 3 +++ src/serializers/fields.rs | 7 ++--- src/serializers/shared.rs | 4 +++ .../type_serializers/with_default.rs | 22 ++++++++++++++-- tests/serializers/test_typed_dict.py | 26 ++++++++++++++++++- 5 files changed, 54 insertions(+), 8 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index daed22d48..ed78e181c 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -2186,6 +2186,7 @@ def with_default_schema( *, default: Any = PydanticUndefined, default_factory: Callable[[], Any] | None = None, + default_comparison: Callable[[Any, Any], bool] | None = None, on_error: Literal['raise', 'omit', 'default'] | None = None, validate_default: bool | None = None, strict: bool | None = None, @@ -2211,6 +2212,7 @@ def with_default_schema( schema: The schema to add a default value to default: The default value to use default_factory: A function that returns the default value to use + default_comparison: A function to compare the default value with any other given on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default' validate_default: Whether the default value should be validated strict: Whether the underlying schema should be validated with strict mode @@ -2222,6 +2224,7 @@ def with_default_schema( type='default', schema=schema, default_factory=default_factory, + default_comparison=default_comparison, on_error=on_error, validate_default=validate_default, strict=strict, diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index f48f8e0b9..05959c2a9 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -68,12 +68,9 @@ impl SerField { } fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer) -> PyResult { + let py = value.py(); if extra.exclude_defaults { - if let Some(default) = serializer.get_default(value.py())? { - if value.eq(default)? { - return Ok(true); - } - } + return serializer.compare_with_default(py, value); } Ok(false) } diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index cfccc748a..bd2213e40 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -301,6 +301,10 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug { fn get_default(&self, _py: Python) -> PyResult> { Ok(None) } + + fn compare_with_default(&self, _py: Python, _value: &PyAny) -> PyResult { + Ok(false) + } } pub(crate) struct PydanticSerializer<'py> { diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index d20c273a1..3dacb2a24 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -13,6 +13,7 @@ use super::{BuildSerializer, CombinedSerializer, Extra, TypeSerializer}; #[derive(Debug, Clone)] pub struct WithDefaultSerializer { default: DefaultType, + default_comparison: Option, serializer: Box, } @@ -26,11 +27,16 @@ impl BuildSerializer for WithDefaultSerializer { ) -> PyResult { let py = schema.py(); let default = DefaultType::new(schema)?; - + let default_comparison = schema.get_as(intern!(py, "default_comparison"))?; let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?; let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?); - Ok(Self { default, serializer }.into()) + Ok(Self { + default, + default_comparison, + serializer, + } + .into()) } } @@ -74,4 +80,16 @@ impl TypeSerializer for WithDefaultSerializer { fn get_default(&self, py: Python) -> PyResult> { self.default.default_value(py) } + + fn compare_with_default(&self, py: Python, value: &PyAny) -> PyResult { + if let Some(default) = self.get_default(py)? { + if let Some(default_comparison) = &self.default_comparison { + return default_comparison.call(py, (value, default), None)?.extract::(py); + } else if value.eq(default)? { + return Ok(true); + } + } + + Ok(false) + } } diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index df507a248..c285af3a9 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -150,6 +150,13 @@ def test_exclude_none(): def test_exclude_default(): + class TestComparison: + def __init__(self, val: Any): + self.val = val + + def __eq__(self, other): + return self.val == other.val + v = SchemaSerializer( core_schema.typed_dict_schema( { @@ -157,6 +164,13 @@ def test_exclude_default(): 'bar': core_schema.typed_dict_field( core_schema.with_default_schema(core_schema.bytes_schema(), default=b'[default]') ), + 'foobar': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.any_schema(), + default=TestComparison(val=1), + default_comparison=lambda value, default: value.val == -1 * default.val, + ) + ), } ) ) @@ -165,9 +179,19 @@ def test_exclude_default(): assert v.to_python({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == {'foo': 1} assert v.to_python({'foo': 1, 'bar': b'[default]'}, mode='json') == {'foo': 1, 'bar': '[default]'} assert v.to_python({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True, mode='json') == {'foo': 1} - assert v.to_json({'foo': 1, 'bar': b'[default]'}) == b'{"foo":1,"bar":"[default]"}' assert v.to_json({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == b'{"foo":1}' + # Note that due to the custom comparison operator foobar must be excluded + assert v.to_python({'foo': 1, 'bar': b'x', 'foobar': TestComparison(val=-1)}, exclude_defaults=True) == { + 'foo': 1, + 'bar': b'x', + } + # foobar here must be included + assert v.to_python({'foo': 1, 'bar': b'x', 'foobar': TestComparison(val=1)}, exclude_defaults=True) == { + 'foo': 1, + 'bar': b'x', + 'foobar': TestComparison(val=1), + } def test_function_plain_field_serializer_to_python(): From 06140112123e72b2bff7b5dc68633c7be7efbab7 Mon Sep 17 00:00:00 2001 From: Andres Date: Mon, 4 Dec 2023 11:44:40 -0300 Subject: [PATCH 2/2] update test --- tests/serializers/test_typed_dict.py | 67 +++++++++++++++++----------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index c285af3a9..2ed4cf980 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict +from typing import Any, Dict, List, Union import pytest from dirty_equals import IsStrictDict @@ -151,29 +151,37 @@ def test_exclude_none(): def test_exclude_default(): class TestComparison: - def __init__(self, val: Any): - self.val = val + def __init__(self, array: Union[List[int], List[float]]): + self.array = array def __eq__(self, other): - return self.val == other.val - - v = SchemaSerializer( - core_schema.typed_dict_schema( - { - 'foo': core_schema.typed_dict_field(core_schema.nullable_schema(core_schema.int_schema())), - 'bar': core_schema.typed_dict_field( - core_schema.with_default_schema(core_schema.bytes_schema(), default=b'[default]') - ), - 'foobar': core_schema.typed_dict_field( - core_schema.with_default_schema( - core_schema.any_schema(), - default=TestComparison(val=1), - default_comparison=lambda value, default: value.val == -1 * default.val, - ) - ), - } - ) + """Simple comparison arrays have to match also in the dtype""" + # Test case we just look at the first element to get the dtype of the array + self_is_integer = isinstance(self.array[0], int) + other_is_integer = isinstance(other.array[0], int) + return self_is_integer == other_is_integer and self.array == other.array + + def custom_comparison_operator(value: TestComparison, default: TestComparison): + """Will replace __eq__ in TestComparison omiting the dtype check""" + return value.array == default.array + + dict_schema = core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field(core_schema.nullable_schema(core_schema.int_schema())), + 'bar': core_schema.typed_dict_field( + core_schema.with_default_schema(core_schema.bytes_schema(), default=b'[default]') + ), + 'foobar': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.any_schema(), + default=TestComparison(array=[1, 2, 3]), + default_comparison=custom_comparison_operator, + ) + ), + } ) + + v = SchemaSerializer(dict_schema) assert v.to_python({'foo': 1, 'bar': b'x'}) == {'foo': 1, 'bar': b'x'} assert v.to_python({'foo': 1, 'bar': b'[default]'}) == {'foo': 1, 'bar': b'[default]'} assert v.to_python({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == {'foo': 1} @@ -181,16 +189,23 @@ def __eq__(self, other): assert v.to_python({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True, mode='json') == {'foo': 1} assert v.to_json({'foo': 1, 'bar': b'[default]'}) == b'{"foo":1,"bar":"[default]"}' assert v.to_json({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == b'{"foo":1}' - # Note that due to the custom comparison operator foobar must be excluded - assert v.to_python({'foo': 1, 'bar': b'x', 'foobar': TestComparison(val=-1)}, exclude_defaults=True) == { + # Note that due to the custom comparison operator foobar must be excluded because this operator doesn't pay attention on the dtype + assert v.to_python( + {'foo': 1, 'bar': b'x', 'foobar': TestComparison(array=[1.0, 2.0, 3.0])}, exclude_defaults=True + ) == { 'foo': 1, 'bar': b'x', } - # foobar here must be included - assert v.to_python({'foo': 1, 'bar': b'x', 'foobar': TestComparison(val=1)}, exclude_defaults=True) == { + # Now removing custom comparison operator foobar must be included due that TestComparison.__eq__ checks the array dtype + # So TestComparison(array=[1.0, 2.0, 3.0]) is not equal to the default TestComparison(array=[1, 2, 3]) + del dict_schema['fields']['foobar']['schema']['default_comparison'] + v = SchemaSerializer(dict_schema) + assert v.to_python( + {'foo': 1, 'bar': b'x', 'foobar': TestComparison(array=[1.0, 2.0, 3.0])}, exclude_defaults=True + ) == { 'foo': 1, 'bar': b'x', - 'foobar': TestComparison(val=1), + 'foobar': TestComparison(array=[1.0, 2.0, 3.0]), }