Skip to content

Commit 32bc8d5

Browse files
Merge pull request #1 from danielgtaylor/maps
Add basic support for maps
2 parents ad7162a + e0d1611 commit 32bc8d5

File tree

7 files changed

+116
-36
lines changed

7 files changed

+116
-36
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
- [x] Don't encode zero values for nested types
77
- [x] Enums
88
- [x] Repeated message fields
9-
- [ ] Maps
9+
- [x] Maps
10+
- [ ] Maps of message fields
1011
- [ ] Support passthrough of unknown fields
1112
- [ ] Refs to nested types
1213
- [ ] Imports in proto files

betterproto/__init__.py

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Type,
1414
Iterable,
1515
TypeVar,
16+
Optional,
1617
)
1718
import dataclasses
1819

@@ -36,6 +37,7 @@
3637
TYPE_STRING = "string"
3738
TYPE_BYTES = "bytes"
3839
TYPE_MESSAGE = "message"
40+
TYPE_MAP = "map"
3941

4042

4143
# Fields that use a fixed amount of space (4 or 8 bytes)
@@ -87,7 +89,7 @@
8789

8890
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
8991
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
90-
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
92+
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
9193

9294

9395
@dataclasses.dataclass(frozen=True)
@@ -98,6 +100,8 @@ class FieldMetadata:
98100
number: int
99101
# Protobuf type name
100102
proto_type: str
103+
# Map information if the proto_type is a map
104+
map_types: Optional[Tuple[str, str]]
101105
# Default value if given
102106
default: Any
103107

@@ -107,10 +111,14 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
107111
return field.metadata["betterproto"]
108112

109113

110-
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
114+
def dataclass_field(
115+
number: int,
116+
proto_type: str,
117+
default: Any,
118+
map_types: Optional[Tuple[str, str]] = None,
119+
**kwargs: dict,
120+
) -> dataclasses.Field:
111121
"""Creates a dataclass field with attached protobuf metadata."""
112-
kwargs = {}
113-
114122
if callable(default):
115123
kwargs["default_factory"] = default
116124
elif isinstance(default, dict) or isinstance(default, list):
@@ -119,7 +127,8 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
119127
kwargs["default"] = default
120128

121129
return dataclasses.field(
122-
**kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)}
130+
**kwargs,
131+
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)},
123132
)
124133

125134

@@ -129,63 +138,69 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
129138

130139

131140
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
132-
return field(number, TYPE_ENUM, default=default)
141+
return dataclass_field(number, TYPE_ENUM, default=default)
133142

134143

135144
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
136-
return field(number, TYPE_INT32, default=default)
145+
return dataclass_field(number, TYPE_INT32, default=default)
137146

138147

139148
def int64_field(number: int, default: int = 0) -> Any:
140-
return field(number, TYPE_INT64, default=default)
149+
return dataclass_field(number, TYPE_INT64, default=default)
141150

142151

143152
def uint32_field(number: int, default: int = 0) -> Any:
144-
return field(number, TYPE_UINT32, default=default)
153+
return dataclass_field(number, TYPE_UINT32, default=default)
145154

146155

147156
def uint64_field(number: int, default: int = 0) -> Any:
148-
return field(number, TYPE_UINT64, default=default)
157+
return dataclass_field(number, TYPE_UINT64, default=default)
149158

150159

151160
def sint32_field(number: int, default: int = 0) -> Any:
152-
return field(number, TYPE_SINT32, default=default)
161+
return dataclass_field(number, TYPE_SINT32, default=default)
153162

154163

155164
def sint64_field(number: int, default: int = 0) -> Any:
156-
return field(number, TYPE_SINT64, default=default)
165+
return dataclass_field(number, TYPE_SINT64, default=default)
157166

158167

159168
def float_field(number: int, default: float = 0.0) -> Any:
160-
return field(number, TYPE_FLOAT, default=default)
169+
return dataclass_field(number, TYPE_FLOAT, default=default)
161170

162171

