Skip to content

Commit c0f01df

Browse files
committed
Pass field_name to validators through ValidationState so functions are called with correct ValidationInfo.field_name
1 parent 9a25aa6 commit c0f01df

File tree

7 files changed

+146
-15
lines changed

7 files changed

+146
-15
lines changed

src/validators/dataclass.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuild
2424
struct Field {
2525
kw_only: bool,
2626
name: String,
27-
py_name: Py<PyString>,
27+
name_py: Py<PyString>,
2828
init: bool,
2929
init_only: bool,
3030
lookup_key_collection: LookupKeyCollection,
@@ -72,8 +72,8 @@ impl BuildValidator for DataclassArgsValidator {
7272
for field in fields_schema {
7373
let field = field.downcast::<PyDict>()?;
7474

75-
let py_name: Bound<'_, PyString> = field.get_as_req(intern!(py, "name"))?;
76-
let name: String = py_name.extract()?;
75+
let name_py: Bound<'_, PyString> = field.get_as_req(intern!(py, "name"))?;
76+
let name: String = name_py.extract()?;
7777

7878
let schema = field.get_as_req(intern!(py, "schema"))?;
7979

@@ -99,7 +99,7 @@ impl BuildValidator for DataclassArgsValidator {
9999
fields.push(Field {
100100
kw_only,
101101
name,
102-
py_name: py_name.into(),
102+
name_py: name_py.into(),
103103
lookup_key_collection,
104104
validator,
105105
init: field.get_as(intern!(py, "init"))?.unwrap_or(true),
@@ -163,13 +163,13 @@ impl Validator for DataclassArgsValidator {
163163

164164
macro_rules! set_item {
165165
($field:ident, $value:expr) => {{
166-
let py_name = $field.py_name.bind(py);
166+
let name_py = $field.name_py.bind(py);
167167
if $field.init_only {
168168
if let Some(ref mut init_only_args) = init_only_args {
169169
init_only_args.push($value);
170170
}
171171
} else {
172-
output_dict.set_item(py_name, $value)?;
172+
output_dict.set_item(name_py, $value)?;
173173
}
174174
}};
175175
}
@@ -214,6 +214,8 @@ impl Validator for DataclassArgsValidator {
214214
}
215215
let kw_value = kw_value.as_ref().map(|(path, value)| (path, value.borrow_input()));
216216

217+
let state = &mut state.rebind_extra(|extra| extra.field_name = Some(field.name_py.bind(py).clone()));
218+
217219
match (pos_value, kw_value) {
218220
// found both positional and keyword arguments, error
219221
(Some(_), Some((_, kw_value))) => {
@@ -404,11 +406,12 @@ impl Validator for DataclassArgsValidator {
404406
}
405407
}
406408

407-
match field.validator.validate(
408-
py,
409-
field_value,
410-
&mut state.rebind_extra(|extra| extra.data = Some(data_dict.clone())),
411-
) {
409+
let state = &mut state.rebind_extra(|extra| {
410+
extra.data = Some(data_dict.clone());
411+
extra.field_name = Some(field.name_py.bind(py).clone());
412+
});
413+
414+
match field.validator.validate(py, field_value, state) {
412415
Ok(output) => ok(output),
413416
Err(ValError::LineErrors(line_errors)) => {
414417
let errors = line_errors

src/validators/function.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,13 @@ impl FunctionBeforeValidator {
100100
state: &'s mut ValidationState<'_, 'py>,
101101
) -> ValResult<PyObject> {
102102
let r = if self.info_arg {
103-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
103+
let field_name = state
104+
.extra()
105+
.field_name
106+
.clone()
107+
.map(Bound::unbind)
108+
.or(self.field_name.clone());
109+
let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
104110
self.func.call1(py, (input.to_object(py)?, info))
105111
} else {
106112
self.func.call1(py, (input.to_object(py)?,))
@@ -169,7 +175,13 @@ impl FunctionAfterValidator {
169175
) -> ValResult<PyObject> {
170176
let v = call(input, state)?;
171177
let r = if self.info_arg {
172-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
178+
let field_name = state
179+
.extra()
180+
.field_name
181+
.clone()
182+
.map(Bound::unbind)
183+
.or(self.field_name.clone());
184+
let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
173185
self.func.call1(py, (v, info))
174186
} else {
175187
self.func.call1(py, (v,))
@@ -258,7 +270,13 @@ impl Validator for FunctionPlainValidator {
258270
state: &mut ValidationState<'_, 'py>,
259271
) -> ValResult<PyObject> {
260272
let r = if self.info_arg {
261-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
273+
let field_name = state
274+
.extra()
275+
.field_name
276+
.clone()
277+
.map(Bound::unbind)
278+
.or(self.field_name.clone());
279+
let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
262280
self.func.call1(py, (input.to_object(py)?, info))
263281
} else {
264282
self.func.call1(py, (input.to_object(py)?,))
@@ -322,7 +340,13 @@ impl FunctionWrapValidator {
322340
state: &mut ValidationState<'_, 'py>,
323341
) -> ValResult<PyObject> {
324342
let r = if self.info_arg {
325-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
343+
let field_name = state
344+
.extra()
345+
.field_name
346+
.clone()
347+
.map(Bound::unbind)
348+
.or(self.field_name.clone());
349+
let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
326350
self.func.call1(py, (input.to_object(py)?, handler, info))
327351
} else {
328352
self.func.call1(py, (input.to_object(py)?, handler))

src/validators/generator.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ impl InternalValidator {
276276
data: self.data.as_ref().map(|data| data.bind(py).clone()),
277277
strict: self.strict,
278278
from_attributes: self.from_attributes,
279+
field_name: Some(PyString::new(py, field_name)),
279280
context: self.context.as_ref().map(|data| data.bind(py)),
280281
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
281282
cache_str: self.cache_str,
@@ -313,6 +314,7 @@ impl InternalValidator {
313314
data: self.data.as_ref().map(|data| data.bind(py).clone()),
314315
strict: self.strict,
315316
from_attributes: self.from_attributes,
317+
field_name: None,
316318
context: self.context.as_ref().map(|data| data.bind(py)),
317319
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
318320
cache_str: self.cache_str,

src/validators/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ impl SchemaValidator {
311311
data: None,
312312
strict,
313313
from_attributes,
314+
field_name: Some(PyString::new(py, field_name)),
314315
context,
315316
self_instance: None,
316317
cache_str: self.cache_str,
@@ -337,6 +338,7 @@ impl SchemaValidator {
337338
data: None,
338339
strict,
339340
from_attributes: None,
341+
field_name: None,
340342
context,
341343
self_instance: None,
342344
cache_str: self.cache_str,
@@ -678,6 +680,8 @@ pub struct Extra<'a, 'py> {
678680
pub from_attributes: Option<bool>,
679681
/// context used in validator functions
680682
pub context: Option<&'a Bound<'py, PyAny>>,
683+
/// The name of the field being validated, if applicable
684+
pub field_name: Option<Bound<'py, PyString>>,
681685
/// This is an instance of the model or dataclass being validated, when validation is performed from `__init__`
682686
self_instance: Option<&'a Bound<'py, PyAny>>,
683687
/// Whether to use a cache of short strings to accelerate python string construction
@@ -705,6 +709,7 @@ impl<'a, 'py> Extra<'a, 'py> {
705709
data: None,
706710
strict,
707711
from_attributes,
712+
field_name: None,
708713
context,
709714
self_instance,
710715
cache_str,
@@ -721,6 +726,7 @@ impl Extra<'_, '_> {
721726
data: self.data.clone(),
722727
strict: Some(true),
723728
from_attributes: self.from_attributes,
729+
field_name: self.field_name.clone(),
724730
context: self.context,
725731
self_instance: self.self_instance,
726732
cache_str: self.cache_str,

src/validators/model_fields.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ impl Validator for ModelFieldsValidator {
197197
// extra logic either way
198198
used_keys.insert(lookup_path.first_key());
199199
}
200+
201+
let state =
202+
&mut state.rebind_extra(|extra| extra.field_name = Some(field.name_py.bind(py).clone()));
203+
200204
match field.validator.validate(py, value.borrow_input(), state) {
201205
Ok(value) => {
202206
model_dict.set_item(&field.name_py, value)?;
@@ -422,6 +426,8 @@ impl Validator for ModelFieldsValidator {
422426
));
423427
}
424428

429+
let state = &mut state.rebind_extra(|extra| extra.field_name = Some(field.name_py.bind(py).clone()));
430+
425431
prepare_result(field.validator.validate(py, field_value, state))?
426432
} else {
427433
// Handle extra (unknown) field

src/validators/typed_dict.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ impl Validator for TypedDictValidator {
218218
true => allow_partial,
219219
false => false.into(),
220220
};
221+
let state =
222+
&mut state.rebind_extra(|extra| extra.field_name = Some(field.name_py.bind(py).clone()));
223+
221224
match field.validator.validate(py, value.borrow_input(), state) {
222225
Ok(value) => {
223226
output_dict.set_item(&field.name_py, value)?;

tests/validators/test_function.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import platform
33
import re
44
from copy import deepcopy
5+
from dataclasses import dataclass
56
from typing import Any
67

78
import pytest
@@ -662,6 +663,36 @@ def f(input_value: Any, info: core_schema.ValidationInfo) -> Any:
662663
assert v.validate_python({'x': b'foo'}).x == 'input: foo'
663664

664665

666+
def test_model_field_validator_reuse() -> None:
667+
class Model:
668+
x: str
669+
y: str
670+
671+
def f(input_value: Any, info: core_schema.ValidationInfo) -> Any:
672+
return f'{info.field_name}: {input_value}'
673+
674+
# When a type alias with a validator function is used on multiple fields,
675+
# its core schema is only generated once (with the first field_name) and reused.
676+
# See https://github.com/pydantic/pydantic/issues/11737
677+
validator = core_schema.with_info_plain_validator_function(f, field_name='x')
678+
679+
v = SchemaValidator(
680+
core_schema.model_schema(
681+
Model,
682+
core_schema.model_fields_schema(
683+
{
684+
'x': core_schema.model_field(validator),
685+
'y': core_schema.model_field(validator),
686+
}
687+
),
688+
)
689+
)
690+
691+
m = v.validate_python({'x': 'foo', 'y': 'bar'})
692+
assert m.x == 'x: foo'
693+
assert m.y == 'y: bar'
694+
695+
665696
def test_model_field_wrap_validator() -> None:
666697
class Model:
667698
x: str
@@ -821,6 +852,62 @@ def f(input_value: Any, info: core_schema.ValidationInfo) -> Any:
821852
assert info_stuff == {'field_name': 'c', 'data': {'a': 1}}
822853

823854

855+
def test_typed_dict_validator_reuse() -> None:
856+
def f(input_value: Any, info: core_schema.ValidationInfo) -> Any:
857+
return f'{info.field_name}: {input_value}'
858+
859+
# When a type alias with a validator function is used on multiple fields,
860+
# its core schema is only generated once (with the first field_name) and reused.
861+
# See https://github.com/pydantic/pydantic/issues/11737
862+
validator = core_schema.with_info_plain_validator_function(f, field_name='x')
863+
864+
v = SchemaValidator(
865+
core_schema.typed_dict_schema(
866+
{
867+
'x': core_schema.model_field(validator),
868+
'y': core_schema.model_field(validator),
869+
}
870+
)
871+
)
872+
873+
data = v.validate_python({'x': 'foo', 'y': 'bar'})
874+
assert data['x'] == 'x: foo'
875+
assert data['y'] == 'y: bar'
876+
877+
878+
def test_dataclass_validator_reuse() -> None:
879+
@dataclass
880+
class Model:
881+
x: str
882+
y: str
883+
884+
def f(input_value: Any, info: core_schema.ValidationInfo) -> Any:
885+
return f'{info.field_name}: {input_value}'
886+
887+
# When a type alias with a validator function is used on multiple fields,
888+
# its core schema is only generated once (with the first field_name) and reused.
889+
# See https://github.com/pydantic/pydantic/issues/11737
890+
validator = core_schema.with_info_plain_validator_function(f, field_name='x')
891+
892+
v = SchemaValidator(
893+
core_schema.dataclass_schema(
894+
Model,
895+
core_schema.dataclass_args_schema(
896+
'Model',
897+
[
898+
core_schema.dataclass_field(name='x', schema=validator),
899+
core_schema.dataclass_field(name='y', schema=validator),
900+
],
901+
),
902+
['x', 'y'],
903+
)
904+
)
905+
906+
m = v.validate_python({'x': 'foo', 'y': 'bar'})
907+
assert m.x == 'x: foo'
908+
assert m.y == 'y: bar'
909+
910+
824911
@pytest.mark.parametrize(
825912
'mode,calls1,calls2',
826913
[

0 commit comments

Comments
 (0)