diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index dc8f55b7c..1e4ba3df9 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -1917,6 +1917,68 @@ def gen() -> Iterator[int]: ) +class IterableSchema(TypedDict, total=False): + type: Required[Literal['iterable']] + items_schema: CoreSchema + min_length: int + max_length: int + lazy: bool + ref: str + metadata: dict[str, Any] + serialization: IncExSeqOrElseSerSchema + + +def iterable_schema( + items_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + lazy: bool | None = None, + ref: str | None = None, + metadata: dict[str, Any] | None = None, + serialization: IncExSeqOrElseSerSchema | None = None, +) -> IterableSchema: + """ + Returns a schema that matches an iterable value, e.g.: + + ```py + from typing import Iterator + from pydantic_core import SchemaValidator, core_schema + + def gen() -> Iterator[int]: + yield 1 + + schema = core_schema.iterable_schema(items_schema=core_schema.int_schema()) + v = SchemaValidator(schema) + v.validate_python(gen()) + ``` + + Lazy validation (the default) is equivalent to `generator_schema` for + backwards compatibility in Pydantic V2. + + When not using lazy validation, validated iterables will be collected into a list. + + Args: + items_schema: The value must be an iterable with items that match this schema + min_length: The value must be an iterable that yields at least this many items + max_length: The value must be an iterable that yields at most this many items + lazy: Whether to use lazy evaluation, defaults to True + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='iterable', + items_schema=items_schema, + min_length=min_length, + max_length=max_length, + lazy=lazy, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + IncExDict = set[Union[int, str]] diff --git a/src/input/mod.rs b/src/input/mod.rs index 19be9fad1..3ce37d628 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -23,7 +23,8 @@ pub(crate) use input_python::{downcast_python_input, input_as_python_instance}; pub(crate) use input_string::StringMapping; pub(crate) use return_enums::{ no_validator_iter_to_vec, py_string_str, validate_iter_to_set, validate_iter_to_vec, EitherBytes, EitherFloat, - EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch, + EitherInt, EitherString, GenericIterator, GenericJsonIterator, GenericPyIterator, Int, MaxLengthCheck, + ValidationMatch, }; // Defined here as it's not exported by pyo3 diff --git a/src/validators/iterable.rs b/src/validators/iterable.rs new file mode 100644 index 000000000..63f614758 --- /dev/null +++ b/src/validators/iterable.rs @@ -0,0 +1,128 @@ +use std::sync::Arc; + +use jiter::JsonValue; +use pyo3::types::PyDict; +use pyo3::{intern, prelude::*, IntoPyObjectExt}; + +use crate::errors::ValResult; +use crate::input::{ + validate_iter_to_vec, GenericIterator, GenericJsonIterator, GenericPyIterator, Input, MaxLengthCheck, +}; +use crate::tools::SchemaDict; +use crate::validators::any::AnyValidator; +use crate::validators::generator::GeneratorValidator; +use crate::validators::list::min_length_check; + +use super::list::get_items_schema; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; + +#[derive(Debug, Clone)] +pub struct IterableValidator { + item_validator: Option>, + min_length: Option, + max_length: Option, + name: String, +} + +impl BuildValidator for IterableValidator { + const EXPECTED_TYPE: &'static str = "iterable"; + + fn build( + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, + definitions: &mut DefinitionsBuilder, + ) -> PyResult { + // TODO: in Pydantic V3 default will be lazy=False + let lazy_iterable: bool = schema.get_as(intern!(schema.py(), "lazy"))?.unwrap_or(true); + + if lazy_iterable { + // lazy iterable is equivalent to generator, for backwards compatibility + return GeneratorValidator::build(schema, config, definitions); + } + + let item_validator = get_items_schema(schema, config, definitions)?.map(Arc::new); + let name = match item_validator { + Some(ref v) => format!("{}[{}]", Self::EXPECTED_TYPE, v.get_name()), + None => format!("{}[any]", Self::EXPECTED_TYPE), + }; + Ok(Self { + item_validator, + name, + min_length: schema.get_as(pyo3::intern!(schema.py(), "min_length"))?, + max_length: schema.get_as(pyo3::intern!(schema.py(), "max_length"))?, + } + .into()) + } +} + +impl_py_gc_traverse!(IterableValidator { item_validator }); + +impl Validator for IterableValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult> { + // this validator does not yet support partial validation, disable it to avoid incorrect results + state.allow_partial = false.into(); + + let iterator = input.validate_iter()?; + + let item_validator = self + .item_validator + .as_deref() + .unwrap_or(&CombinedValidator::Any(AnyValidator)); + + let max_length_check = MaxLengthCheck::new(self.max_length, "Iterable", input, None); + let vec = match iterator { + GenericIterator::PyIterator(iter) => validate_iter_to_vec( + py, + IterWithPy { py, iter }, + 0, + max_length_check, + item_validator, + state, + false, + )?, + GenericIterator::JsonArray(iter) => validate_iter_to_vec( + py, + IterWithPy { py, iter }, + 0, + max_length_check, + item_validator, + state, + false, + )?, + }; + + min_length_check!(input, "Iterable", self.min_length, vec); + + vec.into_py_any(py).map_err(Into::into) + } + + fn get_name(&self) -> &str { + &self.name + } +} + +struct IterWithPy<'py, I> { + py: Python<'py>, + iter: I, +} + +impl<'py> Iterator for IterWithPy<'py, GenericPyIterator> { + type Item = PyResult>; + + fn next(&mut self) -> Option { + Some(self.iter.next(self.py).transpose()?.map(|(v, _)| v)) + } +} + +impl<'j> Iterator for IterWithPy<'_, GenericJsonIterator<'j>> { + type Item = PyResult>; + + fn next(&mut self) -> Option { + Some(self.iter.next(self.py).transpose()?.map(|(v, _)| v.clone())) + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 1b19b6235..ad5dd72ae 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -43,6 +43,7 @@ mod generator; mod int; mod is_instance; mod is_subclass; +mod iterable; mod json; mod json_or_python; mod lax_or_strict; @@ -645,6 +646,8 @@ fn build_validator_inner( json_or_python::JsonOrPython, // generator validators generator::GeneratorValidator, + // iterables + iterable::IterableValidator, // custom error custom_error::CustomErrorValidator, // json data @@ -822,6 +825,8 @@ pub enum CombinedValidator { LaxOrStrict(lax_or_strict::LaxOrStrictValidator), // generator validators Generator(generator::GeneratorValidator), + // iterables + Iterable(iterable::IterableValidator), // custom error CustomError(custom_error::CustomErrorValidator), // json data diff --git a/tests/validators/test_generator.py b/tests/validators/test_generator.py index 411b2202e..8041325ab 100644 --- a/tests/validators/test_generator.py +++ b/tests/validators/test_generator.py @@ -1,4 +1,5 @@ import re +from typing import Callable import pytest from dirty_equals import HasRepr, IsStr @@ -9,6 +10,17 @@ from ..conftest import Err, PyAndJson +@pytest.fixture(params=['generator', 'iterable']) +def schema_type(request): + # both generator and (lazy) iterable should behave the same + return request.param + + +@pytest.fixture(params=[cs.generator_schema, cs.iterable_schema]) +def schema_func(request): + return request.param + + @pytest.mark.parametrize( 'input_value,expected', [ @@ -21,8 +33,8 @@ ], ids=repr, ) -def test_generator_json_int(py_and_json: PyAndJson, input_value, expected): - v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}}) +def test_generator_json_int(schema_type: str, py_and_json: PyAndJson, input_value, expected): + v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): list(v.validate_test(input_value)) @@ -39,8 +51,8 @@ def test_generator_json_int(py_and_json: PyAndJson, input_value, expected): (CoreConfig(hide_input_in_errors=True), 'type=iterable_type'), ), ) -def test_generator_json_hide_input(py_and_json: PyAndJson, config, input_str): - v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}}, config) +def test_generator_json_hide_input(schema_type: str, py_and_json: PyAndJson, config, input_str): + v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}}, config) with pytest.raises(ValidationError, match=re.escape(f'[{input_str}]')): list(v.validate_test(5)) @@ -57,8 +69,8 @@ def test_generator_json_hide_input(py_and_json: PyAndJson, config, input_str): ], ids=repr, ) -def test_generator_json_any(py_and_json: PyAndJson, input_value, expected): - v = py_and_json({'type': 'generator'}) +def test_generator_json_any(schema_type: str, py_and_json: PyAndJson, input_value, expected): + v = py_and_json({'type': schema_type}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): list(v.validate_test(input_value)) @@ -67,8 +79,8 @@ def test_generator_json_any(py_and_json: PyAndJson, input_value, expected): assert list(v.validate_test(input_value)) == expected -def test_error_index(py_and_json: PyAndJson): - v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}}) +def test_error_index(schema_type: str, py_and_json: PyAndJson): + v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}}) gen = v.validate_test(['wrong']) assert gen.index == 0 with pytest.raises(ValidationError) as exc_info: @@ -108,8 +120,8 @@ def test_error_index(py_and_json: PyAndJson): assert gen.index == 5 -def test_too_long(py_and_json: PyAndJson): - v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}, 'max_length': 2}) +def test_too_long(schema_type: str, py_and_json: PyAndJson): + v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}, 'max_length': 2}) assert list(v.validate_test([1])) == [1] assert list(v.validate_test([1, 2])) == [1, 2] with pytest.raises(ValidationError) as exc_info: @@ -126,8 +138,8 @@ def test_too_long(py_and_json: PyAndJson): ] -def test_too_short(py_and_json: PyAndJson): - v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}, 'min_length': 2}) +def test_too_short(schema_type: str, py_and_json: PyAndJson): + v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}, 'min_length': 2}) assert list(v.validate_test([1, 2, 3])) == [1, 2, 3] assert list(v.validate_test([1, 2])) == [1, 2] with pytest.raises(ValidationError) as exc_info: @@ -150,8 +162,8 @@ def gen(): yield 3 -def test_generator_too_long(): - v = SchemaValidator(cs.generator_schema(items_schema=cs.int_schema(), max_length=2)) +def test_generator_too_long(schema_func: Callable): + v = SchemaValidator(schema_func(items_schema=cs.int_schema(), max_length=2)) validating_iterator = v.validate_python(gen()) @@ -174,8 +186,8 @@ def test_generator_too_long(): ] -def test_generator_too_short(): - v = SchemaValidator(cs.generator_schema(items_schema=cs.int_schema(), min_length=4)) +def test_generator_too_short(schema_func: Callable): + v = SchemaValidator(schema_func(items_schema=cs.int_schema(), min_length=4)) validating_iterator = v.validate_python(gen()) diff --git a/tests/validators/test_iterable.py b/tests/validators/test_iterable.py new file mode 100644 index 000000000..916bfd2d3 --- /dev/null +++ b/tests/validators/test_iterable.py @@ -0,0 +1,176 @@ +import re + +import pytest +from dirty_equals import HasRepr, IsStr + +from pydantic_core import CoreConfig, SchemaValidator, ValidationError +from pydantic_core import core_schema as cs + +from ..conftest import Err, PyAndJson + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + ([1, 2, 3], [1, 2, 3]), + ([1, 2, '3'], [1, 2, 3]), + ({1: 2, 3: 4}, [1, 3]), + ('123', [1, 2, 3]), + (5, Err('[type=iterable_type, input_value=5, input_type=int]')), + ([1, 'wrong'], Err("[type=int_parsing, input_value='wrong', input_type=str]")), + ], + ids=repr, +) +def test_iterable_json_int(py_and_json: PyAndJson, input_value, expected): + v = py_and_json(cs.iterable_schema(lazy=False, items_schema=cs.int_schema())) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + list(v.validate_test(input_value)) + + else: + assert list(v.validate_test(input_value)) == expected + + +@pytest.mark.parametrize( + 'config,input_str', + ( + (CoreConfig(), 'type=iterable_type, input_value=5, input_type=int'), + (CoreConfig(hide_input_in_errors=False), 'type=iterable_type, input_value=5, input_type=int'), + (CoreConfig(hide_input_in_errors=True), 'type=iterable_type'), + ), +) +def test_iterable_json_hide_input(py_and_json: PyAndJson, config, input_str): + v = py_and_json(cs.iterable_schema(lazy=False, items_schema=cs.int_schema()), config) + with pytest.raises(ValidationError, match=re.escape(f'[{input_str}]')): + list(v.validate_test(5)) + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + ([1, 2, 3], [1, 2, 3]), + ([1, 2, '3'], [1, 2, '3']), + ({'1': 2, '3': 4}, ['1', '3']), + ('123', ['1', '2', '3']), + (5, Err('[type=iterable_type, input_value=5, input_type=int]')), + ([1, 'wrong'], [1, 'wrong']), + ], + ids=repr, +) +def test_iterable_json_any(py_and_json: PyAndJson, input_value, expected): + v = py_and_json(cs.iterable_schema(lazy=False)) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + list(v.validate_test(input_value)) + + else: + assert list(v.validate_test(input_value)) == expected + + +def test_error_index(py_and_json: PyAndJson): + v = py_and_json(cs.iterable_schema(lazy=False, items_schema=cs.int_schema())) + with pytest.raises(ValidationError) as exc_info: + v.validate_test(['wrong']) + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.title == 'iterable[int]' + assert str(exc_info.value).startswith('1 validation error for iterable[int]\n') + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'int_parsing', + 'loc': (0,), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ] + + with pytest.raises(ValidationError) as exc_info: + v.validate_test([1, 2, 3, 'wrong', 4]) + + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'int_parsing', + 'loc': (3,), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ] + + +def test_too_long(py_and_json: PyAndJson): + v = py_and_json(cs.iterable_schema(lazy=False, items_schema=cs.int_schema(), max_length=2)) + assert list(v.validate_test([1])) == [1] + assert list(v.validate_test([1, 2])) == [1, 2] + with pytest.raises(ValidationError) as exc_info: + list(v.validate_test([1, 2, 3])) + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'too_long', + 'loc': (), + 'msg': 'Iterable should have at most 2 items after validation, not more', + 'input': [1, 2, 3], + 'ctx': {'field_type': 'Iterable', 'max_length': 2, 'actual_length': None}, + } + ] + + +def test_too_short(py_and_json: PyAndJson): + v = py_and_json(cs.iterable_schema(lazy=False, items_schema=cs.int_schema(), min_length=2)) + assert list(v.validate_test([1, 2, 3])) == [1, 2, 3] + assert list(v.validate_test([1, 2])) == [1, 2] + with pytest.raises(ValidationError) as exc_info: + list(v.validate_test([1])) + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'too_short', + 'loc': (), + 'msg': 'Iterable should have at least 2 items after validation, not 1', + 'input': [1], + 'ctx': {'field_type': 'Iterable', 'min_length': 2, 'actual_length': 1}, + } + ] + + +def gen(): + yield 1 + yield 2 + yield 3 + + +def test_generator_too_long(): + v = SchemaValidator(cs.iterable_schema(lazy=False, items_schema=cs.int_schema(), max_length=2)) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(gen()) + + errors = exc_info.value.errors(include_url=False) + # insert_assert(errors) + assert errors == [ + { + 'type': 'too_long', + 'loc': (), + 'input': HasRepr(IsStr(regex='')), + 'msg': 'Iterable should have at most 2 items after validation, not more', + 'ctx': {'field_type': 'Iterable', 'max_length': 2, 'actual_length': None}, + } + ] + + +def test_generator_too_short(): + v = SchemaValidator(cs.iterable_schema(lazy=False, items_schema=cs.int_schema(), min_length=4)) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(gen()) + + errors = exc_info.value.errors(include_url=False) + # insert_assert(errors) + assert errors == [ + { + 'type': 'too_short', + 'input': HasRepr(IsStr(regex='')), + 'loc': (), + 'msg': 'Iterable should have at least 4 items after validation, not 3', + 'ctx': {'field_type': 'Iterable', 'min_length': 4, 'actual_length': 3}, + } + ]