diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 3fc307ee3..0ab3dd947 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3013,6 +3013,7 @@ class ModelFieldsSchema(TypedDict, total=False): computed_fields: list[ComputedField] strict: bool extras_schema: CoreSchema + extras_keys_schema: CoreSchema extra_behavior: ExtraBehavior from_attributes: bool ref: str @@ -3027,6 +3028,7 @@ def model_fields_schema( computed_fields: list[ComputedField] | None = None, strict: bool | None = None, extras_schema: CoreSchema | None = None, + extras_keys_schema: CoreSchema | None = None, extra_behavior: ExtraBehavior | None = None, from_attributes: bool | None = None, ref: str | None = None, @@ -3034,7 +3036,7 @@ def model_fields_schema( serialization: SerSchema | None = None, ) -> ModelFieldsSchema: """ - Returns a schema that matches a typed dict, e.g.: + Returns a schema that matches the fields of a Pydantic model, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -3048,15 +3050,16 @@ def model_fields_schema( ``` Args: - fields: The fields to use for the typed dict + fields: The fields of the model model_name: The name of the model, used for error messages, defaults to "Model" computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model - strict: Whether the typed dict is strict - extras_schema: The extra validator to use for the typed dict + strict: Whether the model is strict + extras_schema: The schema to use when validating extra input data + extras_keys_schema: The schema to use when validating the keys of extra input data ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core - extra_behavior: The extra behavior to use for the typed dict - from_attributes: Whether the typed dict should be populated from attributes + extra_behavior: The extra behavior to use for the model fields + from_attributes: Whether the model fields should be populated from attributes serialization: Custom serialization schema """ return _dict_not_none( @@ -3066,6 +3069,7 @@ def model_fields_schema( computed_fields=computed_fields, strict=strict, extras_schema=extras_schema, + extras_keys_schema=extras_keys_schema, extra_behavior=extra_behavior, from_attributes=from_attributes, ref=ref, diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 1ef684890..ba1f7d1ba 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -34,6 +34,7 @@ pub struct ModelFieldsValidator { model_name: String, extra_behavior: ExtraBehavior, extras_validator: Option>, + extras_keys_validator: Option>, strict: bool, from_attributes: bool, loc_by_alias: bool, @@ -62,6 +63,11 @@ impl BuildValidator for ModelFieldsValidator { (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, }; + let extras_keys_validator = match (schema.get_item(intern!(py, "extras_keys_schema"))?, &extra_behavior) { + (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(&v, config, definitions)?)), + (Some(_), _) => return py_schema_err!("extras_keys_schema can only be used if extra_behavior=allow"), + (_, _) => None, + }; let model_name: String = schema .get_as(intern!(py, "model_name"))? .unwrap_or_else(|| "Model".to_string()); @@ -98,6 +104,7 @@ impl BuildValidator for ModelFieldsValidator { model_name, extra_behavior, extras_validator, + extras_keys_validator, strict, from_attributes, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), @@ -244,6 +251,7 @@ impl Validator for ModelFieldsValidator { fields_set_vec: &'a mut Vec>, extra_behavior: ExtraBehavior, extras_validator: Option<&'a CombinedValidator>, + extras_keys_validator: Option<&'a CombinedValidator>, state: &'a mut ValidationState<'s, 'py>, } @@ -294,7 +302,22 @@ impl Validator for ModelFieldsValidator { } ExtraBehavior::Ignore => {} ExtraBehavior::Allow => { - let py_key = either_str.as_py_string(self.py, self.state.cache_str()); + let py_key = match self.extras_keys_validator { + Some(validator) => { + match validator.validate(self.py, raw_key.borrow_input(), self.state) { + Ok(value) => value.downcast_bound::(self.py)?.clone(), + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + self.errors.push(err.with_outer_location(raw_key.clone())); + } + continue; + } + Err(err) => return Err(err), + } + } + None => either_str.as_py_string(self.py, self.state.cache_str()), + }; + if let Some(validator) = self.extras_validator { match validator.validate(self.py, value, self.state) { Ok(value) => { @@ -326,6 +349,7 @@ impl Validator for ModelFieldsValidator { fields_set_vec: &mut fields_set_vec, extra_behavior: self.extra_behavior, extras_validator: self.extras_validator.as_deref(), + extras_keys_validator: self.extras_keys_validator.as_deref(), state, })??; diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index 22b8620a8..cc04f07c1 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -213,6 +213,13 @@ def test_allow_extra_invalid(): ) ) + with pytest.raises(SchemaError, match='extras_keys_schema can only be used if extra_behavior=allow'): + SchemaValidator( + schema=core_schema.model_fields_schema( + fields={}, extras_keys_schema=core_schema.int_schema(), extra_behavior='ignore' + ) + ) + def test_allow_extra_wrong(): with pytest.raises(SchemaError, match='Invalid extra_behavior: `wrong`'): @@ -1758,6 +1765,24 @@ def test_extra_behavior_ignore(config: Union[core_schema.CoreConfig, None], sche assert 'not_f' not in m +def test_extra_behavior_allow_keys_validation() -> None: + v = SchemaValidator( + core_schema.model_fields_schema( + {}, extra_behavior='allow', extras_keys_schema=core_schema.str_schema(max_length=3) + ) + ) + + m, model_extra, fields_set = v.validate_python({'ext': 123}) + assert m == {} + assert model_extra == {'ext': 123} + assert fields_set == {'ext'} + + with pytest.raises(ValidationError) as exc_info: + v.validate_python({'extra_too_long': 123}) + + assert exc_info.value.errors()[0]['type'] == 'string_too_long' + + @pytest.mark.parametrize('config_by_alias', [None, True, False]) @pytest.mark.parametrize('config_by_name', [None, True, False]) @pytest.mark.parametrize('runtime_by_alias', [None, True, False])