Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions src/validators/enum_.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
// Validator for Enums, so named because "enum" is a reserved keyword in Rust.
use std::marker::PhantomData;

use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType};

use crate::build_tools::{is_strict, py_schema_err};
use crate::errors::{ErrorType, ValError, ValResult};
use crate::input::Input;
use crate::tools::{safe_repr, SchemaDict};
use crate::tools::SchemaDict;

use super::is_instance::class_repr;
use super::literal::{expected_repr_name, LiteralLookup};
Expand Down Expand Up @@ -119,33 +118,11 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
} else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
state.floor_exactness(Exactness::Lax);
return Ok(v);
} else if let Some(ref missing) = self.missing {
} else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) {
state.floor_exactness(Exactness::Lax);
let enum_value = missing.bind(py).call1((input.to_object(py),)).map_err(|_| {
ValError::new(
ErrorType::Enum {
expected: self.expected_repr.clone(),
context: None,
},
input,
)
})?;
// check enum_value is an instance of the class like
// https://github.com/python/cpython/blob/v3.12.2/Lib/enum.py#L1148
if enum_value.is_instance(class)? {
return Ok(enum_value.into());
} else if !enum_value.is(&py.None()) {
let type_error = PyTypeError::new_err(format!(
"error in {}._missing_: returned {} instead of None or a valid member",
class
.name()
.and_then(|name| name.extract::<String>())
.unwrap_or_else(|_| "<Unknown>".into()),
safe_repr(&enum_value)
));
return Err(type_error.into());
}
return Ok(res);
}

Err(ValError::new(
ErrorType::Enum {
expected: self.expected_repr.clone(),
Expand Down
128 changes: 128 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from decimal import Decimal
from enum import Enum, IntEnum, IntFlag

import pytest
Expand Down Expand Up @@ -344,3 +345,130 @@ class ColorEnum(IntEnum):

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN


@pytest.mark.parametrize(
'value',
[-1, 0, 1],
)
def test_enum_int_validation_should_succeed_for_decimal(value: int):
# GIVEN
class MyEnum(Enum):
VALUE = value

class MyIntEnum(IntEnum):
VALUE = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_int = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
default=MyIntEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE

assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE


def test_enum_int_validation_should_succeed_for_custom_type():
# GIVEN
class AnyWrapper:
def __init__(self, value):
self.value = value

def __eq__(self, other: object) -> bool:
return self.value == other

class MyEnum(Enum):
VALUE = 999
SECOND_VALUE = 1000000
THIRD_VALUE = 'Py03'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE


def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(1))


def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
# GIVEN
class MyEnum(Enum):
VALUE = 1

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(2))

with pytest.raises(ValidationError):
v.validate_python((1, 2))

with pytest.raises(ValidationError):
v.validate_python(Decimal(1.1))


def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
# GIVEN
class MyEnum(Enum):
VALUE = 1

class MyClass:
def __init__(self, value):
self.value = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(MyClass(1))
60 changes: 60 additions & 0 deletions tests/validators/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from copy import deepcopy
from decimal import Decimal
from typing import Any, Callable, Dict, List, Set, Tuple

import pytest
Expand Down Expand Up @@ -1312,3 +1313,62 @@ class OtherModel:
'ctx': {'class_name': 'MyModel'},
}
]


def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
# GIVEN
from enum import Enum

class EnumClass(Enum):
enum_value = 1
enum_value_2 = 2
enum_value_3 = 3

class IntWrappable:
def __init__(self, value: int):
self.value = value

def __eq__(self, value: object) -> bool:
return self.value == value

class MyModel:
__slots__ = (
'__dict__',
'__pydantic_fields_set__',
'__pydantic_extra__',
'__pydantic_private__',
)
enum_field: EnumClass

# WHEN
v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{
'enum_field': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_2': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_3': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
}
),
)
)

# THEN
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}')
m = v.validate_python(
{
'enum_field': Decimal(1),
'enum_field_2': Decimal(2),
'enum_field_3': IntWrappable(3),
}
)
v.validate_assignment(m, 'enum_field', Decimal(1))
v.validate_assignment(m, 'enum_field_2', Decimal(2))
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))
Loading