@@ -948,6 +948,43 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
948948            assert  m  in  str (w [0 ].message )
949949
950950
951+ def  test_mixed_union_models_and_other_types () ->  None :
952+     s  =  SchemaSerializer (
953+         core_schema .union_schema (
954+             [
955+                 core_schema .tagged_union_schema (
956+                     discriminator = 'type_' ,
957+                     choices = {
958+                         'cat' : core_schema .model_schema (
959+                             cls = ModelCat ,
960+                             schema = core_schema .model_fields_schema (
961+                                 fields = {
962+                                     'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
963+                                 },
964+                             ),
965+                         ),
966+                         'dog' : core_schema .model_schema (
967+                             cls = ModelDog ,
968+                             schema = core_schema .model_fields_schema (
969+                                 fields = {
970+                                     'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
971+                                 },
972+                             ),
973+                         ),
974+                     },
975+                 ),
976+                 core_schema .str_schema (),
977+             ]
978+         )
979+     )
980+ 
981+     assert  s .to_python (ModelCat (type_ = 'cat' ), warnings = 'error' ) ==  {'type_' : 'cat' }
982+     assert  s .to_python (ModelDog (type_ = 'dog' ), warnings = 'error' ) ==  {'type_' : 'dog' }
983+     # note, this fails as ModelCat and ModelDog (discriminator warnings, etc), but the warnings 
984+     # don't bubble up to this level :) 
985+     assert  s .to_python ('a string' , warnings = 'error' ) ==  'a string' 
986+ 
987+ 
951988@pytest .mark .parametrize ( 
952989    'input,expected' , 
953990    [ 
@@ -1000,3 +1037,28 @@ def test_union_of_unions_of_models_with_tagged_union_json_serialization(
10001037    )
10011038
10021039    assert  s .to_json (input , warnings = 'error' ) ==  expected 
1040+ 
1041+ 
1042+ def  test_discriminated_union_ser_with_typed_dict () ->  None :
1043+     v  =  SchemaSerializer (
1044+         core_schema .tagged_union_schema (
1045+             {
1046+                 'a' : core_schema .typed_dict_schema (
1047+                     {
1048+                         'type' : core_schema .typed_dict_field (core_schema .literal_schema (['a' ])),
1049+                         'a' : core_schema .typed_dict_field (core_schema .int_schema ()),
1050+                     }
1051+                 ),
1052+                 'b' : core_schema .typed_dict_schema (
1053+                     {
1054+                         'type' : core_schema .typed_dict_field (core_schema .literal_schema (['b' ])),
1055+                         'b' : core_schema .typed_dict_field (core_schema .str_schema ()),
1056+                     }
1057+                 ),
1058+             },
1059+             discriminator = 'type' ,
1060+         )
1061+     )
1062+ 
1063+     assert  v .to_python ({'type' : 'a' , 'a' : 1 }, warnings = 'error' ) ==  {'type' : 'a' , 'a' : 1 }
1064+     assert  v .to_python ({'type' : 'b' , 'b' : 'foo' }, warnings = 'error' ) ==  {'type' : 'b' , 'b' : 'foo' }
0 commit comments