Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ docs/_build/
htmlcov/
node_modules/

.venv
/.benchmarks/
/.idea/
/.pytest_cache/
Expand Down
6 changes: 6 additions & 0 deletions src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ impl EnumValidateValue for PlainEnumValidator {
} else if py_input.is_instance_of::<PyFloat>() {
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
}
if py_input.is_instance_of::<PyAny>() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_instance_of::<PyAny>() I think will always be true, will probably be optimized away by the compiler but also not necessary at all IMO.

Suggested change
if py_input.is_instance_of::<PyAny>() {

if let Ok(Some(res)) = lookup.try_validate_any(input) {
return Ok(Some(res.clone_ref(py)));
}
}
}
}

Ok(None)
}
}
Expand Down
47 changes: 47 additions & 0 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl<T: Debug> LiteralLookup<T> {
}
}
}

if let Some(expected_strings) = &self.expected_str {
let validation_result = if input.as_python().is_some() {
input.exact_str()
Expand Down Expand Up @@ -163,6 +164,7 @@ impl<T: Debug> LiteralLookup<T> {
}
}
};

Ok(None)
}

Expand Down Expand Up @@ -216,6 +218,51 @@ impl<T: Debug> LiteralLookup<T> {
}
Ok(None)
}

pub fn try_validate_any<'a, 'py, I: Input<'py> + ?Sized>(&self, input: &'a I) -> ValResult<Option<&T>> {
let Some(py_input) = input.as_python() else {
return Ok(None);
};

if let Some(expected_ints) = &self.expected_int {
let id = expected_ints
.iter()
.find(|(&k, _)| is_equal_to(py_input, k).unwrap_or(false));

if let Some((_, id)) = id {
return Ok(Some(&self.values[*id]));
}
};

let Some(expected_strings) = &self.expected_str else {
return Ok(None);
};

// try with raw strings
let id = expected_strings
.iter()
.find(|(k, _)| is_equal_to(py_input, k.as_str()).unwrap_or(false));

if let Some((_, id)) = id {
return Ok(Some(&self.values[*id]));
}

// try with converting to int
let id = expected_strings
.iter()
.filter_map(|(k, id)| k.parse::<i64>().ok().map(|k_as_int| (k_as_int, id)))
.find(|(k, _)| is_equal_to(py_input, *k).unwrap_or(false));

if let Some((_, id)) = id {
return Ok(Some(&self.values[*id]));
}
Ok(None)
}
}

fn is_equal_to<TValue: IntoPy<Py<PyAny>>>(input: &Bound<PyAny>, value: TValue) -> PyResult<bool> {
let equality = input.call_method1("__eq__", (value,))?;
equality.extract::<bool>()
}

impl<T: PyGcTraverse + Debug> PyGcTraverse for LiteralLookup<T> {
Expand Down
161 changes: 161 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from decimal import Decimal
from enum import Enum, IntEnum, IntFlag

import pytest
Expand Down Expand Up @@ -344,3 +345,163 @@ class ColorEnum(IntEnum):

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN


@pytest.mark.parametrize(
'value',
[-1, 0, 1],
)
def test_enum_int_validation_should_succeed_for_decimal(value: int):
# GIVEN
class MyEnum(Enum):
VALUE = value

class MyIntEnum(IntEnum):
VALUE = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_int = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
default=MyIntEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE

assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE


def test_enum_int_validation_should_succeed_for_custom_type():
# GIVEN
class AnyWrapper:
def __init__(self, value):
self.value = value

def __eq__(self, other: object) -> bool:
return self.value == other

class MyEnum(Enum):
VALUE = 999
SECOND_VALUE = 1000000
THIRD_VALUE = 'Py03'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE


def test_enum_str_validation_should_succeed_for_decimal_with_strict_disabled():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(1)) is MyEnum.VALUE


def test_enum_str_validation_should_fail_for_decimal_with_strict_enabled():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()), strict=True),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(1))


def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
# GIVEN
class MyEnum(Enum):
VALUE = 1

class MyStrEnum(Enum):
VALUE = '2'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_str = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyStrEnum, list(MyStrEnum.__members__.values())),
default=MyStrEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(2))

with pytest.raises(ValidationError):
v.validate_python((1, 2))

with pytest.raises(ValidationError):
v.validate_python(Decimal(1.1))

with pytest.raises(ValidationError):
v_str.validate_python(Decimal(1))

with pytest.raises(ValidationError):
v_str.validate_python(Decimal(2.1))


def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
# GIVEN
class MyEnum(Enum):
VALUE = 1
Comment on lines +456 to +458
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise now that we probably also want to test this with e.g. MyEnum(IntEnum), which I think goes through a separate code pathway but probably was also broken when we moved enum validation to Rust?


class MyClass:
def __init__(self, value):
self.value = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(MyClass(1))
67 changes: 67 additions & 0 deletions tests/validators/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from copy import deepcopy
from decimal import Decimal
from typing import Any, Callable, Dict, List, Set, Tuple

import pytest
Expand Down Expand Up @@ -1312,3 +1313,69 @@ class OtherModel:
'ctx': {'class_name': 'MyModel'},
}
]


def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
# GIVEN
from enum import Enum

class EnumClass(Enum):
enum_value = 1
enum_value_2 = 2
enum_value_3 = 3
enum_value_4 = '4'

class IntWrappable:
def __init__(self, value: int):
self.value = value

def __eq__(self, value: object) -> bool:
return self.value == value

class MyModel:
__slots__ = (
'__dict__',
'__pydantic_fields_set__',
'__pydantic_extra__',
'__pydantic_private__',
)
enum_field: EnumClass

# WHEN
v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{
'enum_field': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_2': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_3': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_4': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
}
),
)
)

# THEN
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3, "enum_field_4": "4"}')
m = v.validate_python(
{
'enum_field': Decimal(1),
'enum_field_2': Decimal(2),
'enum_field_3': IntWrappable(3),
'enum_field_4': IntWrappable(4),
}
)
v.validate_assignment(m, 'enum_field', Decimal(1))
v.validate_assignment(m, 'enum_field_2', Decimal(2))
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))
v.validate_assignment(m, 'enum_field_4', Decimal(4))
v.validate_assignment(m, 'enum_field_4', IntWrappable(4))