Skip to content

Commit 866eb2d

Browse files
authored
Add lax_str and lax_int support for enum values not inherited from str/int (#1015)
1 parent 23d1065 commit 866eb2d

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

src/input/input_python.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ use super::datetime::{
2121
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
2222
EitherTime,
2323
};
24-
use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
24+
use super::shared::{
25+
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float,
26+
str_as_int,
27+
};
2528
use super::{
2629
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
2730
GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
@@ -256,6 +259,8 @@ impl<'a> Input<'a> for PyAny {
256259
|| self.is_instance(decimal_type.as_ref(py)).unwrap_or_default()
257260
} {
258261
Ok(self.str()?.into())
262+
} else if let Some(enum_val) = maybe_as_enum(self) {
263+
Ok(enum_val.str()?.into())
259264
} else {
260265
Err(ValError::new(ErrorTypeDefaults::StringType, self))
261266
}
@@ -340,6 +345,8 @@ impl<'a> Input<'a> for PyAny {
340345
decimal_as_int(self.py(), self, decimal)
341346
} else if let Ok(float) = self.extract::<f64>() {
342347
float_as_int(self, float)
348+
} else if let Some(enum_val) = maybe_as_enum(self) {
349+
Ok(EitherInt::Py(enum_val))
343350
} else {
344351
Err(ValError::new(ErrorTypeDefaults::IntType, self))
345352
}
@@ -759,6 +766,18 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult<Option<Cow<
759766
}
760767
}
761768

769+
/// Utility for extracting an enum value, if possible.
770+
fn maybe_as_enum(v: &PyAny) -> Option<&PyAny> {
771+
let py = v.py();
772+
let enum_meta_object = get_enum_meta_object(py);
773+
let meta_type = v.get_type().get_type();
774+
if meta_type.is(&enum_meta_object) {
775+
v.getattr(intern!(py, "value")).ok()
776+
} else {
777+
None
778+
}
779+
}
780+
762781
#[cfg(PyPy)]
763782
static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell<Py<PyType>> = pyo3::once_cell::GILOnceCell::new();
764783

src/input/shared.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11
use num_bigint::BigInt;
2-
use pyo3::{intern, PyAny, Python};
2+
use pyo3::sync::GILOnceCell;
3+
use pyo3::{intern, Py, PyAny, Python, ToPyObject};
34

45
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
56

67
use super::parse_json::{JsonArray, JsonInput};
78
use super::{EitherFloat, EitherInt, Input};
89

10+
static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
11+
12+
pub fn get_enum_meta_object(py: Python) -> Py<PyAny> {
13+
ENUM_META_OBJECT
14+
.get_or_init(py, || {
15+
py.import(intern!(py, "enum"))
16+
.and_then(|enum_module| enum_module.getattr(intern!(py, "EnumMeta")))
17+
.unwrap()
18+
.to_object(py)
19+
})
20+
.clone()
21+
}
22+
923
pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> {
1024
ValError::new(
1125
ErrorType::JsonInvalid {

src/serializers/ob_type.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,9 @@ impl ObTypeLookup {
259259
fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool {
260260
// only test on the type itself, not base types
261261
if op_value.is_some() {
262+
let enum_meta_type = self.enum_object.as_ref(py_type.py()).get_type();
262263
let meta_type = py_type.get_type();
263-
meta_type.is(&self.enum_object)
264+
meta_type.is(enum_meta_type)
264265
} else {
265266
false
266267
}

tests/validators/test_int.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,16 @@ def test_float_subclass() -> None:
459459
v_lax = v.validate_python(FloatSubclass(1))
460460
assert v_lax == 1
461461
assert type(v_lax) == int
462+
463+
464+
def test_int_subclass_plain_enum() -> None:
465+
v = SchemaValidator({'type': 'int'})
466+
467+
from enum import Enum
468+
469+
class PlainEnum(Enum):
470+
ONE = 1
471+
472+
v_lax = v.validate_python(PlainEnum.ONE)
473+
assert v_lax == 1
474+
assert type(v_lax) == int

tests/validators/test_string.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,21 @@ def test_lax_subclass(FruitEnum, kwargs):
249249
assert repr(p) == "'pear'"
250250

251251

252+
@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr)
253+
def test_lax_subclass_plain_enum(kwargs):
254+
v = SchemaValidator(core_schema.str_schema(**kwargs))
255+
256+
from enum import Enum
257+
258+
class PlainEnum(Enum):
259+
ONE = 'one'
260+
261+
p = v.validate_python(PlainEnum.ONE)
262+
assert p == 'one'
263+
assert type(p) is str
264+
assert repr(p) == "'one'"
265+
266+
252267
def test_subclass_preserved() -> None:
253268
class StrSubclass(str):
254269
pass

0 commit comments

Comments
 (0)