Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ docs/_build/
htmlcov/
node_modules/

.venv
/.benchmarks/
/.idea/
/.pytest_cache/
Expand Down
2 changes: 2 additions & 0 deletions src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub(crate) use return_enums::{
EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch,
};

pub(crate) use shared::decimal_as_int;

// Defined here as it's not exported by pyo3
pub fn py_error_on_minusone(py: Python<'_>, result: c_int) -> PyResult<()> {
if result != -1 {
Expand Down
26 changes: 25 additions & 1 deletion src/validators/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::errors::ErrorType;
use crate::errors::ValResult;
use crate::errors::{ErrorTypeDefaults, Number};
use crate::errors::{ToErrorValue, ValError};
use crate::input::Input;
use crate::input::{decimal_as_int, EitherInt, Input};
use crate::tools::SchemaDict;

use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
Expand Down Expand Up @@ -288,3 +288,27 @@ fn handle_decimal_new_error(input: impl ToErrorValue, error: PyErr, decimal_exce
ValError::InternalErr(error)
}
}

pub(crate) fn try_from_decimal_to_int<'a, 'py, I: Input<'py> + ?Sized>(
py: Python<'py>,
input: &'a I,
) -> ValResult<i64> {
let Some(py_input) = input.as_python() else {
return Err(ValError::new(ErrorTypeDefaults::DecimalType, input));
};

if let Ok(false) = py_input.is_instance(get_decimal_type(py)) {
return Err(ValError::new(ErrorTypeDefaults::DecimalType, input));
}

let dec_value = match decimal_as_int(input, py_input)? {
EitherInt::Py(value) => value,
_ => return Err(ValError::new(ErrorType::DecimalParsing { context: None }, input)),
};

let either_int = dec_value.exact_int()?;

let int = either_int.into_i64(py)?;

Ok(int)
}
2 changes: 1 addition & 1 deletion src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ impl EnumValidateValue for PlainEnumValidator {
lookup: &LiteralLookup<PyObject>,
strict: bool,
) -> ValResult<Option<PyObject>> {
match lookup.validate(py, input)? {
match lookup.validate(py, input, strict)? {
Some((_, v)) => Ok(Some(v.clone_ref(py))),
None => {
if !strict {
Expand Down
22 changes: 21 additions & 1 deletion src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::input::{Input, ValidationMatch};
use crate::py_gc::PyGcTraverse;
use crate::tools::SchemaDict;

use super::decimal::try_from_decimal_to_int;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone, Default)]
Expand Down Expand Up @@ -104,6 +105,7 @@ impl<T: Debug> LiteralLookup<T> {
&self,
py: Python<'py>,
input: &'a I,
strict: bool,
) -> ValResult<Option<(&'a I, &T)>> {
if let Some(expected_bool) = &self.expected_bool {
if let Ok(bool_value) = input.validate_bool(true) {
Expand All @@ -123,7 +125,15 @@ impl<T: Debug> LiteralLookup<T> {
return Ok(Some((input, &self.values[*id])));
}
}
// if the input is a Decimal type, we need to check if its value is in the expected_ints
if let Ok(value) = try_from_decimal_to_int(py, input) {
let Some(id) = expected_ints.get(&value) else {
return Ok(None);
};
return Ok(Some((input, &self.values[*id])));
}
}

if let Some(expected_strings) = &self.expected_str {
let validation_result = if input.as_python().is_some() {
input.exact_str()
Expand All @@ -142,6 +152,15 @@ impl<T: Debug> LiteralLookup<T> {
return Ok(Some((input, &self.values[*id])));
}
}
if !strict {
// if the input is a Decimal type, we need to check if its value is in the expected_ints
if let Ok(value) = try_from_decimal_to_int(py, input) {
let Some(id) = expected_strings.get(&value.to_string()) else {
return Ok(None);
};
return Ok(Some((input, &self.values[*id])));
}
}
}
if let Some(expected_py_dict) = &self.expected_py_dict {
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
Expand All @@ -163,6 +182,7 @@ impl<T: Debug> LiteralLookup<T> {
}
}
};

Ok(None)
}

Expand Down Expand Up @@ -269,7 +289,7 @@ impl Validator for LiteralValidator {
input: &(impl Input<'py> + ?Sized),
_state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
match self.lookup.validate(py, input)? {
match self.lookup.validate(py, input, _state.strict_or(false))? {
Some((_, v)) => Ok(v.clone()),
None => Err(ValError::new(
ErrorType::LiteralError {
Expand Down
2 changes: 1 addition & 1 deletion src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ impl TaggedUnionValidator {
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) {
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag, state.strict_or(false)) {
return match validator.validate(py, input, state) {
Ok(res) => Ok(res),
Err(err) => Err(err.with_outer_location(tag)),
Expand Down
105 changes: 105 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from decimal import Decimal
from enum import Enum, IntEnum, IntFlag

import pytest
Expand Down Expand Up @@ -344,3 +345,107 @@ class ColorEnum(IntEnum):

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN


@pytest.mark.parametrize(
'value',
[-1, 0, 1],
)
def test_enum_int_validation_should_succeed_for_decimal(value: int):
# GIVEN
class MyEnum(Enum):
VALUE = value

class MyIntEnum(IntEnum):
VALUE = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_int = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
default=MyIntEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE

assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE


def test_enum_str_validation_should_succeed_for_decimal_with_strict_disabled():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(1)) is MyEnum.VALUE


def test_enum_str_validation_should_fail_for_decimal_with_strict_enabled():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()), strict=True),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(1))


def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
# GIVEN
class MyEnum(Enum):
VALUE = 1

class MyStrEnum(Enum):
VALUE = '2'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_str = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyStrEnum, list(MyStrEnum.__members__.values())),
default=MyStrEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(2))

with pytest.raises(ValidationError):
v.validate_python((1, 2))

with pytest.raises(ValidationError):
v_str.validate_python(Decimal(1))
38 changes: 38 additions & 0 deletions tests/validators/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from copy import deepcopy
from decimal import Decimal
from typing import Any, Callable, Dict, List, Set, Tuple

import pytest
Expand Down Expand Up @@ -1312,3 +1313,40 @@ class OtherModel:
'ctx': {'class_name': 'MyModel'},
}
]


def test_model_with_enum_int_field_validation_should_succeed_for_decimal():
from enum import Enum

class EnumClass(Enum):
enum_value = 1
enum_value_2 = 2

class MyModel:
__slots__ = (
'__dict__',
'__pydantic_fields_set__',
'__pydantic_extra__',
'__pydantic_private__',
)
enum_field: EnumClass

v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{
'enum_field': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_2': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
}
),
)
)
v.validate_json('{"enum_field": 1, "enum_field_2": 2}')
m = v.validate_python({'enum_field': Decimal(1), 'enum_field_2': Decimal(2)})
v.validate_assignment(m, 'enum_field', Decimal(1))
v.validate_assignment(m, 'enum_field_2', Decimal(2))