Skip to content

Commit c45e43c

Browse files
committed
introduce iterable schema
1 parent 0cd11fe commit c45e43c

File tree

6 files changed

+401
-17
lines changed

6 files changed

+401
-17
lines changed

python/pydantic_core/core_schema.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,6 +1917,68 @@ def gen() -> Iterator[int]:
19171917
)
19181918

19191919

1920+
class IterableSchema(TypedDict, total=False):
1921+
type: Required[Literal['iterable']]
1922+
items_schema: CoreSchema
1923+
min_length: int
1924+
max_length: int
1925+
lazy: bool
1926+
ref: str
1927+
metadata: dict[str, Any]
1928+
serialization: IncExSeqOrElseSerSchema
1929+
1930+
1931+
def iterable_schema(
1932+
items_schema: CoreSchema | None = None,
1933+
*,
1934+
min_length: int | None = None,
1935+
max_length: int | None = None,
1936+
lazy: bool | None = None,
1937+
ref: str | None = None,
1938+
metadata: dict[str, Any] | None = None,
1939+
serialization: IncExSeqOrElseSerSchema | None = None,
1940+
) -> IterableSchema:
1941+
"""
1942+
Returns a schema that matches an iterable value, e.g.:
1943+
1944+
```py
1945+
from typing import Iterator
1946+
from pydantic_core import SchemaValidator, core_schema
1947+
1948+
def gen() -> Iterator[int]:
1949+
yield 1
1950+
1951+
schema = core_schema.iterable_schema(items_schema=core_schema.int_schema())
1952+
v = SchemaValidator(schema)
1953+
v.validate_python(gen())
1954+
```
1955+
1956+
Lazy validation (the default) is equivalent to `generator_schema` for
1957+
backwards compatibility in Pydantic V2.
1958+
1959+
When not using lazy validation, validated iterables will be collected into a list.
1960+
1961+
Args:
1962+
items_schema: The value must be an iterable with items that match this schema
1963+
min_length: The value must be an iterable that yields at least this many items
1964+
max_length: The value must be an iterable that yields at most this many items
1965+
lazy: Whether to use lazy evaluation, defaults to True
1966+
ref: optional unique identifier of the schema, used to reference the schema in other places
1967+
metadata: Any other information you want to include with the schema, not used by pydantic-core
1968+
serialization: Custom serialization schema
1969+
"""
1970+
return _dict_not_none(
1971+
type='iterable',
1972+
items_schema=items_schema,
1973+
min_length=min_length,
1974+
max_length=max_length,
1975+
lazy=lazy,
1976+
ref=ref,
1977+
metadata=metadata,
1978+
serialization=serialization,
1979+
)
1980+
1981+
19201982
IncExDict = set[Union[int, str]]
19211983

19221984

src/input/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ pub(crate) use input_python::{downcast_python_input, input_as_python_instance};
2323
pub(crate) use input_string::StringMapping;
2424
pub(crate) use return_enums::{
2525
no_validator_iter_to_vec, py_string_str, validate_iter_to_set, validate_iter_to_vec, EitherBytes, EitherFloat,
26-
EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch,
26+
EitherInt, EitherString, GenericIterator, GenericJsonIterator, GenericPyIterator, Int, MaxLengthCheck,
27+
ValidationMatch,
2728
};
2829

2930
// Defined here as it's not exported by pyo3

