| 
 | 1 | +from __future__ import annotations  | 
 | 2 | + | 
1 | 3 | import dataclasses  | 
2 | 4 | import json  | 
3 | 5 | import uuid  | 
@@ -778,3 +780,196 @@ class ModelB:  | 
778 | 780 |     model_b = ModelB(field=1)  | 
779 | 781 |     assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}  | 
780 | 782 |     assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}  | 
 | 783 | + | 
 | 784 | + | 
 | 785 | +class ModelDog:  | 
 | 786 | +    def __init__(self, type_: Literal['dog']) -> None:  | 
 | 787 | +        self.type_ = 'dog'  | 
 | 788 | + | 
 | 789 | + | 
 | 790 | +class ModelCat:  | 
 | 791 | +    def __init__(self, type_: Literal['cat']) -> None:  | 
 | 792 | +        self.type_ = 'cat'  | 
 | 793 | + | 
 | 794 | + | 
 | 795 | +class ModelAlien:  | 
 | 796 | +    def __init__(self, type_: Literal['alien']) -> None:  | 
 | 797 | +        self.type_ = 'alien'  | 
 | 798 | + | 
 | 799 | + | 
 | 800 | +@pytest.fixture  | 
 | 801 | +def model_a_b_union_schema() -> core_schema.UnionSchema:  | 
 | 802 | +    return core_schema.union_schema(  | 
 | 803 | +        [  | 
 | 804 | +            core_schema.model_schema(  | 
 | 805 | +                cls=ModelA,  | 
 | 806 | +                schema=core_schema.model_fields_schema(  | 
 | 807 | +                    fields={  | 
 | 808 | +                        'a': core_schema.model_field(core_schema.str_schema()),  | 
 | 809 | +                        'b': core_schema.model_field(core_schema.str_schema()),  | 
 | 810 | +                    },  | 
 | 811 | +                ),  | 
 | 812 | +            ),  | 
 | 813 | +            core_schema.model_schema(  | 
 | 814 | +                cls=ModelB,  | 
 | 815 | +                schema=core_schema.model_fields_schema(  | 
 | 816 | +                    fields={  | 
 | 817 | +                        'c': core_schema.model_field(core_schema.str_schema()),  | 
 | 818 | +                        'd': core_schema.model_field(core_schema.str_schema()),  | 
 | 819 | +                    },  | 
 | 820 | +                ),  | 
 | 821 | +            ),  | 
 | 822 | +        ]  | 
 | 823 | +    )  | 
 | 824 | + | 
 | 825 | + | 
 | 826 | +@pytest.fixture  | 
 | 827 | +def union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:  | 
 | 828 | +    return core_schema.union_schema(  | 
 | 829 | +        [  | 
 | 830 | +            model_a_b_union_schema,  | 
 | 831 | +            core_schema.union_schema(  | 
 | 832 | +                [  | 
 | 833 | +                    core_schema.model_schema(  | 
 | 834 | +                        cls=ModelCat,  | 
 | 835 | +                        schema=core_schema.model_fields_schema(  | 
 | 836 | +                            fields={  | 
 | 837 | +                                'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),  | 
 | 838 | +                            },  | 
 | 839 | +                        ),  | 
 | 840 | +                    ),  | 
 | 841 | +                    core_schema.model_schema(  | 
 | 842 | +                        cls=ModelDog,  | 
 | 843 | +                        schema=core_schema.model_fields_schema(  | 
 | 844 | +                            fields={  | 
 | 845 | +                                'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),  | 
 | 846 | +                            },  | 
 | 847 | +                        ),  | 
 | 848 | +                    ),  | 
 | 849 | +                ]  | 
 | 850 | +            ),  | 
 | 851 | +        ]  | 
 | 852 | +    )  | 
 | 853 | + | 
 | 854 | + | 
 | 855 | +@pytest.mark.parametrize(  | 
 | 856 | +    'input,expected',  | 
 | 857 | +    [  | 
 | 858 | +        (ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),  | 
 | 859 | +        (ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),  | 
 | 860 | +        (ModelCat(type_='cat'), {'type_': 'cat'}),  | 
 | 861 | +        (ModelDog(type_='dog'), {'type_': 'dog'}),  | 
 | 862 | +    ],  | 
 | 863 | +)  | 
 | 864 | +def test_union_of_unions_of_models(union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any) -> None:  | 
 | 865 | +    s = SchemaSerializer(union_of_unions_schema)  | 
 | 866 | +    assert s.to_python(input, warnings='error') == expected  | 
 | 867 | + | 
 | 868 | + | 
 | 869 | +def test_union_of_unions_of_models_invalid_variant(union_of_unions_schema: core_schema.UnionSchema) -> None:  | 
 | 870 | +    s = SchemaSerializer(union_of_unions_schema)  | 
 | 871 | +    # All warnings should be available  | 
 | 872 | +    messages = [  | 
 | 873 | +        'Expected `ModelA` but got `ModelAlien`',  | 
 | 874 | +        'Expected `ModelB` but got `ModelAlien`',  | 
 | 875 | +        'Expected `ModelCat` but got `ModelAlien`',  | 
 | 876 | +        'Expected `ModelDog` but got `ModelAlien`',  | 
 | 877 | +    ]  | 
 | 878 | + | 
 | 879 | +    with warnings.catch_warnings(record=True) as w:  | 
 | 880 | +        warnings.simplefilter('always')  | 
 | 881 | +        s.to_python(ModelAlien(type_='alien'))  | 
 | 882 | +        for m in messages:  | 
 | 883 | +            assert m in str(w[0].message)  | 
 | 884 | + | 
 | 885 | + | 
 | 886 | +@pytest.fixture  | 
 | 887 | +def tagged_union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:  | 
 | 888 | +    return core_schema.union_schema(  | 
 | 889 | +        [  | 
 | 890 | +            model_a_b_union_schema,  | 
 | 891 | +            core_schema.tagged_union_schema(  | 
 | 892 | +                discriminator='type_',  | 
 | 893 | +                choices={  | 
 | 894 | +                    'cat': core_schema.model_schema(  | 
 | 895 | +                        cls=ModelCat,  | 
 | 896 | +                        schema=core_schema.model_fields_schema(  | 
 | 897 | +                            fields={  | 
 | 898 | +                                'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),  | 
 | 899 | +                            },  | 
 | 900 | +                        ),  | 
 | 901 | +                    ),  | 
 | 902 | +                    'dog': core_schema.model_schema(  | 
 | 903 | +                        cls=ModelDog,  | 
 | 904 | +                        schema=core_schema.model_fields_schema(  | 
 | 905 | +                            fields={  | 
 | 906 | +                                'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),  | 
 | 907 | +                            },  | 
 | 908 | +                        ),  | 
 | 909 | +                    ),  | 
 | 910 | +                },  | 
 | 911 | +            ),  | 
 | 912 | +        ]  | 
 | 913 | +    )  | 
 | 914 | + | 
 | 915 | + | 
 | 916 | +@pytest.mark.parametrize(  | 
 | 917 | +    'input,expected',  | 
 | 918 | +    [  | 
 | 919 | +        (ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),  | 
 | 920 | +        (ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),  | 
 | 921 | +        (ModelCat(type_='cat'), {'type_': 'cat'}),  | 
 | 922 | +        (ModelDog(type_='dog'), {'type_': 'dog'}),  | 
 | 923 | +    ],  | 
 | 924 | +)  | 
 | 925 | +def test_union_of_unions_of_models_with_tagged_union(  | 
 | 926 | +    tagged_union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any  | 
 | 927 | +) -> None:  | 
 | 928 | +    s = SchemaSerializer(tagged_union_of_unions_schema)  | 
 | 929 | +    assert s.to_python(input, warnings='error') == expected  | 
 | 930 | + | 
 | 931 | + | 
 | 932 | +def test_union_of_unions_of_models_with_tagged_union_invalid_variant(  | 
 | 933 | +    tagged_union_of_unions_schema: core_schema.UnionSchema,  | 
 | 934 | +) -> None:  | 
 | 935 | +    s = SchemaSerializer(tagged_union_of_unions_schema)  | 
 | 936 | +    # All warnings should be available  | 
 | 937 | +    messages = [  | 
 | 938 | +        'Expected `ModelA` but got `ModelAlien`',  | 
 | 939 | +        'Expected `ModelB` but got `ModelAlien`',  | 
 | 940 | +        'Expected `ModelCat` but got `ModelAlien`',  | 
 | 941 | +        'Expected `ModelDog` but got `ModelAlien`',  | 
 | 942 | +    ]  | 
 | 943 | + | 
 | 944 | +    with warnings.catch_warnings(record=True) as w:  | 
 | 945 | +        warnings.simplefilter('always')  | 
 | 946 | +        s.to_python(ModelAlien(type_='alien'))  | 
 | 947 | +        for m in messages:  | 
 | 948 | +            assert m in str(w[0].message)  | 
 | 949 | + | 
 | 950 | + | 
 | 951 | +@pytest.mark.parametrize(  | 
 | 952 | +    'input,expected',  | 
 | 953 | +    [  | 
 | 954 | +        ({True: '1'}, b'{"true":"1"}'),  | 
 | 955 | +        ({1: '1'}, b'{"1":"1"}'),  | 
 | 956 | +        ({2.3: '1'}, b'{"2.3":"1"}'),  | 
 | 957 | +        ({'a': 'b'}, b'{"a":"b"}'),  | 
 | 958 | +    ],  | 
 | 959 | +)  | 
 | 960 | +def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(  | 
 | 961 | +    input: bool | int | float | str, expected: bytes  | 
 | 962 | +) -> None:  | 
 | 963 | +    s = SchemaSerializer(  | 
 | 964 | +        core_schema.dict_schema(  | 
 | 965 | +            keys_schema=core_schema.union_schema(  | 
 | 966 | +                [  | 
 | 967 | +                    core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),  | 
 | 968 | +                    core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),  | 
 | 969 | +                ]  | 
 | 970 | +            ),  | 
 | 971 | +            values_schema=core_schema.str_schema(),  | 
 | 972 | +        )  | 
 | 973 | +    )  | 
 | 974 | + | 
 | 975 | +    assert s.to_json(input, warnings='error') == expected  | 
0 commit comments