Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ install:
pip install -U pip wheel pre-commit
pip install -r tests/requirements.txt
pip install -r tests/requirements-linting.txt
pip install -e .
pip install -v -e .
pre-commit install

.PHONY: install-rust-coverage
Expand Down
31 changes: 22 additions & 9 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,19 +422,32 @@ impl TaggedUnionSerializer {
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
let py = value.py();
let discriminator_value = match &self.discriminator {
Discriminator::LookupKey(lookup_key) => lookup_key
.simple_py_get_attr(value)
.ok()
.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))),
Discriminator::LookupKey(lookup_key) => {
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
// at this point. we could be more strict and only do this in lax mode...
let getattr_result = match value.is_instance_of::<PyDict>() {
true => {
let value_dict = value.downcast::<PyDict>().unwrap();
lookup_key.py_get_dict_item(value_dict).ok()
}
false => lookup_key.simple_py_get_attr(value).ok(),
};
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py)))
}
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
};
if discriminator_value.is_none() {
let value_str = truncate_safe_repr(value, None);
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
if extra.check == SerCheck::None {
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
}
}
discriminator_value
}
Expand Down
62 changes: 62 additions & 0 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,43 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
assert m in str(w[0].message)


def test_mixed_union_models_and_other_types() -> None:
s = SchemaSerializer(
core_schema.union_schema(
[
core_schema.tagged_union_schema(
discriminator='type_',
choices={
'cat': core_schema.model_schema(
cls=ModelCat,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
},
),
),
'dog': core_schema.model_schema(
cls=ModelDog,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
},
),
),
},
),
core_schema.str_schema(),
]
)
)

assert s.to_python(ModelCat(type_='cat'), warnings='error') == {'type_': 'cat'}
assert s.to_python(ModelDog(type_='dog'), warnings='error') == {'type_': 'dog'}
# note, this fails as ModelCat and ModelDog (discriminator warnings, etc), but the warnings
# don't bubble up to this level :)
assert s.to_python('a string', warnings='error') == 'a string'


@pytest.mark.parametrize(
'input,expected',
[
Expand Down Expand Up @@ -1000,3 +1037,28 @@ def test_union_of_unions_of_models_with_tagged_union_json_serialization(
)

assert s.to_json(input, warnings='error') == expected


def test_discriminated_union_ser_with_typed_dict() -> None:
v = SchemaSerializer(
core_schema.tagged_union_schema(
{
'a': core_schema.typed_dict_schema(
{
'type': core_schema.typed_dict_field(core_schema.literal_schema(['a'])),
'a': core_schema.typed_dict_field(core_schema.int_schema()),
}
),
'b': core_schema.typed_dict_schema(
{
'type': core_schema.typed_dict_field(core_schema.literal_schema(['b'])),
'b': core_schema.typed_dict_field(core_schema.str_schema()),
}
),
},
discriminator='type',
)
)

assert v.to_python({'type': 'a', 'a': 1}, warnings='error') == {'type': 'a', 'a': 1}
assert v.to_python({'type': 'b', 'b': 'foo'}, warnings='error') == {'type': 'b', 'b': 'foo'}
Loading