163172
def double_field(number: int, default: float = 0.0) -> Any:
164-
return field(number, TYPE_DOUBLE, default=default)
173+
return dataclass_field(number, TYPE_DOUBLE, default=default)
165174

166175

167176
def fixed32_field(number: int, default: float = 0.0) -> Any:
168-
return field(number, TYPE_FIXED32, default=default)
177+
return dataclass_field(number, TYPE_FIXED32, default=default)
169178

170179

171180
def fixed64_field(number: int, default: float = 0.0) -> Any:
172-
return field(number, TYPE_FIXED64, default=default)
181+
return dataclass_field(number, TYPE_FIXED64, default=default)
173182

174183

175184
def sfixed32_field(number: int, default: float = 0.0) -> Any:
176-
return field(number, TYPE_SFIXED32, default=default)
185+
return dataclass_field(number, TYPE_SFIXED32, default=default)
177186

178187

179188
def sfixed64_field(number: int, default: float = 0.0) -> Any:
180-
return field(number, TYPE_SFIXED64, default=default)
189+
return dataclass_field(number, TYPE_SFIXED64, default=default)
181190

182191

183192
def string_field(number: int, default: str = "") -> Any:
184-
return field(number, TYPE_STRING, default=default)
193+
return dataclass_field(number, TYPE_STRING, default=default)
185194

186195

187196
def message_field(number: int, default: Type["Message"]) -> Any:
188-
return field(number, TYPE_MESSAGE, default=default)
197+
return dataclass_field(number, TYPE_MESSAGE, default=default)
198+
199+
200+
def map_field(number: int, key_type: str, value_type: str) -> Any:
201+
return dataclass_field(
202+
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
203+
)
189204

190205

191206
def _pack_fmt(proto_type: str) -> str:
@@ -354,6 +369,14 @@ def __bytes__(self) -> bytes:
354369
else:
355370
for item in value:
356371
output += _serialize_single(meta.number, meta.proto_type, item)
372+
elif isinstance(value, dict):
373+
if not len(value):
374+
continue
375+
376+
for k, v in value.items():
377+
sk = _serialize_single(1, meta.map_types[0], k)
378+
sv = _serialize_single(2, meta.map_types[1], v)
379+
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
357380
else:
358381
if value == field.default:
359382
continue
@@ -377,23 +400,35 @@ def _postprocess_single(
377400
) -> Any:
378401
"""Adjusts values after parsing."""
379402
if wire_type == WIRE_VARINT:
380-
if meta.proto_type in ["int32", "int64"]:
403+
if meta.proto_type in [TYPE_INT32, TYPE_INT64]:
381404
bits = int(meta.proto_type[3:])
382405
value = value & ((1 << bits) - 1)
383406
signbit = 1 << (bits - 1)
384407
value = int((value ^ signbit) - signbit)
385-
elif meta.proto_type in ["sint32", "sint64"]:
408+
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
386409
# Undo zig-zag encoding
387410
value = (value >> 1) ^ (-(value & 1))
388411
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
389412
fmt = _pack_fmt(meta.proto_type)
390413
value = struct.unpack(fmt, value)[0]
391414
elif wire_type == WIRE_LEN_DELIM:
392-
if meta.proto_type in ["string"]:
415+
if meta.proto_type in [TYPE_STRING]:
393416
value = value.decode("utf-8")
394-
elif meta.proto_type in ["message"]:
417+
elif meta.proto_type in [TYPE_MESSAGE]:
395418
cls = self._cls_for(field)
396419
value = cls().parse(value)
420+
elif meta.proto_type in [TYPE_MAP]:
421+
# TODO: This is slow, use a cache to make it faster since each
422+
# key/value pair will recreate the class.
423+
Entry = dataclasses.make_dataclass(
424+
"Entry",
425+
[
426+
("key", Any, dataclass_field(1, meta.map_types[0], None)),
427+
("value", Any, dataclass_field(2, meta.map_types[1], None)),
428+
],
429+
bases=(Message,),
430+
)
431+
value = Entry().parse(value)
397432

398433
return value
399434

