@@ -141,8 +141,10 @@ class FieldMetadata:
141
141
number : int
142
142
# Protobuf type name
143
143
proto_type : str
144
+
144
145
# Map information if the proto_type is a map
145
- map_types : tuple [str , str ] | None = None
146
+ map_meta : tuple [FieldMetadata , FieldMetadata ] | None = None
147
+
146
148
# Groups several "one-of" fields together
147
149
group : str | None = None
148
150
@@ -160,12 +162,24 @@ def get(field: dataclasses.Field) -> FieldMetadata:
160
162
return field .metadata ["betterproto" ]
161
163
162
164
165
+ def map_meta (
166
+ proto_type_1 : str ,
167
+ proto_type_2 : str ,
168
+ * ,
169
+ unwrap_2 : Callable [[], type ] | None = None ,
170
+ ) -> tuple [FieldMetadata , FieldMetadata ]:
171
+ key_meta = FieldMetadata (1 , proto_type_1 )
172
+ value_meta = FieldMetadata (2 , proto_type_2 , unwrap = unwrap_2 )
173
+
174
+ return key_meta , value_meta
175
+
176
+
163
177
def field (
164
178
number : int ,
165
179
proto_type : str ,
166
180
* ,
167
181
default_factory : Callable [[], Any ] | None = None ,
168
- map_types : tuple [str , str ] | None = None ,
182
+ map_meta : tuple [FieldMetadata , FieldMetadata ] | None = None ,
169
183
group : str | None = None ,
170
184
unwrap : Callable [[], type ] | None = None ,
171
185
optional : bool = False ,
@@ -202,7 +216,7 @@ def field(
202
216
203
217
return dataclasses .field (
204
218
default_factory = default_factory ,
205
- metadata = {"betterproto" : FieldMetadata (number , proto_type , map_types , group , unwrap , optional , repeated )},
219
+ metadata = {"betterproto" : FieldMetadata (number , proto_type , map_meta , group , unwrap , optional , repeated )},
206
220
)
207
221
208
222
@@ -485,7 +499,7 @@ def _get_cls_by_field(cls: type[Message], fields: Iterable[dataclasses.Field]) -
485
499
for field_ in fields :
486
500
meta = FieldMetadata .get (field_ )
487
501
if meta .proto_type == TYPE_MAP :
488
- assert meta .map_types
502
+ assert meta .map_meta
489
503
kt = cls ._cls_for (field_ , index = 0 )
490
504
vt = cls ._cls_for (field_ , index = 1 )
491
505
field_cls [field_ .name ] = dataclasses .make_dataclass (
@@ -494,12 +508,12 @@ def _get_cls_by_field(cls: type[Message], fields: Iterable[dataclasses.Field]) -
494
508
(
495
509
"key" ,
496
510
kt ,
497
- field (1 , meta .map_types [0 ], default_factory = kt ),
511
+ field (1 , meta .map_meta [0 ]. proto_type , default_factory = kt ),
498
512
),
499
513
(
500
514
"value" ,
501
515
vt ,
502
- field (2 , meta .map_types [1 ], default_factory = vt ),
516
+ field (2 , meta .map_meta [1 ]. proto_type , default_factory = vt ),
503
517
),
504
518
],
505
519
bases = (Message ,),
@@ -720,9 +734,9 @@ def __bytes__(self) -> bytes:
720
734
721
735
elif isinstance (value , dict ):
722
736
for k , v in value .items ():
723
- assert meta .map_types
724
- sk = _serialize_single (1 , meta .map_types [0 ], k )
725
- sv = _serialize_single (2 , meta .map_types [1 ], v )
737
+ assert meta .map_meta
738
+ sk = _serialize_single (1 , meta .map_meta [0 ]. proto_type , k )
739
+ sv = _serialize_single (2 , meta .map_meta [1 ]. proto_type , v , unwrap = meta . map_meta [ 1 ]. unwrap )
726
740
stream .write (_serialize_single (meta .number , meta .proto_type , sk + sv ))
727
741
else :
728
742
stream .write (
@@ -1007,13 +1021,12 @@ def to_dict(
1007
1021
output [cased_name ] = output_value
1008
1022
1009
1023
elif meta .proto_type == TYPE_MAP :
1010
- assert meta .map_types is not None
1024
+ assert meta .map_meta is not None
1011
1025
field_type_k = field_types [field_name ].__args__ [0 ]
1012
1026
field_type_v = field_types [field_name ].__args__ [1 ]
1013
- # TODO wrapped types don't work in maps
1014
1027
output_map = {
1015
- _value_to_dict (k , meta .map_types [0 ], field_type_k , None , ** kwargs )[0 ]: _value_to_dict (
1016
- v , meta .map_types [1 ], field_type_v , None , ** kwargs
1028
+ _value_to_dict (k , meta .map_meta [0 ]. proto_type , field_type_k , None , ** kwargs )[0 ]: _value_to_dict (
1029
+ v , meta .map_meta [1 ]. proto_type , field_type_v , meta . map_meta [ 1 ]. unwrap , ** kwargs
1017
1030
)[0 ]
1018
1031
for k , v in value .items ()
1019
1032
}
@@ -1058,7 +1071,7 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
1058
1071
value , meta .proto_type , cls ._betterproto .cls_by_field [field_name ], meta .unwrap
1059
1072
)
1060
1073
1061
- elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
1074
+ elif meta .map_meta and meta .map_meta [1 ]. proto_type == TYPE_MESSAGE :
1062
1075
sub_cls = cls ._betterproto .cls_by_field [f"{ field_name } .value" ]
1063
1076
value = {k : sub_cls .from_dict (v ) for k , v in value .items ()}
1064
1077
else :
@@ -1209,7 +1222,7 @@ def from_pydict(self: T, value: Mapping[str, Any]) -> T:
1209
1222
v = value [key ]
1210
1223
else :
1211
1224
v = cls ().from_pydict (value [key ])
1212
- elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
1225
+ elif meta .map_meta and meta .map_meta [1 ]. proto_type == TYPE_MESSAGE :
1213
1226
v = getattr (self , field_name )
1214
1227
cls = self ._betterproto .cls_by_field [f"{ field_name } .value" ]
1215
1228
for k in value [key ]:
0 commit comments