diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index dbd797f12..8d3f2adf8 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -157,6 +157,8 @@ impl Validator for DataclassArgsValidator { let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len()); let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone())); + let state = &mut state.scoped_set(|state| &mut state.has_field_error, false); + let extra_behavior = state.extra_behavior_or(self.extra_behavior); let validate_by_alias = state.validate_by_alias_or(self.validate_by_alias); @@ -235,6 +237,7 @@ impl Validator for DataclassArgsValidator { fields_set_count += 1; } Err(ValError::LineErrors(line_errors)) => { + state.has_field_error = true; errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); } Err(err) => return Err(err), @@ -246,6 +249,7 @@ impl Validator for DataclassArgsValidator { fields_set_count += 1; } Err(ValError::LineErrors(line_errors)) => { + state.has_field_error = true; errors.extend( line_errors .into_iter() @@ -272,6 +276,7 @@ impl Validator for DataclassArgsValidator { } Err(ValError::Omit) => {} Err(ValError::LineErrors(line_errors)) => { + state.has_field_error = true; for err in line_errors { // Note: this will always use the field name even if there is an alias // However, we don't mind so much because this error can only happen if the diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index be8ebb4b9..3a334f571 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -179,6 +179,7 @@ impl Validator for ModelFieldsValidator { { let state = &mut state.rebind_extra(|extra| extra.data = Some(model_dict.clone())); + let state = &mut state.scoped_set(|state| &mut state.has_field_error, false); for field in &self.fields { let lookup_key = field @@ -242,6 +243,7 @@ impl Validator for ModelFieldsValidator { } Err(ValError::Omit) => {} Err(ValError::LineErrors(line_errors)) => { + state.has_field_error = true; for err in line_errors { // Note: this will always use the field name even if there is an alias // However, we don't mind so much because this error can only happen if the diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index fdec72143..26af98fd5 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -185,6 +185,7 @@ impl Validator for TypedDictValidator { { let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone())); + let state = &mut state.scoped_set(|state| &mut state.has_field_error, false); let mut fields_set_count: usize = 0; @@ -229,15 +230,24 @@ impl Validator for TypedDictValidator { output_dict.set_item(&field.name_py, value)?; fields_set_count += 1; } - Err(ValError::Omit) => continue, - Err(ValError::LineErrors(line_errors)) => { - if !is_last_partial || field.required { - for err in line_errors { - errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + Err(e) => { + state.has_field_error = true; + match e { + ValError::Omit => {} + ValError::LineErrors(line_errors) => { + if !is_last_partial || field.required { + for err in line_errors { + errors.push(lookup_path.apply_error_loc( + err, + self.loc_by_alias, + &field.name, + )); + } + } } + err => return Err(err), } } - Err(err) => return Err(err), } continue; } @@ -260,6 +270,7 @@ impl Validator for TypedDictValidator { } Err(ValError::Omit) => {} Err(ValError::LineErrors(line_errors)) => { + state.has_field_error = true; for err in line_errors { // Note: this will always use the field name even if there is an alias // However, we don't mind so much because this error can only happen if the diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index e5ec3a5fc..411a2b765 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -1,3 +1,5 @@ +use std::ops::{Deref, DerefMut}; + use pyo3::prelude::*; use pyo3::types::PyString; @@ -29,6 +31,9 @@ pub struct ValidationState<'a, 'py> { // Whether at least one field had a validation error. This is used in the context of structured types // (models, dataclasses, etc), where we need to know if a validation error occurred before calling // a default factory that takes the validated data. + // + // TODO: this should probably be moved directly into the structured types which need it, but that + // requires some refactoring to make them have knowledge of default (factories). pub has_field_error: bool, // deliberately make Extra readonly extra: Extra<'a, 'py>, @@ -58,6 +63,26 @@ impl<'a, 'py> ValidationState<'a, 'py> { ValidationStateWithReboundExtra { state: self, old_extra } } + /// Temporarily rebinds a field of the state by calling `projector` to get a mutable reference to the field, + /// and setting that field to `value`. + /// + /// When `ScopedSetState` drops, the field is restored to its original value. + pub fn scoped_set<'state, P, T>( + &'state mut self, + projector: P, + new_value: T, + ) -> ScopedSetState<'state, 'a, 'py, P, T> + where + P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T, + { + let value = std::mem::replace((projector)(self), new_value); + ScopedSetState { + state: self, + projector, + value, + } + } + pub fn extra(&self) -> &'_ Extra<'a, 'py> { &self.extra } @@ -176,3 +201,44 @@ impl Iterator for EnumerateLastPartial { self.iter.size_hint() } } + +pub struct ScopedSetState<'scope, 'a, 'py, P, T> +where + P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T, +{ + /// The state which has been set for the scope. + state: &'scope mut ValidationState<'a, 'py>, + /// A function that projects from the state to the field that has been set. + projector: P, + /// The previous value of the field that has been set. + value: T, +} + +impl<'a, 'py, P, T> Drop for ScopedSetState<'_, 'a, 'py, P, T> +where + P: for<'drop> Fn(&'drop mut ValidationState<'a, 'py>) -> &'drop mut T, +{ + fn drop(&mut self) { + std::mem::swap((self.projector)(self.state), &mut self.value); + } +} + +impl<'a, 'py, P, T> Deref for ScopedSetState<'_, 'a, 'py, P, T> +where + P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T, +{ + type Target = ValidationState<'a, 'py>; + + fn deref(&self) -> &Self::Target { + self.state + } +} + +impl<'a, 'py, P, T> DerefMut for ScopedSetState<'_, 'a, 'py, P, T> +where + P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T, +{ + fn deref_mut(&mut self) -> &mut ValidationState<'a, 'py> { + self.state + } +} diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 443a87fa6..2e2f36f36 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -2,6 +2,7 @@ import sys import weakref from collections import deque +from dataclasses import dataclass from typing import Any, Callable, Union, cast import pytest @@ -16,6 +17,7 @@ ValidationError, core_schema, ) +from pydantic_core._pydantic_core import SchemaSerializer from ..conftest import PyAndJson, assert_gc @@ -822,31 +824,46 @@ def _raise(ex: Exception) -> None: assert exc_info.value.errors(include_url=False, include_context=False) == expected -def test_default_factory_not_called_if_existing_error(pydantic_version) -> None: - class Test: - def __init__(self, a: int, b: int): - self.a = a - self.b = b +@pytest.fixture(params=['model', 'typed_dict', 'dataclass', 'arguments_v3']) +def container_schema_builder( + request: pytest.FixtureRequest, +) -> Callable[[dict[str, core_schema.CoreSchema]], core_schema.CoreSchema]: + if request.param == 'model': + return lambda fields: core_schema.model_schema( + cls=type('Test', (), {}), + schema=core_schema.model_fields_schema( + fields={k: core_schema.model_field(schema=v) for k, v in fields.items()}, + ), + ) + elif request.param == 'typed_dict': + return lambda fields: core_schema.typed_dict_schema( + fields={k: core_schema.typed_dict_field(schema=v) for k, v in fields.items()} + ) + elif request.param == 'dataclass': + return lambda fields: core_schema.dataclass_schema( + cls=dataclass(type('Test', (), {})), + schema=core_schema.dataclass_args_schema( + 'Test', + fields=[core_schema.dataclass_field(name=k, schema=v) for k, v in fields.items()], + ), + fields=[k for k in fields.keys()], + ) + elif request.param == 'arguments_v3': + # TODO: open an issue for this + raise pytest.xfail('arguments v3 does not yet support default_factory_takes_data properly') + else: + raise ValueError(f'Unknown container type {request.param}') - schema = core_schema.model_schema( - cls=Test, - schema=core_schema.model_fields_schema( - computed_fields=[], - fields={ - 'a': core_schema.model_field( - schema=core_schema.int_schema(), - ), - 'b': core_schema.model_field( - schema=core_schema.with_default_schema( - schema=core_schema.int_schema(), - default_factory=lambda data: data['a'], - default_factory_takes_data=True, - ), - ), - }, - ), - ) +def test_default_factory_not_called_if_existing_error(container_schema_builder, pydantic_version) -> None: + schema = container_schema_builder( + { + 'a': core_schema.int_schema(), + 'b': core_schema.with_default_schema( + schema=core_schema.int_schema(), default_factory=lambda data: data['a'], default_factory_takes_data=True + ), + } + ) v = SchemaValidator(schema) with pytest.raises(ValidationError) as e: v.validate_python({'a': 'not_an_int'}) @@ -868,7 +885,49 @@ def __init__(self, a: int, b: int): assert ( str(e.value) - == f"""2 validation errors for Test + == f"""2 validation errors for {v.title} +a + Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not_an_int', input_type=str] + For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing +b + The default factory uses validated data, but at least one validation error occurred [type=default_factory_not_called] + For further information visit https://errors.pydantic.dev/{pydantic_version}/v/default_factory_not_called""" + ) + + # repeat with the first field being a default which validates incorrectly + + schema = container_schema_builder( + { + 'a': core_schema.with_default_schema( + schema=core_schema.int_schema(), default='not_an_int', validate_default=True + ), + 'b': core_schema.with_default_schema( + schema=core_schema.int_schema(), default_factory=lambda data: data['a'], default_factory_takes_data=True + ), + } + ) + v = SchemaValidator(schema) + with pytest.raises(ValidationError) as e: + v.validate_python({}) + + assert e.value.errors(include_url=False) == [ + { + 'type': 'int_parsing', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'not_an_int', + }, + { + 'input': PydanticUndefined, + 'loc': ('b',), + 'msg': 'The default factory uses validated data, but at least one validation error occurred', + 'type': 'default_factory_not_called', + }, + ] + + assert ( + str(e.value) + == f"""2 validation errors for {v.title} a Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not_an_int', input_type=str] For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing @@ -876,3 +935,35 @@ def __init__(self, a: int, b: int): The default factory uses validated data, but at least one validation error occurred [type=default_factory_not_called] For further information visit https://errors.pydantic.dev/{pydantic_version}/v/default_factory_not_called""" ) + + +def test_default_factory_not_called_union_ok(container_schema_builder) -> None: + schema_fail = container_schema_builder( + { + 'a': core_schema.none_schema(), + 'b': core_schema.with_default_schema( + schema=core_schema.int_schema(), + default_factory=lambda data: data['a'], + default_factory_takes_data=True, + ), + } + ) + + schema_ok = container_schema_builder( + { + 'a': core_schema.int_schema(), + 'b': core_schema.with_default_schema( + schema=core_schema.int_schema(), + default_factory=lambda data: data['a'] + 1, + default_factory_takes_data=True, + ), + # this is used to show that this union member was selected + 'c': core_schema.with_default_schema(schema=core_schema.int_schema(), default=3), + } + ) + + schema = core_schema.union_schema([schema_fail, schema_ok]) + + v = SchemaValidator(schema) + s = SchemaSerializer(schema) + assert s.to_python(v.validate_python({'a': 1}), mode='json') == {'a': 1, 'b': 2, 'c': 3}