Skip to content

Commit 470754d

Browse files
committed
move allow_partial to state, and set it correctly
1 parent 92e694f commit 470754d

File tree

7 files changed

+156
-91
lines changed

7 files changed

+156
-91
lines changed

src/input/return_enums.rs

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMappin
1717
use serde::{ser::Error, Serialize, Serializer};
1818

1919
use crate::errors::{
20-
py_err_string, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ToErrorValue, ValError, ValLineError, ValResult,
20+
py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, ValLineError, ValResult,
2121
};
2222
use crate::py_gc::PyGcTraverse;
2323
use crate::tools::{extract_i64, extract_int, new_py_string, py_err};
@@ -128,8 +128,9 @@ pub(crate) fn validate_iter_to_vec<'py>(
128128
) -> ValResult<Vec<PyObject>> {
129129
let mut output: Vec<PyObject> = Vec::with_capacity(capacity);
130130
let mut errors: Vec<ValLineError> = Vec::new();
131-
let mut index = 0;
132-
for item_result in iter {
131+
132+
for (index, is_last_partial, item_result) in state.enumerate_last_partial(iter) {
133+
state.allow_partial = is_last_partial;
133134
let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?;
134135
match validator.validate(py, item.borrow_input(), state) {
135136
Ok(item) => {
@@ -138,41 +139,25 @@ pub(crate) fn validate_iter_to_vec<'py>(
138139
}
139140
Err(ValError::LineErrors(line_errors)) => {
140141
max_length_check.incr()?;
141-
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
142-
if fail_fast {
143-
return Err(ValError::LineErrors(errors));
142+
if !is_last_partial {
143+
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
144+
if fail_fast {
145+
return Err(ValError::LineErrors(errors));
146+
}
144147
}
145148
}
146149
Err(ValError::Omit) => (),
147150
Err(err) => return Err(err),
148151
}
149-
index += 1;
150152
}
151153

152-
if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) {
154+
if errors.is_empty() {
153155
Ok(output)
154156
} else {
155157
Err(ValError::LineErrors(errors))
156158
}
157159
}
158160

159-
/// If we're in `allow_partial` mode, whether all errors occurred in the last element of the input.
160-
pub fn sequence_valid_as_partial(state: &ValidationState, input_length: usize, errors: &[ValLineError]) -> bool {
161-
if !state.extra().allow_partial {
162-
false
163-
} else {
164-
// for the error to be in the last element, the index of all errors must be `input_length - 1`
165-
let last_index = (input_length - 1) as i64;
166-
errors.iter().all(|error| {
167-
if let Some(LocItem::I(loc_index)) = error.first_loc_item() {
168-
*loc_index == last_index
169-
} else {
170-
false
171-
}
172-
})
173-
}
174-
}
175-
176161
pub trait BuildSet {
177162
fn build_add(&self, item: PyObject) -> PyResult<()>;
178163

@@ -216,8 +201,9 @@ pub(crate) fn validate_iter_to_set<'py>(
216201
fail_fast: bool,
217202
) -> ValResult<()> {
218203
let mut errors: Vec<ValLineError> = Vec::new();
219-
let mut index = 0;
220-
for item_result in iter {
204+
205+
for (index, is_last_partial, item_result) in state.enumerate_last_partial(iter) {
206+
state.allow_partial = is_last_partial;
221207
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
222208
match validator.validate(py, item.borrow_input(), state) {
223209
Ok(item) => {
@@ -240,18 +226,19 @@ pub(crate) fn validate_iter_to_set<'py>(
240226
}
241227
}
242228
Err(ValError::LineErrors(line_errors)) => {
243-
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
229+
if !is_last_partial {
230+
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
231+
}
244232
}
245233
Err(ValError::Omit) => (),
246234
Err(err) => return Err(err),
247235
}
248236
if fail_fast && !errors.is_empty() {
249237
return Err(ValError::LineErrors(errors));
250238
}
251-
index += 1;
252239
}
253240

254-
if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) {
241+
if errors.is_empty() {
255242
Ok(())
256243
} else {
257244
Err(ValError::LineErrors(errors))

src/validators/dict.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,9 @@ where
109109
fn consume_iterator(self, iterator: impl Iterator<Item = ValResult<(Key, Value)>>) -> ValResult<PyObject> {
110110
let output = PyDict::new_bound(self.py);
111111
let mut errors: Vec<ValLineError> = Vec::new();
112-
// this should only be set to if:
113-
// we get errors in a value, there are no previous errors, and no items come after that
114-
// e.g. if we get errors just in the last value
115-
let mut errors_in_last = false;
116112

117-
for item_result in iterator {
118-
errors_in_last = false;
113+
for (_, is_last_partial, item_result) in self.state.enumerate_last_partial(iterator) {
114+
self.state.allow_partial = false;
119115
let (key, value) = item_result?;
120116
let output_key = match self.key_validator.validate(self.py, key.borrow_input(), self.state) {
121117
Ok(value) => Some(value),
@@ -129,12 +125,12 @@ where
129125
Err(ValError::Omit) => continue,
130126
Err(err) => return Err(err),
131127
};
128+
self.state.allow_partial = is_last_partial;
132129
let output_value = match self.value_validator.validate(self.py, value.borrow_input(), self.state) {
133130
Ok(value) => value,
134131
Err(ValError::LineErrors(line_errors)) => {
135-
errors_in_last = errors.is_empty();
136-
for err in line_errors {
137-
errors.push(err.with_outer_location(key.clone()));
132+
if !is_last_partial {
133+
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(key.clone())));
138134
}
139135
continue;
140136
}
@@ -146,7 +142,7 @@ where
146142
}
147143
}
148144

149-
if errors.is_empty() || (self.state.extra().allow_partial && errors_in_last) {
145+
if errors.is_empty() {
150146
let input = self.input;
151147
length_check!(input, "Dictionary", self.min_length, self.max_length, output);
152148
Ok(output.into())

src/validators/generator.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ impl InternalValidator {
260260
hide_input_in_errors,
261261
validation_error_cause,
262262
cache_str: extra.cache_str,
263-
allow_partial: extra.allow_partial,
263+
allow_partial: state.allow_partial,
264264
}
265265
}
266266

@@ -280,9 +280,8 @@ impl InternalValidator {
280280
context: self.context.as_ref().map(|data| data.bind(py)),
281281
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
282282
cache_str: self.cache_str,
283-
allow_partial: self.allow_partial,
284283
};
285-
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
284+
let mut state = ValidationState::new(extra, &mut self.recursion_guard, self.allow_partial);
286285
state.exactness = self.exactness;
287286
let result = self
288287
.validator
@@ -316,9 +315,8 @@ impl InternalValidator {
316315
context: self.context.as_ref().map(|data| data.bind(py)),
317316
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
318317
cache_str: self.cache_str,
319-
allow_partial: self.allow_partial,
320318
};
321-
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
319+
let mut state = ValidationState::new(extra, &mut self.recursion_guard, self.allow_partial);
322320
state.exactness = self.exactness;
323321
let result = self.validator.validate(py, input, &mut state).map_err(|e| {
324322
ValidationError::from_val_error(

src/validators/mod.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,10 @@ impl SchemaValidator {
280280
context,
281281
self_instance: None,
282282
cache_str: self.cache_str,
283-
allow_partial: false,
284283
};
285284

286285
let guard = &mut RecursionState::default();
287-
let mut state = ValidationState::new(extra, guard);
286+
let mut state = ValidationState::new(extra, guard, false);
288287
self.validator
289288
.validate_assignment(py, &obj, field_name, &field_value, &mut state)
290289
.map_err(|e| self.prepare_validation_err(py, e, InputType::Python))
@@ -305,10 +304,9 @@ impl SchemaValidator {
305304
context,
306305
self_instance: None,
307306
cache_str: self.cache_str,
308-
allow_partial: false,
309307
};
310308
let recursion_guard = &mut RecursionState::default();
311-
let mut state = ValidationState::new(extra, recursion_guard);
309+
let mut state = ValidationState::new(extra, recursion_guard, false);
312310
let r = self.validator.default_value(py, None::<i64>, &mut state);
313311
match r {
314312
Ok(maybe_default) => match maybe_default {
@@ -365,9 +363,9 @@ impl SchemaValidator {
365363
self_instance,
366364
input_type,
367365
self.cache_str,
368-
allow_partial,
369366
),
370367
&mut recursion_guard,
368+
allow_partial,
371369
);
372370
self.validator.validate(py, input, &mut state)
373371
}
@@ -430,8 +428,9 @@ impl<'py> SelfValidator<'py> {
430428
let py = schema.py();
431429
let mut recursion_guard = RecursionState::default();
432430
let mut state = ValidationState::new(
433-
Extra::new(strict, None, None, None, InputType::Python, true.into(), false),
431+
Extra::new(strict, None, None, None, InputType::Python, true.into()),
434432
&mut recursion_guard,
433+
false,
435434
);
436435
match self.validator.validator.validate(py, schema, &mut state) {
437436
Ok(schema_obj) => Ok(schema_obj.into_bound(py)),
@@ -628,8 +627,6 @@ pub struct Extra<'a, 'py> {
628627
self_instance: Option<&'a Bound<'py, PyAny>>,
629628
/// Whether to use a cache of short strings to accelerate python string construction
630629
cache_str: StringCacheMode,
631-
/// Whether to allow validation of partial objects
632-
pub allow_partial: bool,
633630
}
634631

635632
impl<'a, 'py> Extra<'a, 'py> {
@@ -640,7 +637,6 @@ impl<'a, 'py> Extra<'a, 'py> {
640637
self_instance: Option<&'a Bound<'py, PyAny>>,
641638
input_type: InputType,
642639
cache_str: StringCacheMode,
643-
allow_partial: bool,
644640
) -> Self {
645641
Extra {
646642
input_type,
@@ -650,7 +646,6 @@ impl<'a, 'py> Extra<'a, 'py> {
650646
context,
651647
self_instance,
652648
cache_str,
653-
allow_partial,
654649
}
655650
}
656651
}
@@ -665,7 +660,6 @@ impl Extra<'_, '_> {
665660
context: self.context,
666661
self_instance: self.self_instance,
667662
cache_str: self.cache_str,
668-
allow_partial: self.allow_partial,
669663
}
670664
}
671665
}

0 commit comments

Comments
 (0)