Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ impl Validator for DataclassArgsValidator {
let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len());

let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
let state = &mut state.scoped_set(|state| &mut state.has_field_error, false);

let extra_behavior = state.extra_behavior_or(self.extra_behavior);

let validate_by_alias = state.validate_by_alias_or(self.validate_by_alias);
Expand Down Expand Up @@ -235,6 +237,7 @@ impl Validator for DataclassArgsValidator {
fields_set_count += 1;
}
Err(ValError::LineErrors(line_errors)) => {
state.has_field_error = true;
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
}
Err(err) => return Err(err),
Expand All @@ -246,6 +249,7 @@ impl Validator for DataclassArgsValidator {
fields_set_count += 1;
}
Err(ValError::LineErrors(line_errors)) => {
state.has_field_error = true;
errors.extend(
line_errors
.into_iter()
Expand All @@ -272,6 +276,7 @@ impl Validator for DataclassArgsValidator {
}
Err(ValError::Omit) => {}
Err(ValError::LineErrors(line_errors)) => {
state.has_field_error = true;
for err in line_errors {
// Note: this will always use the field name even if there is an alias
// However, we don't mind so much because this error can only happen if the
Expand Down
2 changes: 2 additions & 0 deletions src/validators/model_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ impl Validator for ModelFieldsValidator {

{
let state = &mut state.rebind_extra(|extra| extra.data = Some(model_dict.clone()));
let state = &mut state.scoped_set(|state| &mut state.has_field_error, false);

for field in &self.fields {
let lookup_key = field
Expand Down Expand Up @@ -242,6 +243,7 @@ impl Validator for ModelFieldsValidator {
}
Err(ValError::Omit) => {}
Err(ValError::LineErrors(line_errors)) => {
state.has_field_error = true;
for err in line_errors {
// Note: this will always use the field name even if there is an alias
// However, we don't mind so much because this error can only happen if the
Expand Down
23 changes: 17 additions & 6 deletions src/validators/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ impl Validator for TypedDictValidator {

{
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
let state = &mut state.scoped_set(|state| &mut state.has_field_error, false);

let mut fields_set_count: usize = 0;

Expand Down Expand Up @@ -229,15 +230,24 @@ impl Validator for TypedDictValidator {
output_dict.set_item(&field.name_py, value)?;
fields_set_count += 1;
}
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
if !is_last_partial || field.required {
for err in line_errors {
errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name));
Err(e) => {
state.has_field_error = true;
match e {
ValError::Omit => {}
ValError::LineErrors(line_errors) => {
if !is_last_partial || field.required {
for err in line_errors {
errors.push(lookup_path.apply_error_loc(
err,
self.loc_by_alias,
&field.name,
));
}
}
}
err => return Err(err),
}
}
Err(err) => return Err(err),
}
continue;
}
Expand All @@ -260,6 +270,7 @@ impl Validator for TypedDictValidator {
}
Err(ValError::Omit) => {}
Err(ValError::LineErrors(line_errors)) => {
state.has_field_error = true;
for err in line_errors {
// Note: this will always use the field name even if there is an alias
// However, we don't mind so much because this error can only happen if the
Expand Down
66 changes: 66 additions & 0 deletions src/validators/validation_state.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::{Deref, DerefMut};

use pyo3::prelude::*;
use pyo3::types::PyString;

Expand Down Expand Up @@ -29,6 +31,9 @@ pub struct ValidationState<'a, 'py> {
// Whether at least one field had a validation error. This is used in the context of structured types
// (models, dataclasses, etc), where we need to know if a validation error occurred before calling
// a default factory that takes the validated data.
//
// TODO: this should probably be moved directly into the structured types which need it, but that
// requires some refactoring to make them have knowledge of default (factories).
pub has_field_error: bool,
// deliberately make Extra readonly
extra: Extra<'a, 'py>,
Expand Down Expand Up @@ -58,6 +63,26 @@ impl<'a, 'py> ValidationState<'a, 'py> {
ValidationStateWithReboundExtra { state: self, old_extra }
}

/// Temporarily rebinds a field of the state by calling `projector` to get a mutable reference to the field,
/// and setting that field to `value`.
///
/// When `ScopedSetState` drops, the field is restored to its original value.
pub fn scoped_set<'state, P, T>(
&'state mut self,
projector: P,
new_value: T,
) -> ScopedSetState<'state, 'a, 'py, P, T>
where
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
{
let value = std::mem::replace((projector)(self), new_value);
ScopedSetState {
state: self,
projector,
value,
}
}

pub fn extra(&self) -> &'_ Extra<'a, 'py> {
&self.extra
}
Expand Down Expand Up @@ -176,3 +201,44 @@ impl<I: Iterator> Iterator for EnumerateLastPartial<I> {
self.iter.size_hint()
}
}

pub struct ScopedSetState<'scope, 'a, 'py, P, T>
where
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
{
/// The state which has been set for the scope.
state: &'scope mut ValidationState<'a, 'py>,
/// A function that projects from the state to the field that has been set.
projector: P,
/// The previous value of the field that has been set.
value: T,
}

impl<'a, 'py, P, T> Drop for ScopedSetState<'_, 'a, 'py, P, T>
where
P: for<'drop> Fn(&'drop mut ValidationState<'a, 'py>) -> &'drop mut T,
{
fn drop(&mut self) {
std::mem::swap((self.projector)(self.state), &mut self.value);
}
}

impl<'a, 'py, P, T> Deref for ScopedSetState<'_, 'a, 'py, P, T>
where
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
{
type Target = ValidationState<'a, 'py>;

fn deref(&self) -> &Self::Target {
self.state
}
}

impl<'a, 'py, P, T> DerefMut for ScopedSetState<'_, 'a, 'py, P, T>
where
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
{
fn deref_mut(&mut self) -> &mut ValidationState<'a, 'py> {
self.state
}
}
139 changes: 115 additions & 24 deletions tests/validators/test_with_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import weakref
from collections import deque
from dataclasses import dataclass
from typing import Any, Callable, Union, cast

import pytest
Expand All @@ -16,6 +17,7 @@
ValidationError,
core_schema,
)
from pydantic_core._pydantic_core import SchemaSerializer

from ..conftest import PyAndJson, assert_gc

Expand Down Expand Up @@ -822,31 +824,46 @@ def _raise(ex: Exception) -> None:
assert exc_info.value.errors(include_url=False, include_context=False) == expected


def test_default_factory_not_called_if_existing_error(pydantic_version) -> None:
class Test:
def __init__(self, a: int, b: int):
self.a = a
self.b = b
@pytest.fixture(params=['model', 'typed_dict', 'dataclass', 'arguments_v3'])
def container_schema_builder(
request: pytest.FixtureRequest,
) -> Callable[[dict[str, core_schema.CoreSchema]], core_schema.CoreSchema]:
if request.param == 'model':
return lambda fields: core_schema.model_schema(
cls=type('Test', (), {}),
schema=core_schema.model_fields_schema(
fields={k: core_schema.model_field(schema=v) for k, v in fields.items()},
),
)
elif request.param == 'typed_dict':
return lambda fields: core_schema.typed_dict_schema(
fields={k: core_schema.typed_dict_field(schema=v) for k, v in fields.items()}
)
elif request.param == 'dataclass':
return lambda fields: core_schema.dataclass_schema(
cls=dataclass(type('Test', (), {})),
schema=core_schema.dataclass_args_schema(
'Test',
fields=[core_schema.dataclass_field(name=k, schema=v) for k, v in fields.items()],
),
fields=[k for k in fields.keys()],
)
elif request.param == 'arguments_v3':
# TODO: open an issue for this
raise pytest.xfail('arguments v3 does not yet support default_factory_takes_data properly')
else:
raise ValueError(f'Unknown container type {request.param}')

schema = core_schema.model_schema(
cls=Test,
schema=core_schema.model_fields_schema(
computed_fields=[],
fields={
'a': core_schema.model_field(
schema=core_schema.int_schema(),
),
'b': core_schema.model_field(
schema=core_schema.with_default_schema(
schema=core_schema.int_schema(),
default_factory=lambda data: data['a'],
default_factory_takes_data=True,
),
),
},
),
)

def test_default_factory_not_called_if_existing_error(container_schema_builder, pydantic_version) -> None:
schema = container_schema_builder(
{
'a': core_schema.int_schema(),
'b': core_schema.with_default_schema(
schema=core_schema.int_schema(), default_factory=lambda data: data['a'], default_factory_takes_data=True
),
}
)
v = SchemaValidator(schema)
with pytest.raises(ValidationError) as e:
v.validate_python({'a': 'not_an_int'})
Expand All @@ -868,11 +885,85 @@ def __init__(self, a: int, b: int):

assert (
str(e.value)
== f"""2 validation errors for Test
== f"""2 validation errors for {v.title}
a
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not_an_int', input_type=str]
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing
b
The default factory uses validated data, but at least one validation error occurred [type=default_factory_not_called]
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/default_factory_not_called"""
)

# repeat with the first field being a default which validates incorrectly

schema = container_schema_builder(
{
'a': core_schema.with_default_schema(
schema=core_schema.int_schema(), default='not_an_int', validate_default=True
),
'b': core_schema.with_default_schema(
schema=core_schema.int_schema(), default_factory=lambda data: data['a'], default_factory_takes_data=True
),
}
)
v = SchemaValidator(schema)
with pytest.raises(ValidationError) as e:
v.validate_python({})

assert e.value.errors(include_url=False) == [
{
'type': 'int_parsing',
'loc': ('a',),
'msg': 'Input should be a valid integer, unable to parse string as an integer',
'input': 'not_an_int',
},
{
'input': PydanticUndefined,
'loc': ('b',),
'msg': 'The default factory uses validated data, but at least one validation error occurred',
'type': 'default_factory_not_called',
},
]

assert (
str(e.value)
== f"""2 validation errors for {v.title}
a
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not_an_int', input_type=str]
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing
b
The default factory uses validated data, but at least one validation error occurred [type=default_factory_not_called]
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/default_factory_not_called"""
)


def test_default_factory_not_called_union_ok(container_schema_builder) -> None:
schema_fail = container_schema_builder(
{
'a': core_schema.none_schema(),
'b': core_schema.with_default_schema(
schema=core_schema.int_schema(),
default_factory=lambda data: data['a'],
default_factory_takes_data=True,
),
}
)

schema_ok = container_schema_builder(
{
'a': core_schema.int_schema(),
'b': core_schema.with_default_schema(
schema=core_schema.int_schema(),
default_factory=lambda data: data['a'] + 1,
default_factory_takes_data=True,
),
# this is used to show that this union member was selected
'c': core_schema.with_default_schema(schema=core_schema.int_schema(), default=3),
}
)

schema = core_schema.union_schema([schema_fail, schema_ok])

v = SchemaValidator(schema)
s = SchemaSerializer(schema)
assert s.to_python(v.validate_python({'a': 1}), mode='json') == {'a': 1, 'b': 2, 'c': 3}
Loading