@@ -790,51 +790,92 @@ def __init__(self, type_: Literal['cat']) -> None:
790790 self .type_ = 'cat'
791791
792792
793- def test_union_of_unions_of_models () -> None :
793+ @pytest .fixture
794+ def model_a_b_union_schema () -> core_schema .UnionSchema :
795+ return core_schema .union_schema (
796+ [
797+ core_schema .model_schema (
798+ cls = ModelA ,
799+ schema = core_schema .model_fields_schema (
800+ fields = {
801+ 'a' : core_schema .model_field (core_schema .str_schema ()),
802+ 'b' : core_schema .model_field (core_schema .str_schema ()),
803+ },
804+ ),
805+ ),
806+ core_schema .model_schema (
807+ cls = ModelB ,
808+ schema = core_schema .model_fields_schema (
809+ fields = {
810+ 'c' : core_schema .model_field (core_schema .str_schema ()),
811+ 'd' : core_schema .model_field (core_schema .str_schema ()),
812+ },
813+ ),
814+ ),
815+ ]
816+ )
817+
818+
819+ def test_union_of_unions_of_models (model_a_b_union_schema : core_schema .UnionSchema ) -> None :
794820 s = SchemaSerializer (
795821 core_schema .union_schema (
796822 [
823+ model_a_b_union_schema ,
797824 core_schema .union_schema (
798825 [
799826 core_schema .model_schema (
800- cls = ModelA ,
827+ cls = ModelCat ,
801828 schema = core_schema .model_fields_schema (
802829 fields = {
803- 'a' : core_schema .model_field (core_schema .str_schema ()),
804- 'b' : core_schema .model_field (core_schema .str_schema ()),
830+ 'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
805831 },
806832 ),
807833 ),
808834 core_schema .model_schema (
809- cls = ModelB ,
835+ cls = ModelDog ,
810836 schema = core_schema .model_fields_schema (
811837 fields = {
812- 'c' : core_schema .model_field (core_schema .str_schema ()),
813- 'd' : core_schema .model_field (core_schema .str_schema ()),
838+ 'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
814839 },
815840 ),
816841 ),
817842 ]
818843 ),
819- core_schema .union_schema (
820- [
821- core_schema .model_schema (
844+ ]
845+ )
846+ )
847+
848+ assert s .to_python (ModelA (a = 'a' , b = 'b' ), warnings = 'error' ) == {'a' : 'a' , 'b' : 'b' }
849+ assert s .to_python (ModelB (c = 'c' , d = 'd' ), warnings = 'error' ) == {'c' : 'c' , 'd' : 'd' }
850+ assert s .to_python (ModelCat (type_ = 'cat' ), warnings = 'error' ) == {'type_' : 'cat' }
851+ assert s .to_python (ModelDog (type_ = 'dog' ), warnings = 'error' ) == {'type_' : 'dog' }
852+
853+
854+ def test_union_of_unions_of_models_with_tagged_union (model_a_b_union_schema : core_schema .UnionSchema ) -> None :
855+ s = SchemaSerializer (
856+ core_schema .union_schema (
857+ [
858+ model_a_b_union_schema ,
859+ core_schema .tagged_union_schema (
860+ discriminator = 'type_' ,
861+ choices = {
862+ 'cat' : core_schema .model_schema (
822863 cls = ModelCat ,
823864 schema = core_schema .model_fields_schema (
824865 fields = {
825866 'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
826867 },
827868 ),
828869 ),
829- core_schema .model_schema (
870+ 'dog' : core_schema .model_schema (
830871 cls = ModelDog ,
831872 schema = core_schema .model_fields_schema (
832873 fields = {
833874 'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
834875 },
835876 ),
836877 ),
837- ]
878+ },
838879 ),
839880 ]
840881 )
0 commit comments