Skip to content

Commit 0bb8a07

Browse files
committed
fix: submodel fields with wrap validator affect smart union selection
1 parent 3414703 commit 0bb8a07

File tree

3 files changed

+91
-3
lines changed

3 files changed

+91
-3
lines changed

src/validators/function.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ impl Validator for FunctionWrapValidator {
380380
let handler = Bound::new(py, handler)?;
381381
#[allow(clippy::used_underscore_items)]
382382
let result = self._validate(handler.as_any(), py, input, state);
383-
state.exactness = handler.borrow_mut().validator.exactness;
383+
let handler = handler.borrow();
384+
state.exactness = handler.validator.exactness;
385+
state.fields_set_count = handler.validator.fields_set_count;
384386
result
385387
}
386388

src/validators/generator.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ pub struct InternalValidator {
225225
self_instance: Option<PyObject>,
226226
recursion_guard: RecursionState,
227227
pub(crate) exactness: Option<Exactness>,
228+
pub(crate) fields_set_count: Option<usize>,
228229
validation_mode: InputType,
229230
hide_input_in_errors: bool,
230231
validation_error_cause: bool,
@@ -256,6 +257,7 @@ impl InternalValidator {
256257
self_instance: extra.self_instance.map(|d| d.clone().unbind()),
257258
recursion_guard: state.recursion_guard.clone(),
258259
exactness: state.exactness,
260+
fields_set_count: state.fields_set_count,
259261
validation_mode: extra.input_type,
260262
hide_input_in_errors,
261263
validation_error_cause,
@@ -284,7 +286,8 @@ impl InternalValidator {
284286
by_name: None,
285287
};
286288
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false.into());
287-
state.exactness = self.exactness;
289+
// state.exactness = self.exactness;
290+
// state.fields_set_count = self.fields_set_count;
288291
let result = self
289292
.validator
290293
.validate_assignment(py, model, field_name, field_value, &mut state)
@@ -299,7 +302,8 @@ impl InternalValidator {
299302
self.validation_error_cause,
300303
)
301304
});
302-
self.exactness = state.exactness;
305+
// self.exactness = state.exactness;
306+
// self.fields_set_count = state.fields_set_count;
303307
result
304308
}
305309

@@ -323,6 +327,7 @@ impl InternalValidator {
323327
};
324328
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false.into());
325329
state.exactness = self.exactness;
330+
state.fields_set_count = self.fields_set_count;
326331
let result = self.validator.validate(py, input, &mut state).map_err(|e| {
327332
ValidationError::from_val_error(
328333
py,
@@ -335,6 +340,7 @@ impl InternalValidator {
335340
)
336341
});
337342
self.exactness = state.exactness;
343+
self.fields_set_count = state.fields_set_count;
338344
result
339345
}
340346
}

tests/validators/test_union.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,3 +1358,83 @@ class Model:
13581358

13591359
assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo)
13601360
assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar)
1361+
1362+
1363+
def test_smart_union_wrap_validator_should_not_change_nested_model_field_counts() -> None:
1364+
"""Adding a wrap validator on a union member should not affect smart union behavior"""
1365+
1366+
class SubModel:
1367+
x: str = 'x'
1368+
1369+
class ModelA:
1370+
type: str = 'A'
1371+
sub: SubModel
1372+
1373+
class ModelB:
1374+
type: str = 'B'
1375+
sub: SubModel
1376+
1377+
submodel_schema = core_schema.model_schema(
1378+
SubModel,
1379+
core_schema.model_fields_schema(fields={'x': core_schema.model_field(core_schema.str_schema())}),
1380+
)
1381+
1382+
wrapped_submodel_schema = core_schema.no_info_wrap_validator_function(
1383+
lambda v, handler: handler(v), submodel_schema
1384+
)
1385+
1386+
model_a_schema = core_schema.model_schema(
1387+
ModelA,
1388+
core_schema.model_fields_schema(
1389+
fields={
1390+
'type': core_schema.model_field(
1391+
core_schema.with_default_schema(core_schema.literal_schema(['A']), default='A'),
1392+
),
1393+
'sub': core_schema.model_field(wrapped_submodel_schema),
1394+
},
1395+
),
1396+
)
1397+
1398+
model_b_schema = core_schema.model_schema(
1399+
ModelB,
1400+
core_schema.model_fields_schema(
1401+
fields={
1402+
'type': core_schema.model_field(
1403+
core_schema.with_default_schema(core_schema.literal_schema(['B']), default='B'),
1404+
),
1405+
'sub': core_schema.model_field(submodel_schema),
1406+
},
1407+
),
1408+
)
1409+
1410+
for choices in permute_choices([model_a_schema, model_b_schema]):
1411+
schema = core_schema.union_schema(choices)
1412+
validator = SchemaValidator(schema)
1413+
1414+
assert isinstance(validator.validate_python({'type': 'A', 'sub': {'x': 'x'}}), ModelA)
1415+
assert isinstance(validator.validate_python({'type': 'B', 'sub': {'x': 'x'}}), ModelB)
1416+
1417+
# defaults to leftmost choice if there's a tie
1418+
assert isinstance(validator.validate_python({'sub': {'x': 'x'}}), choices[0]['cls'])
1419+
1420+
# test validate_assignment
1421+
class RootModel:
1422+
ab: Union[ModelA, ModelB]
1423+
1424+
root_model = core_schema.model_schema(
1425+
RootModel,
1426+
core_schema.model_fields_schema(
1427+
fields={'ab': core_schema.model_field(core_schema.union_schema([model_a_schema, model_b_schema]))}
1428+
),
1429+
)
1430+
1431+
validator = SchemaValidator(root_model)
1432+
m = validator.validate_python({'ab': {'type': 'B', 'sub': {'x': 'x'}}})
1433+
assert isinstance(m, RootModel)
1434+
assert isinstance(m.ab, ModelB)
1435+
assert m.ab.sub.x == 'x'
1436+
1437+
m = validator.validate_assignment(m, 'ab', {'sub': {'x': 'y'}})
1438+
assert isinstance(m, RootModel)
1439+
assert isinstance(m.ab, ModelA)
1440+
assert m.ab.sub.x == 'y'

0 commit comments

Comments
 (0)