@@ -434,10 +469,12 @@ def parse(self, data: bytes) -> T:
434469
parsed.wire_type, meta, field, parsed.value
435470
)
436471

437-
if isinstance(getattr(self, field.name), list) and not isinstance(
438-
value, list
439-
):
440-
getattr(self, field.name).append(value)
472+
current = getattr(self, field.name)
473+
if meta.proto_type == TYPE_MAP:
474+
# Value represents a single key/value pair entry in the map.
475+
current[value.key] = value.value
476+
elif isinstance(current, list) and not isinstance(value, list):
477+
current.append(value)
441478
else:
442479
setattr(self, field.name, value)
443480
else:

betterproto/templates/main.py

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

betterproto/tests/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def ensure_ext(filename: str, ext: str) -> str:
4848
json_files = get_files(".json")
4949

5050
for filename in proto_files:
51-
print(f"Generatinng code for {os.path.basename(filename)}")
51+
print(f"Generating code for {os.path.basename(filename)}")
5252
subprocess.run(
5353
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
5454
)

betterproto/tests/map.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"counts": {
3+
"item1": 1,
4+
"item2": 2,
5+
"item3": 3
6+
}
7+
}

betterproto/tests/map.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
syntax = "proto3";
2+
3+
message Test {
4+
map<string, int32> counts = 1;
5+
}

protoc-gen-betterpy.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DescriptorProto,
1313
EnumDescriptorProto,
1414
FileDescriptorProto,
15+
FieldDescriptorProto,
1516
)
1617

1718
from google.protobuf.compiler import plugin_pb2 as plugin
@@ -20,7 +21,9 @@
2021
from jinja2 import Environment, PackageLoader
2122

2223

23-
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
24+
def py_type(
25+
message: DescriptorProto, descriptor: FieldDescriptorProto
26+
) -> Tuple[str, str]:
2427
if descriptor.type in [1, 2, 6, 7, 15, 16]:
2528
return "float", descriptor.default_value
2629
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
@@ -115,6 +118,10 @@ def generate_code(request, response):
115118

116119
if isinstance(item, DescriptorProto):
117120
# print(item, file=sys.stderr)
121+
if item.options.map_entry:
122+
# Skip generated map entry messages since we just use dicts
123+
continue
124+
118125
data.update(
119126
{
120127
"type": "Message",
@@ -124,11 +131,33 @@ def generate_code(request, response):
124131
)
125132

126133
for i, f in enumerate(item.field):
127-
t, zero = py_type(f)
134+
t, zero = py_type(item, f)
128135
repeated = False
129136
packed = False
130137

131-
if f.label == 3:
138+
field_type = f.Type.Name(f.type).lower()[5:]
139+
map_types = None
140+
if f.type == 11:
141+
# This might be a map...
142+
message_type = f.type_name.split(".").pop()
143+
map_entry = f"{f.name.capitalize()}Entry"
144+
145+
if message_type == map_entry:
146+
for nested in item.nested_type:
147+
if nested.name == map_entry:
148+
if nested.options.map_entry:
149+
print("Found a map!", file=sys.stderr)
150+
k, _ = py_type(item, nested.field[0])
151+
v, _ = py_type(item, nested.field[1])
152+
t = f"Dict[{k}, {v}]"
153+
zero = "dict"
154+
field_type = "map"
155+
map_types = (
156+
f.Type.Name(nested.field[0].type),
157+
f.Type.Name(nested.field[1].type),
158+
)
159+
160+
if f.label == 3 and field_type != "map":
132161
# Repeated field
133162
repeated = True
134163
t = f"List[{t}]"
@@ -143,7 +172,8 @@ def generate_code(request, response):
143172
"number": f.number,
144173
"comment": get_comment(proto_file, path + [2, i]),
145174
"proto_type": int(f.type),
146-
"field_type": f.Type.Name(f.type).lower()[5:],
175+
"field_type": field_type,
176+
"map_types": map_types,
147177
"type": t,
148178
"zero": zero,
149179
"repeated": repeated,

0 commit comments

Comments
 (0)