Skip to content

Commit 55be5ee

Browse files
Merge pull request #2 from danielgtaylor/map-message
Add support for map value message types
2 parents 32bc8d5 + 7dbaee0 commit 55be5ee

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

README.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1-
# TODO
1+
# Better Protobuf / gRPC Support for Python
2+
3+
This project aims to provide an improved experience when using Protobuf / gRPC in a modern Python environment by making use of modern language features and generating readable, understandable code. It will not support legacy features or environments. The following are supported:
4+
5+
- Protobuf 3 & gRPC code generation
6+
- Both binary & JSON serialization is built-in
7+
- Python 3.7+
8+
- Enums
9+
- Dataclasses
10+
- `async`/`await`
11+
- Relative imports
12+
- Mypy type checking
13+
14+
This project is heavily inspired by, and borrows functionality from:
15+
16+
- https://github.com/eigenein/protobuf/
17+
- https://github.com/vmagamedov/grpclib
18+
19+
## TODO
220

321
- [x] Fixed length fields
422
- [x] Packed fixed-length
@@ -7,11 +25,12 @@
725
- [x] Enums
826
- [x] Repeated message fields
927
- [x] Maps
10-
- [ ] Maps of message fields
28+
- [x] Maps of message fields
1129
- [ ] Support passthrough of unknown fields
1230
- [ ] Refs to nested types
1331
- [ ] Imports in proto files
1432
- [ ] Well-known Google types
1533
- [ ] JSON that isn't completely naive.
1634
- [ ] Async service stubs
35+
- [ ] Python package
1736
- [ ] Cleanup!

betterproto/__init__.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -385,14 +385,13 @@ def __bytes__(self) -> bytes:
385385

386386
return output
387387

388-
def _cls_for(self, field: dataclasses.Field) -> Type:
388+
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
389389
"""Get the message class for a field from the type hints."""
390390
module = inspect.getmodule(self)
391391
type_hints = get_type_hints(self, vars(module))
392392
cls = type_hints[field.name]
393393
if hasattr(cls, "__args__"):
394-
print(type_hints[field.name].__args__[0])
395-
cls = type_hints[field.name].__args__[0]
394+
cls = type_hints[field.name].__args__[index]
396395
return cls
397396

398397
def _postprocess_single(
@@ -420,11 +419,13 @@ def _postprocess_single(
420419
elif meta.proto_type in [TYPE_MAP]:
421420
# TODO: This is slow, use a cache to make it faster since each
422421
# key/value pair will recreate the class.
422+
kt = self._cls_for(field, index=0)
423+
vt = self._cls_for(field, index=1)
423424
Entry = dataclasses.make_dataclass(
424425
"Entry",
425426
[
426-
("key", Any, dataclass_field(1, meta.map_types[0], None)),
427-
("value", Any, dataclass_field(2, meta.map_types[1], None)),
427+
("key", kt, dataclass_field(1, meta.map_types[0], None)),
428+
("value", vt, dataclass_field(2, meta.map_types[1], None)),
428429
],
429430
bases=(Message,),
430431
)
@@ -500,10 +501,18 @@ def to_dict(self) -> dict:
500501
v = [i for i in v if i]
501502
else:
502503
v = v.to_dict()
504+
505+
if v:
506+
output[field.name] = v
507+
elif meta.proto_type == "map":
508+
for k in v:
509+
if hasattr(v[k], "to_dict"):
510+
v[k] = v[k].to_dict()
511+
503512
if v:
504513
output[field.name] = v
505514
elif v != field.default:
506-
output[field.name] = getattr(self, field.name)
515+
output[field.name] = v
507516
return output
508517

509518
def from_dict(self, value: dict) -> T:
@@ -516,13 +525,18 @@ def from_dict(self, value: dict) -> T:
516525
if field.name in value:
517526
if meta.proto_type == "message":
518527
v = getattr(self, field.name)
519-
print(v, value[field.name])
528+
# print(v, value[field.name])
520529
if isinstance(v, list):
521530
cls = self._cls_for(field)
522531
for i in range(len(value[field.name])):
523532
v.append(cls().from_dict(value[field.name][i]))
524533
else:
525534
v.from_dict(value[field.name])
535+
elif meta.proto_type == "map" and meta.map_types[1] == TYPE_MESSAGE:
536+
v = getattr(self, field.name)
537+
cls = self._cls_for(field, index=1)
538+
for k in value[field.name]:
539+
v[k] = cls().from_dict(value[field.name][k])
526540
else:
527541
setattr(self, field.name, value[field.name])
528542
return self

betterproto/tests/mapmessage.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"items": {
3+
"foo": {
4+
"count": 1
5+
},
6+
"bar": {
7+
"count": 2
8+
}
9+
}
10+
}

betterproto/tests/mapmessage.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
message Test {
4+
map<string, Nested> items = 1;
5+
}
6+
7+
message Nested {
8+
int32 count = 1;
9+
}

0 commit comments

Comments
 (0)