Skip to content

Commit 095950b

Browse files
committed
Add model_type to validators
1 parent e7c5dc7 commit 095950b

File tree

2 files changed

+67
-7
lines changed

2 files changed

+67
-7
lines changed

python/pydantic_core/core_schema.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ class FieldSerializationInfo(SerializationInfo, Protocol):
163163
@property
164164
def field_name(self) -> str: ...
165165

166+
@property
167+
def model_type(self) -> type: ...
168+
166169

167170
class ValidationInfo(Protocol):
168171
"""
@@ -197,6 +200,14 @@ def field_name(self) -> str | None:
197200
"""
198201
...
199202

203+
@property
204+
def model_type(self) -> type | None:
205+
"""
206+
The type of the current model being validated if this validator is
207+
attached to a model field.
208+
"""
209+
...
210+
200211

201212
ExpectedSerializationTypes = Literal[
202213
'none',
@@ -1956,6 +1967,7 @@ class WithInfoValidatorFunctionSchema(TypedDict, total=False):
19561967
type: Required[Literal['with-info']]
19571968
function: Required[WithInfoValidatorFunction]
19581969
field_name: str
1970+
model_type: type
19591971

19601972

19611973
ValidationFunction = Union[NoInfoValidatorFunctionSchema, WithInfoValidatorFunctionSchema]
@@ -2025,6 +2037,7 @@ def with_info_before_validator_function(
20252037
schema: CoreSchema,
20262038
*,
20272039
field_name: str | None = None,
2040+
model_type: type | None = None,
20282041
ref: str | None = None,
20292042
json_schema_input_schema: CoreSchema | None = None,
20302043
metadata: Dict[str, Any] | None = None,
@@ -2062,7 +2075,7 @@ def fn(v: bytes, info: core_schema.ValidationInfo) -> str:
20622075
"""
20632076
return _dict_not_none(
20642077
type='function-before',
2065-
function=_dict_not_none(type='with-info', function=function, field_name=field_name),
2078+
function=_dict_not_none(type='with-info', function=function, field_name=field_name, model_type=model_type),
20662079
schema=schema,
20672080
ref=ref,
20682081
json_schema_input_schema=json_schema_input_schema,
@@ -2189,6 +2202,7 @@ class WithInfoWrapValidatorFunctionSchema(TypedDict, total=False):
21892202
type: Required[Literal['with-info']]
21902203
function: Required[WithInfoWrapValidatorFunction]
21912204
field_name: str
2205+
model_type: type
21922206

21932207

21942208
WrapValidatorFunction = Union[NoInfoWrapValidatorFunctionSchema, WithInfoWrapValidatorFunctionSchema]

src/validators/function.rs

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct FunctionInfo {
2424
/// The actual function object that will get called
2525
pub function: Py<PyAny>,
2626
pub field_name: Option<Py<PyString>>,
27+
pub model_type: Option<Py<PyAny>>,
2728
pub info_arg: bool,
2829
}
2930

@@ -37,9 +38,11 @@ fn destructure_function_schema(schema: &Bound<'_, PyDict>) -> PyResult<FunctionI
3738
_ => unreachable!(),
3839
};
3940
let field_name = func_dict.get_as(intern!(schema.py(), "field_name"))?;
41+
let model_type = func_dict.get_as(intern!(schema.py(), "model_type"))?;
4042
Ok(FunctionInfo {
4143
function,
4244
field_name,
45+
model_type,
4346
info_arg,
4447
})
4548
}
@@ -71,6 +74,7 @@ macro_rules! impl_build {
7174
},
7275
name,
7376
field_name: func_info.field_name,
77+
model_type: func_info.model_type,
7478
info_arg: func_info.info_arg,
7579
}
7680
.into())
@@ -86,6 +90,7 @@ pub struct FunctionBeforeValidator {
8690
config: PyObject,
8791
name: String,
8892
field_name: Option<Py<PyString>>,
93+
model_type: Option<Py<PyAny>>,
8994
info_arg: bool,
9095
}
9196

@@ -100,7 +105,13 @@ impl FunctionBeforeValidator {
100105
state: &'s mut ValidationState<'_, 'py>,
101106
) -> ValResult<PyObject> {
102107
let r = if self.info_arg {
103-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
108+
let info = ValidationInfo::new(
109+
py,
110+
state.extra(),
111+
&self.config,
112+
self.field_name.clone(),
113+
self.model_type.clone(),
114+
);
104115
self.func.call1(py, (input.to_object(py), info))
105116
} else {
106117
self.func.call1(py, (input.to_object(py),))
@@ -154,6 +165,7 @@ pub struct FunctionAfterValidator {
154165
config: PyObject,
155166
name: String,
156167
field_name: Option<Py<PyString>>,
168+
model_type: Option<Py<PyAny>>,
157169
info_arg: bool,
158170
}
159171

@@ -169,7 +181,13 @@ impl FunctionAfterValidator {
169181
) -> ValResult<PyObject> {
170182
let v = call(input, state)?;
171183
let r = if self.info_arg {
172-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
184+
let info = ValidationInfo::new(
185+
py,
186+
state.extra(),
187+
&self.config,
188+
self.field_name.clone(),
189+
self.model_type.clone(),
190+
);
173191
self.func.call1(py, (v.to_object(py), info))
174192
} else {
175193
self.func.call1(py, (v.to_object(py),))
@@ -221,6 +239,7 @@ pub struct FunctionPlainValidator {
221239
config: PyObject,
222240
name: String,
223241
field_name: Option<Py<PyString>>,
242+
model_type: Option<Py<PyAny>>,
224243
info_arg: bool,
225244
}
226245

@@ -242,6 +261,7 @@ impl BuildValidator for FunctionPlainValidator {
242261
},
243262
name: format!("function-plain[{}()]", function_name(function_info.function.bind(py))?),
244263
field_name: function_info.field_name.clone(),
264+
model_type: function_info.model_type,
245265
info_arg: function_info.info_arg,
246266
}
247267
.into())
@@ -258,7 +278,13 @@ impl Validator for FunctionPlainValidator {
258278
state: &mut ValidationState<'_, 'py>,
259279
) -> ValResult<PyObject> {
260280
let r = if self.info_arg {
261-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
281+
let info = ValidationInfo::new(
282+
py,
283+
state.extra(),
284+
&self.config,
285+
self.field_name.clone(),
286+
self.model_type.clone(),
287+
);
262288
self.func.call1(py, (input.to_object(py), info))
263289
} else {
264290
self.func.call1(py, (input.to_object(py),))
@@ -278,6 +304,7 @@ pub struct FunctionWrapValidator {
278304
config: PyObject,
279305
name: String,
280306
field_name: Option<Py<PyString>>,
307+
model_type: Option<Py<PyAny>>,
281308
info_arg: bool,
282309
hide_input_in_errors: bool,
283310
validation_error_cause: bool,
@@ -305,6 +332,7 @@ impl BuildValidator for FunctionWrapValidator {
305332
},
306333
name: format!("function-wrap[{}()]", function_name(function_info.function.bind(py))?),
307334
field_name: function_info.field_name.clone(),
335+
model_type: function_info.model_type,
308336
info_arg: function_info.info_arg,
309337
hide_input_in_errors,
310338
validation_error_cause,
@@ -322,7 +350,13 @@ impl FunctionWrapValidator {
322350
state: &mut ValidationState<'_, 'py>,
323351
) -> ValResult<PyObject> {
324352
let r = if self.info_arg {
325-
let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone());
353+
let info = ValidationInfo::new(
354+
py,
355+
state.extra(),
356+
&self.config,
357+
self.field_name.clone(),
358+
self.model_type.clone(),
359+
);
326360
self.func.call1(py, (input.to_object(py), handler, info))
327361
} else {
328362
self.func.call1(py, (input.to_object(py), handler))
@@ -505,15 +539,23 @@ pub struct ValidationInfo {
505539
context: Option<PyObject>,
506540
data: Option<Py<PyDict>>,
507541
field_name: Option<Py<PyString>>,
542+
model_type: Option<Py<PyAny>>,
508543
mode: InputType,
509544
}
510545

511546
impl ValidationInfo {
512-
fn new(py: Python, extra: &Extra, config: &PyObject, field_name: Option<Py<PyString>>) -> Self {
547+
fn new(
548+
py: Python,
549+
extra: &Extra,
550+
config: &PyObject,
551+
field_name: Option<Py<PyString>>,
552+
model_type: Option<Py<PyAny>>,
553+
) -> Self {
513554
Self {
514555
config: config.clone_ref(py),
515556
context: extra.context.map(|ctx| ctx.clone().into()),
516557
field_name,
558+
model_type: model_type.map(|t| t.clone().into()),
517559
data: extra.data.as_ref().map(|data| data.clone().into()),
518560
mode: extra.input_type,
519561
}
@@ -548,8 +590,12 @@ impl ValidationInfo {
548590
Some(ref field_name) => safe_repr(field_name.bind(py)).to_string(),
549591
None => "None".into(),
550592
};
593+
let model_type = match self.model_type {
594+
Some(ref model_type) => safe_repr(model_type.bind(py)).to_string(),
595+
None => "None".into(),
596+
};
551597
Ok(format!(
552-
"ValidationInfo(config={config}, context={context}, data={data}, field_name={field_name})"
598+
"ValidationInfo(config={config}, context={context}, data={data}, field_name={field_name}, model_type={model_type})"
553599
))
554600
}
555601
}

0 commit comments

Comments
 (0)