@@ -60,53 +60,36 @@ def __init__(self, c, d):
6060@pytest .fixture (scope = 'module' )
6161def model_serializer () -> SchemaSerializer :
6262 return SchemaSerializer (
63- {
64- 'type' : 'union' ,
65- 'choices' : [
66- {
67- 'type' : 'model' ,
68- 'cls' : ModelA ,
69- 'schema' : {
70- 'type' : 'model-fields' ,
71- 'fields' : {
72- 'a' : {'type' : 'model-field' , 'schema' : {'type' : 'bytes' }},
73- 'b' : {
74- 'type' : 'model-field' ,
75- 'schema' : {
76- 'type' : 'float' ,
77- 'serialization' : {
78- 'type' : 'format' ,
79- 'formatting_string' : '0.1f' ,
80- 'when_used' : 'unless-none' ,
81- },
82- },
83- },
84- },
85- },
86- },
87- {
88- 'type' : 'model' ,
89- 'cls' : ModelB ,
90- 'schema' : {
91- 'type' : 'model-fields' ,
92- 'fields' : {
93- 'c' : {'type' : 'model-field' , 'schema' : {'type' : 'bytes' }},
94- 'd' : {
95- 'type' : 'model-field' ,
96- 'schema' : {
97- 'type' : 'float' ,
98- 'serialization' : {
99- 'type' : 'format' ,
100- 'formatting_string' : '0.2f' ,
101- 'when_used' : 'unless-none' ,
102- },
103- },
104- },
105- },
106- },
107- },
63+ core_schema .union_schema (
64+ [
65+ core_schema .model_schema (
66+ ModelA ,
67+ core_schema .model_fields_schema (
68+ {
69+ 'a' : core_schema .model_field (core_schema .bytes_schema ()),
70+ 'b' : core_schema .model_field (
71+ core_schema .float_schema (
72+ serialization = core_schema .format_ser_schema ('0.1f' , when_used = 'unless-none' )
73+ )
74+ ),
75+ }
76+ ),
77+ ),
78+ core_schema .model_schema (
79+ ModelB ,
80+ core_schema .model_fields_schema (
81+ {
82+ 'c' : core_schema .model_field (core_schema .bytes_schema ()),
83+ 'd' : core_schema .model_field (
84+ core_schema .float_schema (
85+ serialization = core_schema .format_ser_schema ('0.2f' , when_used = 'unless-none' )
86+ )
87+ ),
88+ }
89+ ),
90+ ),
10891 ],
109- }
92+ )
11093 )
11194
11295
@@ -778,3 +761,67 @@ class ModelB:
778761 model_b = ModelB (field = 1 )
779762 assert s .to_python (model_a ) == {'field' : 1 , 'TAG' : 'a' }
780763 assert s .to_python (model_b ) == {'field' : 1 , 'TAG' : 'b' }
764+
765+
766+ def test_union_model_wrap_serializer ():
767+ def wrap_serializer (value , handler ):
768+ return handler (value )
769+
770+ class Data :
771+ pass
772+
773+ class ModelA :
774+ a : Data
775+
776+ class ModelB :
777+ a : Data
778+
779+ model_serializer = SchemaSerializer (
780+ core_schema .union_schema (
781+ [
782+ core_schema .model_schema (
783+ ModelA ,
784+ core_schema .model_fields_schema (
785+ {
786+ 'a' : core_schema .model_field (
787+ core_schema .model_schema (
788+ Data ,
789+ core_schema .model_fields_schema ({}),
790+ )
791+ ),
792+ },
793+ ),
794+ serialization = core_schema .wrap_serializer_function_ser_schema (wrap_serializer ),
795+ ),
796+ core_schema .model_schema (
797+ ModelB ,
798+ core_schema .model_fields_schema (
799+ {
800+ 'a' : core_schema .model_field (
801+ core_schema .model_schema (
802+ Data ,
803+ core_schema .model_fields_schema ({}),
804+ )
805+ ),
806+ },
807+ ),
808+ serialization = core_schema .wrap_serializer_function_ser_schema (wrap_serializer ),
809+ ),
810+ ],
811+ )
812+ )
813+
814+ input_value = ModelA ()
815+ input_value .a = Data ()
816+
817+ assert model_serializer .to_python (input_value ) == {'a' : {}}
818+ assert model_serializer .to_python (input_value , mode = 'json' ) == {'a' : {}}
819+ assert model_serializer .to_json (input_value ) == b'{"a":{}}'
820+
821+ # add some additional attribute, should be ignored & not break serialization
822+
823+ input_value .a ._a = 'foo'
824+
825+ assert model_serializer .to_python (input_value ) == {'a' : {}}
826+ assert model_serializer .to_python (input_value , mode = 'json' ) == {'a' : {}}
827+ assert model_serializer .to_json (input_value ) == b'{"a":{}}'
0 commit comments