Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/validators/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ impl Validator for FunctionWrapValidator {
let handler = Bound::new(py, handler)?;
#[allow(clippy::used_underscore_items)]
let result = self._validate(handler.as_any(), py, input, state);
state.exactness = handler.borrow_mut().validator.exactness;
let handler = handler.borrow();
state.exactness = handler.validator.exactness;
state.fields_set_count = handler.validator.fields_set_count;
result
}

Expand Down
10 changes: 8 additions & 2 deletions src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ pub struct InternalValidator {
self_instance: Option<PyObject>,
recursion_guard: RecursionState,
pub(crate) exactness: Option<Exactness>,
pub(crate) fields_set_count: Option<usize>,
validation_mode: InputType,
hide_input_in_errors: bool,
validation_error_cause: bool,
Expand Down Expand Up @@ -256,6 +257,7 @@ impl InternalValidator {
self_instance: extra.self_instance.map(|d| d.clone().unbind()),
recursion_guard: state.recursion_guard.clone(),
exactness: state.exactness,
fields_set_count: state.fields_set_count,
validation_mode: extra.input_type,
hide_input_in_errors,
validation_error_cause,
Expand Down Expand Up @@ -284,7 +286,8 @@ impl InternalValidator {
by_name: None,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false.into());
state.exactness = self.exactness;
// state.exactness = self.exactness;
// state.fields_set_count = self.fields_set_count;
let result = self
.validator
.validate_assignment(py, model, field_name, field_value, &mut state)
Expand All @@ -299,7 +302,8 @@ impl InternalValidator {
self.validation_error_cause,
)
});
self.exactness = state.exactness;
// self.exactness = state.exactness;
// self.fields_set_count = state.fields_set_count;
result
}

Expand All @@ -323,6 +327,7 @@ impl InternalValidator {
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false.into());
state.exactness = self.exactness;
state.fields_set_count = self.fields_set_count;
let result = self.validator.validate(py, input, &mut state).map_err(|e| {
ValidationError::from_val_error(
py,
Expand All @@ -335,6 +340,7 @@ impl InternalValidator {
)
});
self.exactness = state.exactness;
self.fields_set_count = state.fields_set_count;
result
}
}
Expand Down
80 changes: 80 additions & 0 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,3 +1358,83 @@ class Model:

assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo)
assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar)


def test_smart_union_wrap_validator_should_not_change_nested_model_field_counts() -> None:
"""Adding a wrap validator on a union member should not affect smart union behavior"""

class SubModel:
x: str = 'x'

class ModelA:
type: str = 'A'
sub: SubModel

class ModelB:
type: str = 'B'
sub: SubModel

submodel_schema = core_schema.model_schema(
SubModel,
core_schema.model_fields_schema(fields={'x': core_schema.model_field(core_schema.str_schema())}),
)

wrapped_submodel_schema = core_schema.no_info_wrap_validator_function(
lambda v, handler: handler(v), submodel_schema
)

model_a_schema = core_schema.model_schema(
ModelA,
core_schema.model_fields_schema(
fields={
'type': core_schema.model_field(
core_schema.with_default_schema(core_schema.literal_schema(['A']), default='A'),
),
'sub': core_schema.model_field(wrapped_submodel_schema),
},
),
)

model_b_schema = core_schema.model_schema(
ModelB,
core_schema.model_fields_schema(
fields={
'type': core_schema.model_field(
core_schema.with_default_schema(core_schema.literal_schema(['B']), default='B'),
),
'sub': core_schema.model_field(submodel_schema),
},
),
)

for choices in permute_choices([model_a_schema, model_b_schema]):
schema = core_schema.union_schema(choices)
validator = SchemaValidator(schema)

assert isinstance(validator.validate_python({'type': 'A', 'sub': {'x': 'x'}}), ModelA)
assert isinstance(validator.validate_python({'type': 'B', 'sub': {'x': 'x'}}), ModelB)

# defaults to leftmost choice if there's a tie
assert isinstance(validator.validate_python({'sub': {'x': 'x'}}), choices[0]['cls'])

# test validate_assignment
class RootModel:
ab: Union[ModelA, ModelB]

root_model = core_schema.model_schema(
RootModel,
core_schema.model_fields_schema(
fields={'ab': core_schema.model_field(core_schema.union_schema([model_a_schema, model_b_schema]))}
),
)

validator = SchemaValidator(root_model)
m = validator.validate_python({'ab': {'type': 'B', 'sub': {'x': 'x'}}})
assert isinstance(m, RootModel)
assert isinstance(m.ab, ModelB)
assert m.ab.sub.x == 'x'

m = validator.validate_assignment(m, 'ab', {'sub': {'x': 'y'}})
assert isinstance(m, RootModel)
assert isinstance(m.ab, ModelA)
assert m.ab.sub.x == 'y'
Loading