Skip to content

Commit bc59bd9

Browse files
committed
fix default_factory which takes data on more types
1 parent e87ba01 commit bc59bd9

File tree

5 files changed

+208
-30
lines changed

5 files changed

+208
-30
lines changed

src/validators/dataclass.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ impl Validator for DataclassArgsValidator {
157157
let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len());
158158

159159
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
160+
let state = &mut state.scoped_set(|state| &mut state.has_field_error, false);
161+
160162
let extra_behavior = state.extra_behavior_or(self.extra_behavior);
161163

162164
let validate_by_alias = state.validate_by_alias_or(self.validate_by_alias);
@@ -219,6 +221,8 @@ impl Validator for DataclassArgsValidator {
219221

220222
let state = &mut state.rebind_extra(|extra| extra.field_name = Some(field.name_py.bind(py).clone()));
221223

224+
// FIXME: need to add support for `has_field_error` here
225+
222226
match (pos_value, kw_value) {
223227
// found both positional and keyword arguments, error
224228
(Some(_), Some((_, kw_value))) => {
@@ -235,6 +239,7 @@ impl Validator for DataclassArgsValidator {
235239
fields_set_count += 1;
236240
}
237241
Err(ValError::LineErrors(line_errors)) => {
242+
state.has_field_error = true;
238243
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
239244
}
240245
Err(err) => return Err(err),
@@ -246,6 +251,7 @@ impl Validator for DataclassArgsValidator {
246251
fields_set_count += 1;
247252
}
248253
Err(ValError::LineErrors(line_errors)) => {
254+
state.has_field_error = true;
249255
errors.extend(
250256
line_errors
251257
.into_iter()
@@ -272,6 +278,7 @@ impl Validator for DataclassArgsValidator {
272278
}
273279
Err(ValError::Omit) => {}
274280
Err(ValError::LineErrors(line_errors)) => {
281+
state.has_field_error = true;
275282
for err in line_errors {
276283
// Note: this will always use the field name even if there is an alias
277284
// However, we don't mind so much because this error can only happen if the

src/validators/model_fields.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ impl Validator for ModelFieldsValidator {
179179

180180
{
181181
let state = &mut state.rebind_extra(|extra| extra.data = Some(model_dict.clone()));
182+
let state = &mut state.scoped_set(|state| &mut state.has_field_error, false);
182183

183184
for field in &self.fields {
184185
let lookup_key = field
@@ -242,6 +243,7 @@ impl Validator for ModelFieldsValidator {
242243
}
243244
Err(ValError::Omit) => {}
244245
Err(ValError::LineErrors(line_errors)) => {
246+
state.has_field_error = true;
245247
for err in line_errors {
246248
// Note: this will always use the field name even if there is an alias
247249
// However, we don't mind so much because this error can only happen if the

src/validators/typed_dict.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ impl Validator for TypedDictValidator {
185185

186186
{
187187
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
188+
let state = &mut state.scoped_set(|state| &mut state.has_field_error, false);
188189

189190
let mut fields_set_count: usize = 0;
190191

@@ -229,15 +230,25 @@ impl Validator for TypedDictValidator {
229230
output_dict.set_item(&field.name_py, value)?;
230231
fields_set_count += 1;
231232
}
232-
Err(ValError::Omit) => continue,
233-
Err(ValError::LineErrors(line_errors)) => {
234-
if !is_last_partial || field.required {
235-
for err in line_errors {
236-
errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name));
233+
Err(e) => {
234+
// FIXME handle the different error types better?
235+
state.has_field_error = true;
236+
match e {
237+
ValError::Omit => {}
238+
ValError::LineErrors(line_errors) => {
239+
if !is_last_partial || field.required {
240+
for err in line_errors {
241+
errors.push(lookup_path.apply_error_loc(
242+
err,
243+
self.loc_by_alias,
244+
&field.name,
245+
));
246+
}
247+
}
237248
}
249+
err => return Err(err),
238250
}
239251
}
240-
Err(err) => return Err(err),
241252
}
242253
continue;
243254
}
@@ -260,6 +271,7 @@ impl Validator for TypedDictValidator {
260271
}
261272
Err(ValError::Omit) => {}
262273
Err(ValError::LineErrors(line_errors)) => {
274+
state.has_field_error = true;
263275
for err in line_errors {
264276
// Note: this will always use the field name even if there is an alias
265277
// However, we don't mind so much because this error can only happen if the

src/validators/validation_state.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::ops::{Deref, DerefMut};
2+
13
use pyo3::prelude::*;
24
use pyo3::types::PyString;
35

@@ -29,6 +31,9 @@ pub struct ValidationState<'a, 'py> {
2931
// Whether at least one field had a validation error. This is used in the context of structured types
3032
// (models, dataclasses, etc), where we need to know if a validation error occurred before calling
3133
// a default factory that takes the validated data.
34+
//
35+
// TODO: this should probably be moved directly into the structured types which need it, but that
36+
// requires some refactoring to make them have knowledge of default (factories).
3237
pub has_field_error: bool,
3338
// deliberately make Extra readonly
3439
extra: Extra<'a, 'py>,
@@ -58,6 +63,26 @@ impl<'a, 'py> ValidationState<'a, 'py> {
5863
ValidationStateWithReboundExtra { state: self, old_extra }
5964
}
6065

66+
/// Temporarily rebinds a field of the state by calling `projector` to get a mutable reference to the field,
67+
/// and setting that field to `value`.
68+
///
69+
/// When `ScopedSetState` drops, the field is restored to its original value.
70+
pub fn scoped_set<'state, P, T>(
71+
&'state mut self,
72+
projector: P,
73+
new_value: T,
74+
) -> ScopedSetState<'state, 'a, 'py, P, T>
75+
where
76+
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
77+
{
78+
let value = std::mem::replace((projector)(self), new_value);
79+
ScopedSetState {
80+
state: self,
81+
projector,
82+
value,
83+
}
84+
}
85+
6186
pub fn extra(&self) -> &'_ Extra<'a, 'py> {
6287
&self.extra
6388
}
@@ -176,3 +201,44 @@ impl<I: Iterator> Iterator for EnumerateLastPartial<I> {
176201
self.iter.size_hint()
177202
}
178203
}
204+
205+
pub struct ScopedSetState<'scope, 'a, 'py, P, T>
206+
where
207+
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
208+
{
209+
/// The state which has been set for the scope.
210+
state: &'scope mut ValidationState<'a, 'py>,
211+
/// A function that projects from the state to the field that has been set.
212+
projector: P,
213+
/// The previous value of the field that has been set.
214+
value: T,
215+
}
216+
217+
impl<'a, 'py, P, T> Drop for ScopedSetState<'_, 'a, 'py, P, T>
218+
where
219+
P: for<'drop> Fn(&'drop mut ValidationState<'a, 'py>) -> &'drop mut T,
220+
{
221+
fn drop(&mut self) {
222+
std::mem::swap((self.projector)(self.state), &mut self.value);
223+
}
224+
}
225+
226+
impl<'a, 'py, P, T> Deref for ScopedSetState<'_, 'a, 'py, P, T>
227+
where
228+
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
229+
{
230+
type Target = ValidationState<'a, 'py>;
231+
232+
fn deref(&self) -> &Self::Target {
233+
self.state
234+
}
235+
}
236+
237+
impl<'a, 'py, P, T> DerefMut for ScopedSetState<'_, 'a, 'py, P, T>
238+
where
239+
P: for<'p> Fn(&'p mut ValidationState<'a, 'py>) -> &'p mut T,
240+
{
241+
fn deref_mut(&mut self) -> &mut ValidationState<'a, 'py> {
242+
self.state
243+
}
244+
}

tests/validators/test_with_default.py

Lines changed: 115 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import weakref
44
from collections import deque
5+
from dataclasses import dataclass
56
from typing import Any, Callable, Union, cast
67

78
import pytest
@@ -16,6 +17,7 @@
1617
ValidationError,
1718
core_schema,
1819
)
20+
from pydantic_core._pydantic_core import SchemaSerializer
1921

2022
from ..conftest import PyAndJson, assert_gc
2123

@@ -822,31 +824,46 @@ def _raise(ex: Exception) -> None:
822824
assert exc_info.value.errors(include_url=False, include_context=False) == expected
823825

824826

825-
def test_default_factory_not_called_if_existing_error(pydantic_version) -> None:
826-
class Test:
827-
def __init__(self, a: int, b: int):
828-
self.a = a
829-
self.b = b
827+
@pytest.fixture(params=['model', 'typed_dict', 'dataclass', 'arguments_v3'])
828+
def container_schema_builder(
829+
request: pytest.FixtureRequest,
830+
) -> Callable[[dict[str, core_schema.CoreSchema]], core_schema.CoreSchema]:
831+
if request.param == 'model':
832+
return lambda fields: core_schema.model_schema(
833+
cls=type('Test', (), {}),
834+
schema=core_schema.model_fields_schema(
835+
fields={k: core_schema.model_field(schema=v) for k, v in fields.items()},
836+
),
837+
)
838+
elif request.param == 'typed_dict':
839+
return lambda fields: core_schema.typed_dict_schema(
840+
fields={k: core_schema.typed_dict_field(schema=v) for k, v in fields.items()}
841+
)
842+
elif request.param == 'dataclass':
843+
return lambda fields: core_schema.dataclass_schema(
844+
cls=dataclass(type('Test', (), {})),
845+
schema=core_schema.dataclass_args_schema(
846+
'Test',
847+
fields=[core_schema.dataclass_field(name=k, schema=v) for k, v in fields.items()],
848+
),
849+
fields=[k for k in fields.keys()],
850+
)
851+
elif request.param == 'arguments_v3':
852+
# TODO: open an issue for this
853+
raise pytest.xfail('arguments v3 does not yet support default_factory_takes_data properly')
854+
else:
855+
raise ValueError(f'Unknown container type {request.param}')
830856

831-
schema = core_schema.model_schema(
832-
cls=Test,
833-
schema=core_schema.model_fields_schema(
834-
computed_fields=[],
835-
fields={
836-
'a': core_schema.model_field(
837-
schema=core_schema.int_schema(),
838-
),
839-
'b': core_schema.model_field(
840-
schema=core_schema.with_default_schema(
841-
schema=core_schema.int_schema(),
842-
default_factory=lambda data: data['a'],
843-
default_factory_takes_data=True,
844-
),
845-
),
846-
},
847-
),
848-
)
849857

858+
def test_default_factory_not_called_if_existing_error(container_schema_builder, pydantic_version) -> None:
859+
schema = container_schema_builder(
860+
{
861+
'a': core_schema.int_schema(),
862+
'b': core_schema.with_default_schema(
863+
schema=core_schema.int_schema(), default_factory=lambda data: data['a'], default_factory_takes_data=True
864+
),
865+
}
866+
)
850867
v = SchemaValidator(schema)
851868
with pytest.raises(ValidationError) as e:
852869
v.validate_python({'a': 'not_an_int'})
@@ -868,11 +885,85 @@ def __init__(self, a: int, b: int):
868885

869886
assert (
870887
str(e.value)
871-
== f"""2 validation errors for Test
888+
== f"""2 validation errors for {v.title}
889+
a
890+
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not_an_int', input_type=str]
891+
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing
892+
b
893+
The default factory uses validated data, but at least one validation error occurred [type=default_factory_not_called]
894+
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/default_factory_not_called"""
895+
)
896+
897+
# repeat with the first field being a default which validates incorrectly
898+
899+
schema = container_schema_builder(
900+
{
901+
'a': core_schema.with_default_schema(
902+
schema=core_schema.int_schema(), default='not_an_int', validate_default=True
903+
),
904+
'b': core_schema.with_default_schema(
905+
schema=core_schema.int_schema(), default_factory=lambda data: data['a'], default_factory_takes_data=True
906+
),
907+
}
908+
)
909+
v = SchemaValidator(schema)
910+
with pytest.raises(ValidationError) as e:
911+
v.validate_python({})
912+
913+
assert e.value.errors(include_url=False) == [
914+
{
915+
'type': 'int_parsing',
916+
'loc': ('a',),
917+
'msg': 'Input should be a valid integer, unable to parse string as an integer',
918+
'input': 'not_an_int',
919+
},
920+
{
921+
'input': PydanticUndefined,
922+
'loc': ('b',),
923+
'msg': 'The default factory uses validated data, but at least one validation error occurred',
924+
'type': 'default_factory_not_called',
925+
},
926+
]
927+
928+
assert (
929+
str(e.value)
930+
== f"""2 validation errors for {v.title}
872931
a
873932
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not_an_int', input_type=str]
874933
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing
875934
b
876935
The default factory uses validated data, but at least one validation error occurred [type=default_factory_not_called]
877936
For further information visit https://errors.pydantic.dev/{pydantic_version}/v/default_factory_not_called"""
878937
)
938+
939+
940+
def test_default_factory_not_called_union_ok(container_schema_builder) -> None:
941+
schema_fail = container_schema_builder(
942+
{
943+
'a': core_schema.none_schema(),
944+
'b': core_schema.with_default_schema(
945+
schema=core_schema.int_schema(),
946+
default_factory=lambda data: data['a'],
947+
default_factory_takes_data=True,
948+
),
949+
}
950+
)
951+
952+
schema_ok = container_schema_builder(
953+
{
954+
'a': core_schema.int_schema(),
955+
'b': core_schema.with_default_schema(
956+
schema=core_schema.int_schema(),
957+
default_factory=lambda data: data['a'] + 1,
958+
default_factory_takes_data=True,
959+
),
960+
# this is used to show that this union member was selected
961+
'c': core_schema.with_default_schema(schema=core_schema.int_schema(), default=3),
962+
}
963+
)
964+
965+
schema = core_schema.union_schema([schema_fail, schema_ok])
966+
967+
v = SchemaValidator(schema)
968+
s = SchemaSerializer(schema)
969+
assert s.to_python(v.validate_python({'a': 1}), mode='json') == {'a': 1, 'b': 2, 'c': 3}

0 commit comments

Comments
 (0)