diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 0449a95da..73589d806 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -116,11 +116,15 @@ impl Validator for EnumValidator { }, input, )); - } else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? { - state.floor_exactness(Exactness::Lax); + } + + state.floor_exactness(Exactness::Lax); + + if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? { return Ok(v); + } else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) { + return Ok(res); } else if let Some(ref missing) = self.missing { - state.floor_exactness(Exactness::Lax); let enum_value = missing.bind(py).call1((input.to_object(py),)).map_err(|_| { ValError::new( ErrorType::Enum { @@ -146,6 +150,7 @@ impl Validator for EnumValidator { return Err(type_error.into()); } } + Err(ValError::new( ErrorType::Enum { expected: self.expected_repr.clone(), diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index a4fb3c5c2..83e286417 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -1,5 +1,6 @@ import re import sys +from decimal import Decimal from enum import Enum, IntEnum, IntFlag import pytest @@ -344,3 +345,145 @@ class ColorEnum(IntEnum): assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN assert v.validate_python(1 << 63) is ColorEnum.GREEN + + +@pytest.mark.parametrize( + 'value', + [-1, 0, 1], +) +def test_enum_int_validation_should_succeed_for_decimal(value: int): + class MyEnum(Enum): + VALUE = value + + class MyIntEnum(IntEnum): + VALUE = value + + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + v_int = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())), + default=MyIntEnum.VALUE, + ) + ) + + assert v.validate_python(Decimal(value)) is MyEnum.VALUE + assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE + assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE + assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE + + +@pytest.mark.skipif( + sys.version_info >= (3, 13), + reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163', +) +def test_enum_int_validation_should_succeed_for_custom_type(): + class AnyWrapper: + def __init__(self, value): + self.value = value + + def __eq__(self, other: object) -> bool: + return self.value == other + + class MyEnum(Enum): + VALUE = 999 + SECOND_VALUE = 1000000 + THIRD_VALUE = 'Py03' + + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE + assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE + assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE + + +def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value(): + class MyEnum(Enum): + VALUE = '1' + + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + with pytest.raises(ValidationError): + v.validate_python(Decimal(1)) + + +def test_enum_int_validation_should_fail_for_incorrect_decimal_value(): + class MyEnum(Enum): + VALUE = 1 + + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + with pytest.raises(ValidationError): + v.validate_python(Decimal(2)) + + with pytest.raises(ValidationError): + v.validate_python((1, 2)) + + with pytest.raises(ValidationError): + v.validate_python(Decimal(1.1)) + + +def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking(): + class MyEnum(Enum): + VALUE = 1 + + class MyClass: + def __init__(self, value): + self.value = value + + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + with pytest.raises(ValidationError): + v.validate_python(MyClass(1)) + + +def support_custom_new_method() -> None: + """Demonstrates support for custom new methods, as well as conceptually, multi-value enums without dependency on a 3rd party lib for testing.""" + + class Animal(Enum): + CAT = 'cat', 'meow' + DOG = 'dog', 'woof' + + def __new__(cls, species: str, sound: str): + obj = object.__new__(cls) + + obj._value_ = species + obj._all_values = (species, sound) + + obj.species = species + obj.sound = sound + + cls._value2member_map_[sound] = obj + + return obj + + v = SchemaValidator(core_schema.enum_schema(Animal, list(Animal.__members__.values()))) + assert v.validate_python('cat') is Animal.CAT + assert v.validate_python('meow') is Animal.CAT + assert v.validate_python('dog') is Animal.DOG + assert v.validate_python('woof') is Animal.DOG diff --git a/tests/validators/test_model.py b/tests/validators/test_model.py index bdcfbb2e9..b6a60f55c 100644 --- a/tests/validators/test_model.py +++ b/tests/validators/test_model.py @@ -1,5 +1,7 @@ import re +import sys from copy import deepcopy +from decimal import Decimal from typing import Any, Callable, Dict, List, Set, Tuple import pytest @@ -1312,3 +1314,66 @@ class OtherModel: 'ctx': {'class_name': 'MyModel'}, } ] + + +@pytest.mark.skipif( + sys.version_info >= (3, 13), + reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163', +) +def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks(): + # GIVEN + from enum import Enum + + class EnumClass(Enum): + enum_value = 1 + enum_value_2 = 2 + enum_value_3 = 3 + + class IntWrappable: + def __init__(self, value: int): + self.value = value + + def __eq__(self, other: object) -> bool: + return self.value == other + + class MyModel: + __slots__ = ( + '__dict__', + '__pydantic_fields_set__', + '__pydantic_extra__', + '__pydantic_private__', + ) + enum_field: EnumClass + + # WHEN + v = SchemaValidator( + core_schema.model_schema( + MyModel, + core_schema.model_fields_schema( + { + 'enum_field': core_schema.model_field( + core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values())) + ), + 'enum_field_2': core_schema.model_field( + core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values())) + ), + 'enum_field_3': core_schema.model_field( + core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values())) + ), + } + ), + ) + ) + + # THEN + v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}') + m = v.validate_python( + { + 'enum_field': Decimal(1), + 'enum_field_2': Decimal(2), + 'enum_field_3': IntWrappable(3), + } + ) + v.validate_assignment(m, 'enum_field', Decimal(1)) + v.validate_assignment(m, 'enum_field_2', Decimal(2)) + v.validate_assignment(m, 'enum_field_3', IntWrappable(3))