Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3372,11 +3372,15 @@ def arguments_parameter(
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)


VarKwargsMode: TypeAlias = Literal['uniform', 'unpacked-typed-dict']


class ArgumentsSchema(TypedDict, total=False):
type: Required[Literal['arguments']]
arguments_schema: Required[List[ArgumentsParameter]]
populate_by_name: bool
var_args_schema: CoreSchema
var_kwargs_mode: VarKwargsMode
var_kwargs_schema: CoreSchema
ref: str
metadata: Dict[str, Any]
Expand All @@ -3388,6 +3392,7 @@ def arguments_schema(
*,
populate_by_name: bool | None = None,
var_args_schema: CoreSchema | None = None,
var_kwargs_mode: VarKwargsMode | None = None,
var_kwargs_schema: CoreSchema | None = None,
ref: str | None = None,
metadata: Dict[str, Any] | None = None,
Expand All @@ -3414,6 +3419,9 @@ def arguments_schema(
arguments: The arguments to use for the arguments schema
populate_by_name: Whether to populate by name
var_args_schema: The variable args schema to use for the arguments schema
var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'uniform'`, every value of the
keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`,
the `schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema]
var_kwargs_schema: The variable kwargs schema to use for the arguments schema
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand All @@ -3424,6 +3432,7 @@ def arguments_schema(
arguments_schema=arguments,
populate_by_name=populate_by_name,
var_args_schema=var_args_schema,
var_kwargs_mode=var_kwargs_mode,
var_kwargs_schema=var_kwargs_schema,
ref=ref,
metadata=metadata,
Expand Down
107 changes: 86 additions & 21 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::str::FromStr;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
Expand All @@ -15,6 +17,27 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, PartialEq)]
enum VarKwargsMode {
Uniform,
UnpackedTypedDict,
}

impl FromStr for VarKwargsMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"uniform" => Ok(Self::Uniform),
"unpacked-typed-dict" => Ok(Self::UnpackedTypedDict),
s => py_schema_err!(
"Invalid var_kwargs mode: `{}`, expected `uniform` or `unpacked-typed-dict`",
s
),
}
}
}

#[derive(Debug)]
struct Parameter {
positional: bool,
Expand All @@ -29,6 +52,7 @@ pub struct ArgumentsValidator {
parameters: Vec<Parameter>,
positional_params_count: usize,
var_args_validator: Option<Box<CombinedValidator>>,
var_kwargs_mode: VarKwargsMode,
var_kwargs_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
extra: ExtraBehavior,
Expand Down Expand Up @@ -117,17 +141,31 @@ impl BuildValidator for ArgumentsValidator {
});
}

let py_var_kwargs_mode: Bound<PyString> = schema
.get_as(intern!(py, "var_kwargs_mode"))?
.unwrap_or_else(|| PyString::new_bound(py, "single"));

let var_kwargs_mode = VarKwargsMode::from_str(py_var_kwargs_mode.to_str()?)?;
let var_kwargs_validator = match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
};

if var_kwargs_mode == VarKwargsMode::UnpackedTypedDict && var_kwargs_validator.is_none() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should also check that var_kwargs_validator is a TypedDictValidator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to, given that you have the conditional checks in pydantic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it a TypedDictValdiator statically (at build time) would allow for efficiency gains (by avoiding the dispatch via CombinedValidator).

But I wonder, are there cases where it can be wrapped in a function-after validator (e.g. model_validator)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it a TypedDictValdiator statically (at build time) would allow for efficiency gains (by avoiding the dispatch via CombinedValidator).

The thing is (as described here) var_kwargs_validator is used for both uniform and unpacked-typed-dict modes.

But I wonder, are there cases where it can be wrapped in a function-after validator (e.g. model_validator)?

only config can be attached to typed dicts iirc, and it still results in a typed dict schema.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have

#[derive(Debug, PartialEq)]
enum VarKwargsMode {
    Uniform(CombinedValidator),
    UnpackedTypedDict(TypedDictValidator),
}

... and replace the FromStr implementation with a bespoke function like from_string_and_validator or similar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will leave this as an optimization for later, going to approve for now 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1457

return py_schema_err!(
"`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
);
}

Ok(Self {
parameters,
positional_params_count,
var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
var_kwargs_mode,
var_kwargs_validator,
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
}
Expand Down Expand Up @@ -258,6 +296,8 @@ impl Validator for ArgumentsValidator {
// if there are kwargs check any that haven't been processed yet
if let Some(kwargs) = args.kwargs() {
if kwargs.len() > used_kwargs.len() {
let remaining_kwargs = PyDict::new_bound(py);

for result in kwargs.iter() {
let (raw_key, value) = result?;
let either_str = match raw_key
Expand All @@ -278,30 +318,55 @@ impl Validator for ArgumentsValidator {
Err(err) => return Err(err),
};
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
match self.var_kwargs_mode {
VarKwargsMode::Uniform => match &self.var_kwargs_validator {
Some(validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
VarKwargsMode::UnpackedTypedDict => {
// Save to the remaining kwargs, we will validate as a single dict:
remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
}
}
}

if self.var_kwargs_mode == VarKwargsMode::UnpackedTypedDict {
// `var_kwargs_validator` is guaranteed to be `Some`:
match self
.var_kwargs_validator
.as_ref()
.unwrap()
.validate(py, remaining_kwargs.as_any(), state)
{
Ok(value) => {
output_kwargs.update(value.downcast_bound::<PyDict>(py).unwrap().as_mapping())?;
}
Err(ValError::LineErrors(line_errors)) => {
errors.extend(line_errors);
}
Err(err) => return Err(err),
}
}
}
}

Expand Down
57 changes: 56 additions & 1 deletion tests/validators/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,19 @@ def test_build_non_default_follows():
)


def test_build_missing_var_kwargs():
with pytest.raises(
SchemaError, match="`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
):
SchemaValidator(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
}
)


@pytest.mark.parametrize(
'input_value,expected',
[
Expand All @@ -778,7 +791,7 @@ def test_build_non_default_follows():
],
ids=repr,
)
def test_kwargs(py_and_json: PyAndJson, input_value, expected):
def test_kwargs_uniform(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
Expand All @@ -796,6 +809,48 @@ def test_kwargs(py_and_json: PyAndJson, input_value, expected):
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
[ArgsKwargs((), {'x': 1}), ((), {'x': 1})],
[ArgsKwargs((), {'x': 1.0}), Err('x\n Input should be a valid integer [type=int_type,')],
[ArgsKwargs((), {'x': 1, 'z': 'str'}), ((), {'x': 1, 'y': 'str'})],
[ArgsKwargs((), {'x': 1, 'y': 'str'}), Err('y\n Extra inputs are not permitted [type=extra_forbidden,')],
],
)
def test_kwargs_typed_dict(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
'var_kwargs_schema': {
'type': 'typed-dict',
'fields': {
'x': {
'type': 'typed-dict-field',
'schema': {'type': 'int', 'strict': True},
'required': True,
},
'y': {
'type': 'typed-dict-field',
'schema': {'type': 'str'},
'required': False,
'validation_alias': 'z',
},
},
'config': {'extra_fields_behavior': 'forbid'},
},
}
)

if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
else:
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
Expand Down
Loading