diff --git a/.gitignore b/.gitignore index efffcbf69..bd5a6acbd 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ docs/_build/ htmlcov/ node_modules/ +.venv /.benchmarks/ /.idea/ /.pytest_cache/ diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 087d0ca5f..6f964ab6a 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -79,6 +79,7 @@ pub trait EnumValidateValue: std::fmt::Debug + Clone + Send + Sync { py: Python<'py>, input: &I, lookup: &LiteralLookup, + class: &Py, strict: bool, ) -> ValResult>; } @@ -116,7 +117,7 @@ impl Validator for EnumValidator { }, input, )); - } else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? { + } else if let Some(v) = T::validate_value(py, input, &self.lookup, &self.class, strict)? { state.floor_exactness(Exactness::Lax); return Ok(v); } else if let Some(ref missing) = self.missing { @@ -167,6 +168,7 @@ impl EnumValidateValue for PlainEnumValidator { py: Python<'py>, input: &I, lookup: &LiteralLookup, + class: &Py, strict: bool, ) -> ValResult> { match lookup.validate(py, input)? { @@ -183,8 +185,14 @@ impl EnumValidateValue for PlainEnumValidator { } else if py_input.is_instance_of::() { return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py))); } + if py_input.is_instance_of::() { + if let Ok(res) = class.call1(py, (py_input,)) { + return Ok(Some(res)); + } + } } } + Ok(None) } } @@ -201,6 +209,7 @@ impl EnumValidateValue for IntEnumValidator { py: Python<'py>, input: &I, lookup: &LiteralLookup, + _class: &Py, strict: bool, ) -> ValResult> { Ok(lookup.validate_int(py, input, strict)?.map(|v| v.clone_ref(py))) @@ -217,6 +226,7 @@ impl EnumValidateValue for StrEnumValidator { py: Python, input: &I, lookup: &LiteralLookup, + _class: &Py, strict: bool, ) -> ValResult> { Ok(lookup.validate_str(input, strict)?.map(|v| v.clone_ref(py))) @@ -233,6 +243,7 @@ impl EnumValidateValue for FloatEnumValidator { py: Python<'py>, input: &I, lookup: &LiteralLookup, + _class: &Py, strict: bool, ) -> ValResult> { Ok(lookup.validate_float(py, input, strict)?.map(|v| v.clone_ref(py))) diff --git a/src/validators/generator.rs b/src/validators/generator.rs index de9949a8c..8d2c5651d 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -86,6 +86,7 @@ impl Validator for GeneratorValidator { hide_input_in_errors: self.hide_input_in_errors, validation_error_cause: self.validation_error_cause, }; + Ok(v_iterator.into_py(py)) } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 16eef090d..728b9b4d0 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -124,6 +124,7 @@ impl LiteralLookup { } } } + if let Some(expected_strings) = &self.expected_str { let validation_result = if input.as_python().is_some() { input.exact_str() @@ -163,6 +164,7 @@ impl LiteralLookup { } } }; + Ok(None) } diff --git a/src/validators/model.rs b/src/validators/model.rs index 741b6f9f0..bf4ef5bad 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -270,7 +270,6 @@ impl ModelValidator { .map_err(|e| convert_err(py, e, input)); } } - let output = self.validator.validate(py, input, state)?; let instance = create_class(self.class.bind(py))?; diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index a4fb3c5c2..09b48e5ae 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -1,5 +1,6 @@ import re import sys +from decimal import Decimal from enum import Enum, IntEnum, IntFlag import pytest @@ -344,3 +345,130 @@ class ColorEnum(IntEnum): assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN assert v.validate_python(1 << 63) is ColorEnum.GREEN + + +@pytest.mark.parametrize( + 'value', + [-1, 0, 1], +) +def test_enum_int_validation_should_succeed_for_decimal(value: int): + # GIVEN + class MyEnum(Enum): + VALUE = value + + class MyIntEnum(IntEnum): + VALUE = value + + # WHEN + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + v_int = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())), + default=MyIntEnum.VALUE, + ) + ) + + # THEN + assert v.validate_python(Decimal(value)) is MyEnum.VALUE + assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE + + assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE + assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE + + +def test_enum_int_validation_should_succeed_for_custom_type(): + # GIVEN + class AnyWrapper: + def __init__(self, value): + self.value = value + + def __eq__(self, other: object) -> bool: + return self.value == other + + class MyEnum(Enum): + VALUE = 999 + SECOND_VALUE = 1000000 + THIRD_VALUE = 'Py03' + + # WHEN + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + # THEN + assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE + assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE + assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE + + +def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value(): + # GIVEN + class MyEnum(Enum): + VALUE = '1' + + # WHEN + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + # THEN + with pytest.raises(ValidationError): + v.validate_python(Decimal(1)) + + +def test_enum_int_validation_should_fail_for_incorrect_decimal_value(): + # GIVEN + class MyEnum(Enum): + VALUE = 1 + + # WHEN + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + # THEN + with pytest.raises(ValidationError): + v.validate_python(Decimal(2)) + + with pytest.raises(ValidationError): + v.validate_python((1, 2)) + + with pytest.raises(ValidationError): + v.validate_python(Decimal(1.1)) + + +def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking(): + # GIVEN + class MyEnum(Enum): + VALUE = 1 + + class MyClass: + def __init__(self, value): + self.value = value + + # WHEN + v = SchemaValidator( + core_schema.with_default_schema( + schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())), + default=MyEnum.VALUE, + ) + ) + + # THEN + with pytest.raises(ValidationError): + v.validate_python(MyClass(1)) diff --git a/tests/validators/test_model.py b/tests/validators/test_model.py index bdcfbb2e9..d8a3dc03e 100644 --- a/tests/validators/test_model.py +++ b/tests/validators/test_model.py @@ -1,5 +1,6 @@ import re from copy import deepcopy +from decimal import Decimal from typing import Any, Callable, Dict, List, Set, Tuple import pytest @@ -1312,3 +1313,62 @@ class OtherModel: 'ctx': {'class_name': 'MyModel'}, } ] + + +def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks(): + # GIVEN + from enum import Enum + + class EnumClass(Enum): + enum_value = 1 + enum_value_2 = 2 + enum_value_3 = 3 + + class IntWrappable: + def __init__(self, value: int): + self.value = value + + def __eq__(self, value: object) -> bool: + return self.value == value + + class MyModel: + __slots__ = ( + '__dict__', + '__pydantic_fields_set__', + '__pydantic_extra__', + '__pydantic_private__', + ) + enum_field: EnumClass + + # WHEN + v = SchemaValidator( + core_schema.model_schema( + MyModel, + core_schema.model_fields_schema( + { + 'enum_field': core_schema.model_field( + core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values())) + ), + 'enum_field_2': core_schema.model_field( + core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values())) + ), + 'enum_field_3': core_schema.model_field( + core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values())) + ), + } + ), + ) + ) + + # THEN + v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}') + m = v.validate_python( + { + 'enum_field': Decimal(1), + 'enum_field_2': Decimal(2), + 'enum_field_3': IntWrappable(3), + } + ) + v.validate_assignment(m, 'enum_field', Decimal(1)) + v.validate_assignment(m, 'enum_field_2', Decimal(2)) + v.validate_assignment(m, 'enum_field_3', IntWrappable(3))