@@ -948,6 +948,43 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
948
948
assert m in str (w [0 ].message )
949
949
950
950
951
+ def test_mixed_union_models_and_other_types () -> None :
952
+ s = SchemaSerializer (
953
+ core_schema .union_schema (
954
+ [
955
+ core_schema .tagged_union_schema (
956
+ discriminator = 'type_' ,
957
+ choices = {
958
+ 'cat' : core_schema .model_schema (
959
+ cls = ModelCat ,
960
+ schema = core_schema .model_fields_schema (
961
+ fields = {
962
+ 'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
963
+ },
964
+ ),
965
+ ),
966
+ 'dog' : core_schema .model_schema (
967
+ cls = ModelDog ,
968
+ schema = core_schema .model_fields_schema (
969
+ fields = {
970
+ 'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
971
+ },
972
+ ),
973
+ ),
974
+ },
975
+ ),
976
+ core_schema .str_schema (),
977
+ ]
978
+ )
979
+ )
980
+
981
+ assert s .to_python (ModelCat (type_ = 'cat' ), warnings = 'error' ) == {'type_' : 'cat' }
982
+ assert s .to_python (ModelDog (type_ = 'dog' ), warnings = 'error' ) == {'type_' : 'dog' }
983
+ # note, this fails as ModelCat and ModelDog (discriminator warnings, etc), but the warnings
984
+ # don't bubble up to this level :)
985
+ assert s .to_python ('a string' , warnings = 'error' ) == 'a string'
986
+
987
+
951
988
@pytest .mark .parametrize (
952
989
'input,expected' ,
953
990
[
@@ -1000,3 +1037,28 @@ def test_union_of_unions_of_models_with_tagged_union_json_serialization(
1000
1037
)
1001
1038
1002
1039
assert s .to_json (input , warnings = 'error' ) == expected
1040
+
1041
+
1042
+ def test_discriminated_union_ser_with_typed_dict () -> None :
1043
+ v = SchemaSerializer (
1044
+ core_schema .tagged_union_schema (
1045
+ {
1046
+ 'a' : core_schema .typed_dict_schema (
1047
+ {
1048
+ 'type' : core_schema .typed_dict_field (core_schema .literal_schema (['a' ])),
1049
+ 'a' : core_schema .typed_dict_field (core_schema .int_schema ()),
1050
+ }
1051
+ ),
1052
+ 'b' : core_schema .typed_dict_schema (
1053
+ {
1054
+ 'type' : core_schema .typed_dict_field (core_schema .literal_schema (['b' ])),
1055
+ 'b' : core_schema .typed_dict_field (core_schema .str_schema ()),
1056
+ }
1057
+ ),
1058
+ },
1059
+ discriminator = 'type' ,
1060
+ )
1061
+ )
1062
+
1063
+ assert v .to_python ({'type' : 'a' , 'a' : 1 }, warnings = 'error' ) == {'type' : 'a' , 'a' : 1 }
1064
+ assert v .to_python ({'type' : 'b' , 'b' : 'foo' }, warnings = 'error' ) == {'type' : 'b' , 'b' : 'foo' }
0 commit comments