src/validators/iterable.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use std::sync::Arc;
2+
3+
use jiter::JsonValue;
4+
use pyo3::types::PyDict;
5+
use pyo3::{intern, prelude::*, IntoPyObjectExt};
6+
7+
use crate::errors::ValResult;
8+
use crate::input::{
9+
validate_iter_to_vec, GenericIterator, GenericJsonIterator, GenericPyIterator, Input, MaxLengthCheck,
10+
};
11+
use crate::tools::SchemaDict;
12+
use crate::validators::any::AnyValidator;
13+
use crate::validators::generator::GeneratorValidator;
14+
use crate::validators::list::min_length_check;
15+
16+
use super::list::get_items_schema;
17+
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
18+
19+
#[derive(Debug, Clone)]
20+
pub struct IterableValidator {
21+
item_validator: Option<Arc<CombinedValidator>>,
22+
min_length: Option<usize>,
23+
max_length: Option<usize>,
24+
name: String,
25+
}
26+
27+
impl BuildValidator for IterableValidator {
28+
const EXPECTED_TYPE: &'static str = "iterable";
29+
30+
fn build(
31+
schema: &Bound<'_, PyDict>,
32+
config: Option<&Bound<'_, PyDict>>,
33+
definitions: &mut DefinitionsBuilder<CombinedValidator>,
34+
) -> PyResult<CombinedValidator> {
35+
// TODO: in Pydantic V3 default will be lazy=False
36+
let lazy_iterable: bool = schema.get_as(intern!(schema.py(), "lazy"))?.unwrap_or(true);
37+
38+
if lazy_iterable {
39+
// lazy iterable is equivalent to generator, for backwards compatibility
40+
return GeneratorValidator::build(schema, config, definitions);
41+
}
42+
43+
let item_validator = get_items_schema(schema, config, definitions)?.map(Arc::new);
44+
let name = match item_validator {
45+
Some(ref v) => format!("{}[{}]", Self::EXPECTED_TYPE, v.get_name()),
46+
None => format!("{}[any]", Self::EXPECTED_TYPE),
47+
};
48+
Ok(Self {
49+
item_validator,
50+
name,
51+
min_length: schema.get_as(pyo3::intern!(schema.py(), "min_length"))?,
52+
max_length: schema.get_as(pyo3::intern!(schema.py(), "max_length"))?,
53+
}
54+
.into())
55+
}
56+
}
57+
58+
impl_py_gc_traverse!(IterableValidator { item_validator });
59+
60+
impl Validator for IterableValidator {
61+
fn validate<'py>(
62+
&self,
63+
py: Python<'py>,
64+
input: &(impl Input<'py> + ?Sized),
65+
state: &mut ValidationState<'_, 'py>,
66+
) -> ValResult<Py<PyAny>> {
67+
// this validator does not yet support partial validation, disable it to avoid incorrect results
68+
state.allow_partial = false.into();
69+
70+
let iterator = input.validate_iter()?;
71+
72+
let item_validator = self
73+
.item_validator
74+
.as_deref()
75+
.unwrap_or(&CombinedValidator::Any(AnyValidator));
76+
77+
let max_length_check = MaxLengthCheck::new(self.max_length, "Iterable", input, None);
78+
let vec = match iterator {
79+
GenericIterator::PyIterator(iter) => validate_iter_to_vec(
80+
py,
81+
IterWithPy { py, iter },
82+
0,
83+
max_length_check,
84+
item_validator,
85+
state,
86+
false,
87+
)?,
88+
GenericIterator::JsonArray(iter) => validate_iter_to_vec(
89+
py,
90+
IterWithPy { py, iter },
91+
0,
92+
max_length_check,
93+
item_validator,
94+
state,
95+
false,
96+
)?,
97+
};
98+
99+
min_length_check!(input, "Iterable", self.min_length, vec);
100+
101+
vec.into_py_any(py).map_err(Into::into)
102+
}
103+
104+
fn get_name(&self) -> &str {
105+
&self.name
106+
}
107+
}
108+
109+
struct IterWithPy<'py, I> {
110+
py: Python<'py>,
111+
iter: I,
112+
}
113+
114+
impl<'py> Iterator for IterWithPy<'py, GenericPyIterator> {
115+
type Item = PyResult<Bound<'py, PyAny>>;
116+
117+
fn next(&mut self) -> Option<Self::Item> {
118+
Some(self.iter.next(self.py).transpose()?.map(|(v, _)| v))
119+
}
120+
}
121+
122+
impl<'j> Iterator for IterWithPy<'_, GenericJsonIterator<'j>> {
123+
type Item = PyResult<JsonValue<'j>>;
124+
125+
fn next(&mut self) -> Option<Self::Item> {
126+
Some(self.iter.next(self.py).transpose()?.map(|(v, _)| v.clone()))
127+
}
128+
}

