@@ -778,3 +778,228 @@ class ModelB:
778778 model_b = ModelB (field = 1 )
779779 assert s .to_python (model_a ) == {'field' : 1 , 'TAG' : 'a' }
780780 assert s .to_python (model_b ) == {'field' : 1 , 'TAG' : 'b' }
781+
782+
783+ class ModelDog :
784+ def __init__ (self , type_ : Literal ['dog' ]) -> None :
785+ self .type_ = 'dog'
786+
787+
788+ class ModelCat :
789+ def __init__ (self , type_ : Literal ['cat' ]) -> None :
790+ self .type_ = 'cat'
791+
792+
793+ class ModelAlien :
794+ def __init__ (self , type_ : Literal ['alien' ]) -> None :
795+ self .type_ = 'alien'
796+
797+
798+ @pytest .fixture
799+ def model_a_b_union_schema () -> core_schema .UnionSchema :
800+ return core_schema .union_schema (
801+ [
802+ core_schema .model_schema (
803+ cls = ModelA ,
804+ schema = core_schema .model_fields_schema (
805+ fields = {
806+ 'a' : core_schema .model_field (core_schema .str_schema ()),
807+ 'b' : core_schema .model_field (core_schema .str_schema ()),
808+ },
809+ ),
810+ ),
811+ core_schema .model_schema (
812+ cls = ModelB ,
813+ schema = core_schema .model_fields_schema (
814+ fields = {
815+ 'c' : core_schema .model_field (core_schema .str_schema ()),
816+ 'd' : core_schema .model_field (core_schema .str_schema ()),
817+ },
818+ ),
819+ ),
820+ ]
821+ )
822+
823+
824+ @pytest .fixture
825+ def union_of_unions_schema (model_a_b_union_schema : core_schema .UnionSchema ) -> core_schema .UnionSchema :
826+ return core_schema .union_schema (
827+ [
828+ model_a_b_union_schema ,
829+ core_schema .union_schema (
830+ [
831+ core_schema .model_schema (
832+ cls = ModelCat ,
833+ schema = core_schema .model_fields_schema (
834+ fields = {
835+ 'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
836+ },
837+ ),
838+ ),
839+ core_schema .model_schema (
840+ cls = ModelDog ,
841+ schema = core_schema .model_fields_schema (
842+ fields = {
843+ 'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
844+ },
845+ ),
846+ ),
847+ ]
848+ ),
849+ ]
850+ )
851+
852+
853+ @pytest .mark .parametrize (
854+ 'input,expected' ,
855+ [
856+ (ModelA (a = 'a' , b = 'b' ), {'a' : 'a' , 'b' : 'b' }),
857+ (ModelB (c = 'c' , d = 'd' ), {'c' : 'c' , 'd' : 'd' }),
858+ (ModelCat (type_ = 'cat' ), {'type_' : 'cat' }),
859+ (ModelDog (type_ = 'dog' ), {'type_' : 'dog' }),
860+ ],
861+ )
862+ def test_union_of_unions_of_models (union_of_unions_schema : core_schema .UnionSchema , input : Any , expected : Any ) -> None :
863+ s = SchemaSerializer (union_of_unions_schema )
864+ assert s .to_python (input , warnings = 'error' ) == expected
865+
866+
867+ def test_union_of_unions_of_models_invalid_variant (union_of_unions_schema : core_schema .UnionSchema ) -> None :
868+ s = SchemaSerializer (union_of_unions_schema )
869+ # All warnings should be available
870+ messages = [
871+ 'Expected `ModelA` but got `ModelAlien`' ,
872+ 'Expected `ModelB` but got `ModelAlien`' ,
873+ 'Expected `ModelCat` but got `ModelAlien`' ,
874+ 'Expected `ModelDog` but got `ModelAlien`' ,
875+ ]
876+
877+ with warnings .catch_warnings (record = True ) as w :
878+ warnings .simplefilter ('always' )
879+ s .to_python (ModelAlien (type_ = 'alien' ))
880+ for m in messages :
881+ assert m in str (w [0 ].message )
882+
883+
884+ @pytest .fixture
885+ def tagged_union_of_unions_schema (model_a_b_union_schema : core_schema .UnionSchema ) -> core_schema .UnionSchema :
886+ return core_schema .union_schema (
887+ [
888+ model_a_b_union_schema ,
889+ core_schema .tagged_union_schema (
890+ discriminator = 'type_' ,
891+ choices = {
892+ 'cat' : core_schema .model_schema (
893+ cls = ModelCat ,
894+ schema = core_schema .model_fields_schema (
895+ fields = {
896+ 'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
897+ },
898+ ),
899+ ),
900+ 'dog' : core_schema .model_schema (
901+ cls = ModelDog ,
902+ schema = core_schema .model_fields_schema (
903+ fields = {
904+ 'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
905+ },
906+ ),
907+ ),
908+ },
909+ ),
910+ ]
911+ )
912+
913+
914+ @pytest .mark .parametrize (
915+ 'input,expected' ,
916+ [
917+ (ModelA (a = 'a' , b = 'b' ), {'a' : 'a' , 'b' : 'b' }),
918+ (ModelB (c = 'c' , d = 'd' ), {'c' : 'c' , 'd' : 'd' }),
919+ (ModelCat (type_ = 'cat' ), {'type_' : 'cat' }),
920+ (ModelDog (type_ = 'dog' ), {'type_' : 'dog' }),
921+ ],
922+ )
923+ def test_union_of_unions_of_models_with_tagged_union (
924+ tagged_union_of_unions_schema : core_schema .UnionSchema , input : Any , expected : Any
925+ ) -> None :
926+ s = SchemaSerializer (tagged_union_of_unions_schema )
927+ assert s .to_python (input , warnings = 'error' ) == expected
928+
929+
930+ def test_union_of_unions_of_models_with_tagged_union_invalid_variant (
931+ tagged_union_of_unions_schema : core_schema .UnionSchema ,
932+ ) -> None :
933+ s = SchemaSerializer (tagged_union_of_unions_schema )
934+ # All warnings should be available
935+ messages = [
936+ 'Expected `ModelA` but got `ModelAlien`' ,
937+ 'Expected `ModelB` but got `ModelAlien`' ,
938+ 'Expected `ModelCat` but got `ModelAlien`' ,
939+ 'Expected `ModelDog` but got `ModelAlien`' ,
940+ ]
941+
942+ with warnings .catch_warnings (record = True ) as w :
943+ warnings .simplefilter ('always' )
944+ s .to_python (ModelAlien (type_ = 'alien' ))
945+ for m in messages :
946+ assert m in str (w [0 ].message )
947+
948+
949+ @dataclasses .dataclass (frozen = True )
950+ class DataClassA :
951+ a : str
952+
953+
954+ @dataclasses .dataclass (frozen = True )
955+ class DataClassB :
956+ b : str
957+
958+
959+ @pytest .mark .parametrize (
960+ 'input,expected' ,
961+ [
962+ ({True : '1' }, b'{"true":"1"}' ),
963+ ({1 : '1' }, b'{"1":"1"}' ),
964+ ({2.3 : '1' }, b'{"2.3":"1"}' ),
965+ ({'a' : 'b' }, b'{"a":"b"}' ),
966+ ],
967+ )
968+ def test_union_of_unions_of_models_with_tagged_union_json_key_serialization (
969+ input : bool | int | float | str , expected : bytes
970+ ) -> None :
971+ s = SchemaSerializer (
972+ core_schema .dict_schema (
973+ keys_schema = core_schema .union_schema (
974+ [
975+ core_schema .union_schema ([core_schema .bool_schema (), core_schema .int_schema ()]),
976+ core_schema .union_schema ([core_schema .float_schema (), core_schema .str_schema ()]),
977+ ]
978+ ),
979+ values_schema = core_schema .str_schema (),
980+ )
981+ )
982+
983+ assert s .to_json (input , warnings = 'error' ) == expected
984+
985+
986+ def test_union_of_unions_of_models_with_tagged_union_json_serialization_invalid_variant (
987+ tagged_union_of_unions_schema : core_schema .UnionSchema ,
988+ ) -> None :
989+ s = SchemaSerializer (
990+ core_schema .dict_schema (keys_schema = tagged_union_of_unions_schema , values_schema = core_schema .str_schema ())
991+ )
992+
993+ # All warnings should be available
994+ messages = [
995+ 'Expected `ModelA` but got `ModelAlien`' ,
996+ 'Expected `ModelB` but got `ModelAlien`' ,
997+ 'Expected `ModelCat` but got `ModelAlien`' ,
998+ 'Expected `ModelDog` but got `ModelAlien`' ,
999+ ]
1000+
1001+ with warnings .catch_warnings (record = True ) as w :
1002+ warnings .simplefilter ('always' )
1003+ s .to_python ({ModelAlien (type_ = 'alien' ): 'coucou' })
1004+ for m in messages :
1005+ assert m in str (w [0 ].message )
0 commit comments