13
13
Type ,
14
14
Iterable ,
15
15
TypeVar ,
16
+ Optional ,
16
17
)
17
18
import dataclasses
18
19
36
37
TYPE_STRING = "string"
37
38
TYPE_BYTES = "bytes"
38
39
TYPE_MESSAGE = "message"
40
+ TYPE_MAP = "map"
39
41
40
42
41
43
# Fields that use a fixed amount of space (4 or 8 bytes)
87
89
88
90
WIRE_FIXED_32_TYPES = [TYPE_FLOAT , TYPE_FIXED32 , TYPE_SFIXED32 ]
89
91
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE , TYPE_FIXED64 , TYPE_SFIXED64 ]
90
- WIRE_LEN_DELIM_TYPES = [TYPE_STRING , TYPE_BYTES , TYPE_MESSAGE ]
92
+ WIRE_LEN_DELIM_TYPES = [TYPE_STRING , TYPE_BYTES , TYPE_MESSAGE , TYPE_MAP ]
91
93
92
94
93
95
@dataclasses .dataclass (frozen = True )
@@ -98,6 +100,8 @@ class FieldMetadata:
98
100
number : int
99
101
# Protobuf type name
100
102
proto_type : str
103
+ # Map information if the proto_type is a map
104
+ map_types : Optional [Tuple [str , str ]]
101
105
# Default value if given
102
106
default : Any
103
107
@@ -107,10 +111,14 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
107
111
return field .metadata ["betterproto" ]
108
112
109
113
110
- def field (number : int , proto_type : str , default : Any ) -> dataclasses .Field :
114
+ def dataclass_field (
115
+ number : int ,
116
+ proto_type : str ,
117
+ default : Any ,
118
+ map_types : Optional [Tuple [str , str ]] = None ,
119
+ ** kwargs : dict ,
120
+ ) -> dataclasses .Field :
111
121
"""Creates a dataclass field with attached protobuf metadata."""
112
- kwargs = {}
113
-
114
122
if callable (default ):
115
123
kwargs ["default_factory" ] = default
116
124
elif isinstance (default , dict ) or isinstance (default , list ):
@@ -119,7 +127,8 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
119
127
kwargs ["default" ] = default
120
128
121
129
return dataclasses .field (
122
- ** kwargs , metadata = {"betterproto" : FieldMetadata (number , proto_type , default )}
130
+ ** kwargs ,
131
+ metadata = {"betterproto" : FieldMetadata (number , proto_type , map_types , default )},
123
132
)
124
133
125
134
@@ -129,63 +138,69 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
129
138
130
139
131
140
def enum_field (number : int , default : Union [int , Type [Iterable ]] = 0 ) -> Any :
132
- return field (number , TYPE_ENUM , default = default )
141
+ return dataclass_field (number , TYPE_ENUM , default = default )
133
142
134
143
135
144
def int32_field (number : int , default : Union [int , Type [Iterable ]] = 0 ) -> Any :
136
- return field (number , TYPE_INT32 , default = default )
145
+ return dataclass_field (number , TYPE_INT32 , default = default )
137
146
138
147
139
148
def int64_field (number : int , default : int = 0 ) -> Any :
140
- return field (number , TYPE_INT64 , default = default )
149
+ return dataclass_field (number , TYPE_INT64 , default = default )
141
150
142
151
143
152
def uint32_field (number : int , default : int = 0 ) -> Any :
144
- return field (number , TYPE_UINT32 , default = default )
153
+ return dataclass_field (number , TYPE_UINT32 , default = default )
145
154
146
155
147
156
def uint64_field (number : int , default : int = 0 ) -> Any :
148
- return field (number , TYPE_UINT64 , default = default )
157
+ return dataclass_field (number , TYPE_UINT64 , default = default )
149
158
150
159
151
160
def sint32_field (number : int , default : int = 0 ) -> Any :
152
- return field (number , TYPE_SINT32 , default = default )
161
+ return dataclass_field (number , TYPE_SINT32 , default = default )
153
162
154
163
155
164
def sint64_field (number : int , default : int = 0 ) -> Any :
156
- return field (number , TYPE_SINT64 , default = default )
165
+ return dataclass_field (number , TYPE_SINT64 , default = default )
157
166
158
167
159
168
def float_field (number : int , default : float = 0.0 ) -> Any :
160
- return field (number , TYPE_FLOAT , default = default )
169
+ return dataclass_field (number , TYPE_FLOAT , default = default )
161
170
162
171
163
172
def double_field (number : int , default : float = 0.0 ) -> Any :
164
- return field (number , TYPE_DOUBLE , default = default )
173
+ return dataclass_field (number , TYPE_DOUBLE , default = default )
165
174
166
175
167
176
def fixed32_field (number : int , default : float = 0.0 ) -> Any :
168
- return field (number , TYPE_FIXED32 , default = default )
177
+ return dataclass_field (number , TYPE_FIXED32 , default = default )
169
178
170
179
171
180
def fixed64_field (number : int , default : float = 0.0 ) -> Any :
172
- return field (number , TYPE_FIXED64 , default = default )
181
+ return dataclass_field (number , TYPE_FIXED64 , default = default )
173
182
174
183
175
184
def sfixed32_field (number : int , default : float = 0.0 ) -> Any :
176
- return field (number , TYPE_SFIXED32 , default = default )
185
+ return dataclass_field (number , TYPE_SFIXED32 , default = default )
177
186
178
187
179
188
def sfixed64_field (number : int , default : float = 0.0 ) -> Any :
180
- return field (number , TYPE_SFIXED64 , default = default )
189
+ return dataclass_field (number , TYPE_SFIXED64 , default = default )
181
190
182
191
183
192
def string_field (number : int , default : str = "" ) -> Any :
184
- return field (number , TYPE_STRING , default = default )
193
+ return dataclass_field (number , TYPE_STRING , default = default )
185
194
186
195
187
196
def message_field (number : int , default : Type ["Message" ]) -> Any :
188
- return field (number , TYPE_MESSAGE , default = default )
197
+ return dataclass_field (number , TYPE_MESSAGE , default = default )
198
+
199
+
200
+ def map_field (number : int , key_type : str , value_type : str ) -> Any :
201
+ return dataclass_field (
202
+ number , TYPE_MAP , default = dict , map_types = (key_type , value_type )
203
+ )
189
204
190
205
191
206
def _pack_fmt (proto_type : str ) -> str :
@@ -354,6 +369,14 @@ def __bytes__(self) -> bytes:
354
369
else :
355
370
for item in value :
356
371
output += _serialize_single (meta .number , meta .proto_type , item )
372
+ elif isinstance (value , dict ):
373
+ if not len (value ):
374
+ continue
375
+
376
+ for k , v in value .items ():
377
+ sk = _serialize_single (1 , meta .map_types [0 ], k )
378
+ sv = _serialize_single (2 , meta .map_types [1 ], v )
379
+ output += _serialize_single (meta .number , meta .proto_type , sk + sv )
357
380
else :
358
381
if value == field .default :
359
382
continue
@@ -377,23 +400,35 @@ def _postprocess_single(
377
400
) -> Any :
378
401
"""Adjusts values after parsing."""
379
402
if wire_type == WIRE_VARINT :
380
- if meta .proto_type in ["int32" , "int64" ]:
403
+ if meta .proto_type in [TYPE_INT32 , TYPE_INT64 ]:
381
404
bits = int (meta .proto_type [3 :])
382
405
value = value & ((1 << bits ) - 1 )
383
406
signbit = 1 << (bits - 1 )
384
407
value = int ((value ^ signbit ) - signbit )
385
- elif meta .proto_type in ["sint32" , "sint64" ]:
408
+ elif meta .proto_type in [TYPE_SINT32 , TYPE_SINT64 ]:
386
409
# Undo zig-zag encoding
387
410
value = (value >> 1 ) ^ (- (value & 1 ))
388
411
elif wire_type in [WIRE_FIXED_32 , WIRE_FIXED_64 ]:
389
412
fmt = _pack_fmt (meta .proto_type )
390
413
value = struct .unpack (fmt , value )[0 ]
391
414
elif wire_type == WIRE_LEN_DELIM :
392
- if meta .proto_type in ["string" ]:
415
+ if meta .proto_type in [TYPE_STRING ]:
393
416
value = value .decode ("utf-8" )
394
- elif meta .proto_type in ["message" ]:
417
+ elif meta .proto_type in [TYPE_MESSAGE ]:
395
418
cls = self ._cls_for (field )
396
419
value = cls ().parse (value )
420
+ elif meta .proto_type in [TYPE_MAP ]:
421
+ # TODO: This is slow, use a cache to make it faster since each
422
+ # key/value pair will recreate the class.
423
+ Entry = dataclasses .make_dataclass (
424
+ "Entry" ,
425
+ [
426
+ ("key" , Any , dataclass_field (1 , meta .map_types [0 ], None )),
427
+ ("value" , Any , dataclass_field (2 , meta .map_types [1 ], None )),
428
+ ],
429
+ bases = (Message ,),
430
+ )
431
+ value = Entry ().parse (value )
397
432
398
433
return value
399
434
@@ -434,10 +469,12 @@ def parse(self, data: bytes) -> T:
434
469
parsed .wire_type , meta , field , parsed .value
435
470
)
436
471
437
- if isinstance (getattr (self , field .name ), list ) and not isinstance (
438
- value , list
439
- ):
440
- getattr (self , field .name ).append (value )
472
+ current = getattr (self , field .name )
473
+ if meta .proto_type == TYPE_MAP :
474
+ # Value represents a single key/value pair entry in the map.
475
+ current [value .key ] = value .value
476
+ elif isinstance (current , list ) and not isinstance (value , list ):
477
+ current .append (value )
441
478
else :
442
479
setattr (self , field .name , value )
443
480
else :
0 commit comments