Skip to content

Commit 043fce1

Browse files
committed
fix(union_serializer): do not raise warnings in nested unions
In case unions of unions are used, this will bubble-up the errors rather than warning immediately. If no solution is found among all serializers by the top-level union, it will warn as before. Signed-off-by: Luka Peschke <[email protected]>
1 parent 4cb82bf commit 043fce1

File tree

2 files changed

+252
-7
lines changed

2 files changed

+252
-7
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::build_tools::py_schema_err;
99
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
1010
use crate::definitions::DefinitionsBuilder;
1111
use crate::tools::{truncate_safe_repr, SchemaDict};
12+
use crate::PydanticSerializationUnexpectedValue;
1213

1314
use super::{
1415
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
@@ -89,7 +90,8 @@ fn to_python(
8990
}
9091
}
9192

92-
if retry_with_lax_check {
93+
// If extra.check is SerCheck::Strict, we're in a nested union
94+
if extra.check != SerCheck::Strict && retry_with_lax_check {
9395
new_extra.check = SerCheck::Lax;
9496
for comb_serializer in choices {
9597
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
@@ -98,8 +100,17 @@ fn to_python(
98100
}
99101
}
100102

101-
for err in &errors {
102-
extra.warnings.custom_warning(err.to_string());
103+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
104+
if extra.check == SerCheck::None {
105+
for err in &errors {
106+
extra.warnings.custom_warning(err.to_string());
107+
}
108+
}
109+
// Otherwise, if we've encountered errors, return them to the parent union, which should take
110+
// care of the formatting for us
111+
else if !errors.is_empty() {
112+
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
113+
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
103114
}
104115

105116
infer_to_python(value, include, exclude, extra)
@@ -122,7 +133,8 @@ fn json_key<'a>(
122133
}
123134
}
124135

125-
if retry_with_lax_check {
136+
// If extra.check is SerCheck::Strict, we're in a nested union
137+
if extra.check != SerCheck::Strict && retry_with_lax_check {
126138
new_extra.check = SerCheck::Lax;
127139
for comb_serializer in choices {
128140
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
@@ -131,10 +143,18 @@ fn json_key<'a>(
131143
}
132144
}
133145

134-
for err in &errors {
135-
extra.warnings.custom_warning(err.to_string());
146+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
147+
if extra.check == SerCheck::None {
148+
for err in &errors {
149+
extra.warnings.custom_warning(err.to_string());
150+
}
151+
}
152+
// Otherwise, if we've encountered errors, return them to the parent union, which should take
153+
// care of the formatting for us
154+
else if !errors.is_empty() {
155+
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
156+
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
136157
}
137-
138158
infer_json_key(key, extra)
139159
}
140160

tests/serializers/test_union.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,228 @@ class ModelB:
778778
model_b = ModelB(field=1)
779779
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
780780
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}
781+
782+
783+
class ModelDog:
784+
def __init__(self, type_: Literal['dog']) -> None:
785+
self.type_ = 'dog'
786+
787+
788+
class ModelCat:
789+
def __init__(self, type_: Literal['cat']) -> None:
790+
self.type_ = 'cat'
791+
792+
793+
class ModelAlien:
794+
def __init__(self, type_: Literal['alien']) -> None:
795+
self.type_ = 'alien'
796+
797+
798+
@pytest.fixture
799+
def model_a_b_union_schema() -> core_schema.UnionSchema:
800+
return core_schema.union_schema(
801+
[
802+
core_schema.model_schema(
803+
cls=ModelA,
804+
schema=core_schema.model_fields_schema(
805+
fields={
806+
'a': core_schema.model_field(core_schema.str_schema()),
807+
'b': core_schema.model_field(core_schema.str_schema()),
808+
},
809+
),
810+
),
811+
core_schema.model_schema(
812+
cls=ModelB,
813+
schema=core_schema.model_fields_schema(
814+
fields={
815+
'c': core_schema.model_field(core_schema.str_schema()),
816+
'd': core_schema.model_field(core_schema.str_schema()),
817+
},
818+
),
819+
),
820+
]
821+
)
822+
823+
824+
@pytest.fixture
825+
def union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
826+
return core_schema.union_schema(
827+
[
828+
model_a_b_union_schema,
829+
core_schema.union_schema(
830+
[
831+
core_schema.model_schema(
832+
cls=ModelCat,
833+
schema=core_schema.model_fields_schema(
834+
fields={
835+
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
836+
},
837+
),
838+
),
839+
core_schema.model_schema(
840+
cls=ModelDog,
841+
schema=core_schema.model_fields_schema(
842+
fields={
843+
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
844+
},
845+
),
846+
),
847+
]
848+
),
849+
]
850+
)
851+
852+
853+
@pytest.mark.parametrize(
854+
'input,expected',
855+
[
856+
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
857+
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
858+
(ModelCat(type_='cat'), {'type_': 'cat'}),
859+
(ModelDog(type_='dog'), {'type_': 'dog'}),
860+
],
861+
)
862+
def test_union_of_unions_of_models(union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any) -> None:
863+
s = SchemaSerializer(union_of_unions_schema)
864+
assert s.to_python(input, warnings='error') == expected
865+
866+
867+
def test_union_of_unions_of_models_invalid_variant(union_of_unions_schema: core_schema.UnionSchema) -> None:
868+
s = SchemaSerializer(union_of_unions_schema)
869+
# All warnings should be available
870+
messages = [
871+
'Expected `ModelA` but got `ModelAlien`',
872+
'Expected `ModelB` but got `ModelAlien`',
873+
'Expected `ModelCat` but got `ModelAlien`',
874+
'Expected `ModelDog` but got `ModelAlien`',
875+
]
876+
877+
with warnings.catch_warnings(record=True) as w:
878+
warnings.simplefilter('always')
879+
s.to_python(ModelAlien(type_='alien'))
880+
for m in messages:
881+
assert m in str(w[0].message)
882+
883+
884+
@pytest.fixture
885+
def tagged_union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
886+
return core_schema.union_schema(
887+
[
888+
model_a_b_union_schema,
889+
core_schema.tagged_union_schema(
890+
discriminator='type_',
891+
choices={
892+
'cat': core_schema.model_schema(
893+
cls=ModelCat,
894+
schema=core_schema.model_fields_schema(
895+
fields={
896+
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
897+
},
898+
),
899+
),
900+
'dog': core_schema.model_schema(
901+
cls=ModelDog,
902+
schema=core_schema.model_fields_schema(
903+
fields={
904+
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
905+
},
906+
),
907+
),
908+
},
909+
),
910+
]
911+
)
912+
913+
914+
@pytest.mark.parametrize(
915+
'input,expected',
916+
[
917+
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
918+
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
919+
(ModelCat(type_='cat'), {'type_': 'cat'}),
920+
(ModelDog(type_='dog'), {'type_': 'dog'}),
921+
],
922+
)
923+
def test_union_of_unions_of_models_with_tagged_union(
924+
tagged_union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any
925+
) -> None:
926+
s = SchemaSerializer(tagged_union_of_unions_schema)
927+
assert s.to_python(input, warnings='error') == expected
928+
929+
930+
def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
931+
tagged_union_of_unions_schema: core_schema.UnionSchema,
932+
) -> None:
933+
s = SchemaSerializer(tagged_union_of_unions_schema)
934+
# All warnings should be available
935+
messages = [
936+
'Expected `ModelA` but got `ModelAlien`',
937+
'Expected `ModelB` but got `ModelAlien`',
938+
'Expected `ModelCat` but got `ModelAlien`',
939+
'Expected `ModelDog` but got `ModelAlien`',
940+
]
941+
942+
with warnings.catch_warnings(record=True) as w:
943+
warnings.simplefilter('always')
944+
s.to_python(ModelAlien(type_='alien'))
945+
for m in messages:
946+
assert m in str(w[0].message)
947+
948+
949+
@dataclasses.dataclass(frozen=True)
950+
class DataClassA:
951+
a: str
952+
953+
954+
@dataclasses.dataclass(frozen=True)
955+
class DataClassB:
956+
b: str
957+
958+
959+
@pytest.mark.parametrize(
960+
'input,expected',
961+
[
962+
({True: '1'}, b'{"true":"1"}'),
963+
({1: '1'}, b'{"1":"1"}'),
964+
({2.3: '1'}, b'{"2.3":"1"}'),
965+
({'a': 'b'}, b'{"a":"b"}'),
966+
],
967+
)
968+
def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
969+
input: bool | int | float | str, expected: bytes
970+
) -> None:
971+
s = SchemaSerializer(
972+
core_schema.dict_schema(
973+
keys_schema=core_schema.union_schema(
974+
[
975+
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
976+
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
977+
]
978+
),
979+
values_schema=core_schema.str_schema(),
980+
)
981+
)
982+
983+
assert s.to_json(input, warnings='error') == expected
984+
985+
986+
def test_union_of_unions_of_models_with_tagged_union_json_serialization_invalid_variant(
987+
tagged_union_of_unions_schema: core_schema.UnionSchema,
988+
) -> None:
989+
s = SchemaSerializer(
990+
core_schema.dict_schema(keys_schema=tagged_union_of_unions_schema, values_schema=core_schema.str_schema())
991+
)
992+
993+
# All warnings should be available
994+
messages = [
995+
'Expected `ModelA` but got `ModelAlien`',
996+
'Expected `ModelB` but got `ModelAlien`',
997+
'Expected `ModelCat` but got `ModelAlien`',
998+
'Expected `ModelDog` but got `ModelAlien`',
999+
]
1000+
1001+
with warnings.catch_warnings(record=True) as w:
1002+
warnings.simplefilter('always')
1003+
s.to_python({ModelAlien(type_='alien'): 'coucou'})
1004+
for m in messages:
1005+
assert m in str(w[0].message)

0 commit comments

Comments
 (0)