src/validators/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ mod generator;
4343
mod int;
4444
mod is_instance;
4545
mod is_subclass;
46+
mod iterable;
4647
mod json;
4748
mod json_or_python;
4849
mod lax_or_strict;
@@ -645,6 +646,8 @@ fn build_validator_inner(
645646
json_or_python::JsonOrPython,
646647
// generator validators
647648
generator::GeneratorValidator,
649+
// iterables
650+
iterable::IterableValidator,
648651
// custom error
649652
custom_error::CustomErrorValidator,
650653
// json data
@@ -822,6 +825,8 @@ pub enum CombinedValidator {
822825
LaxOrStrict(lax_or_strict::LaxOrStrictValidator),
823826
// generator validators
824827
Generator(generator::GeneratorValidator),
828+
// iterables
829+
Iterable(iterable::IterableValidator),
825830
// custom error
826831
CustomError(custom_error::CustomErrorValidator),
827832
// json data

tests/validators/test_generator.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from typing import Callable
23

34
import pytest
45
from dirty_equals import HasRepr, IsStr
@@ -9,6 +10,17 @@
910
from ..conftest import Err, PyAndJson
1011

1112

13+
@pytest.fixture(params=['generator', 'iterable'])
14+
def schema_type(request):
15+
# both generator and (lazy) iterable should behave the same
16+
return request.param
17+
18+
19+
@pytest.fixture(params=[cs.generator_schema, cs.iterable_schema])
20+
def schema_func(request):
21+
return request.param
22+
23+
1224
@pytest.mark.parametrize(
1325
'input_value,expected',
1426
[
@@ -21,8 +33,8 @@
2133
],
2234
ids=repr,
2335
)
24-
def test_generator_json_int(py_and_json: PyAndJson, input_value, expected):
25-
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}})
36+
def test_generator_json_int(schema_type: str, py_and_json: PyAndJson, input_value, expected):
37+
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}})
2638
if isinstance(expected, Err):
2739
with pytest.raises(ValidationError, match=re.escape(expected.message)):
2840
list(v.validate_test(input_value))
@@ -39,8 +51,8 @@ def test_generator_json_int(py_and_json: PyAndJson, input_value, expected):
3951
(CoreConfig(hide_input_in_errors=True), 'type=iterable_type'),
4052
),
4153
)
42-
def test_generator_json_hide_input(py_and_json: PyAndJson, config, input_str):
43-
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}}, config)
54+
def test_generator_json_hide_input(schema_type: str, py_and_json: PyAndJson, config, input_str):
55+
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}}, config)
4456
with pytest.raises(ValidationError, match=re.escape(f'[{input_str}]')):
4557
list(v.validate_test(5))
4658

@@ -57,8 +69,8 @@ def test_generator_json_hide_input(py_and_json: PyAndJson, config, input_str):
5769
],
5870
ids=repr,
5971
)
60-
def test_generator_json_any(py_and_json: PyAndJson, input_value, expected):
61-
v = py_and_json({'type': 'generator'})
72+
def test_generator_json_any(schema_type: str, py_and_json: PyAndJson, input_value, expected):
73+
v = py_and_json({'type': schema_type})
6274
if isinstance(expected, Err):
6375
with pytest.raises(ValidationError, match=re.escape(expected.message)):
6476
list(v.validate_test(input_value))
@@ -67,8 +79,8 @@ def test_generator_json_any(py_and_json: PyAndJson, input_value, expected):
6779
assert list(v.validate_test(input_value)) == expected
6880

6981

70-
def test_error_index(py_and_json: PyAndJson):
71-
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}})
82+
def test_error_index(schema_type: str, py_and_json: PyAndJson):
83+
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}})
7284
gen = v.validate_test(['wrong'])
7385
assert gen.index == 0
7486
with pytest.raises(ValidationError) as exc_info:
@@ -108,8 +120,8 @@ def test_error_index(py_and_json: PyAndJson):
108120
assert gen.index == 5
109121

110122

111-
def test_too_long(py_and_json: PyAndJson):
112-
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}, 'max_length': 2})
123+
def test_too_long(schema_type: str, py_and_json: PyAndJson):
124+
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}, 'max_length': 2})
113125
assert list(v.validate_test([1])) == [1]
114126
assert list(v.validate_test([1, 2])) == [1, 2]
115127
with pytest.raises(ValidationError) as exc_info:
@@ -126,8 +138,8 @@ def test_too_long(py_and_json: PyAndJson):
126138
]
127139

128140

129-
def test_too_short(py_and_json: PyAndJson):
130-
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}, 'min_length': 2})
141+
def test_too_short(schema_type: str, py_and_json: PyAndJson):
142+
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}, 'min_length': 2})
131143
assert list(v.validate_test([1, 2, 3])) == [1, 2, 3]
132144
assert list(v.validate_test([1, 2])) == [1, 2]
133145
with pytest.raises(ValidationError) as exc_info:
@@ -150,8 +162,8 @@ def gen():
150162
yield 3
151163

152164

153-
def test_generator_too_long():
154-
v = SchemaValidator(cs.generator_schema(items_schema=cs.int_schema(), max_length=2))
165+
def test_generator_too_long(schema_func: Callable):
166+
v = SchemaValidator(schema_func(items_schema=cs.int_schema(), max_length=2))
155167

156168
validating_iterator = v.validate_python(gen())
157169

@@ -174,8 +186,8 @@ def test_generator_too_long():
174186
]
175187

176188

177-
def test_generator_too_short():
178-
v = SchemaValidator(cs.generator_schema(items_schema=cs.int_schema(), min_length=4))
189+
def test_generator_too_short(schema_func: Callable):
190+
v = SchemaValidator(schema_func(items_schema=cs.int_schema(), min_length=4))
179191

180192
validating_iterator = v.validate_python(gen())
181193

0 commit comments

Comments
 (0)