Skip to content

Commit 9084534

Browse files
Fix unwrap in maps (#87)
* Fix unwrap in maps * Add test
1 parent 4853b11 commit 9084534

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

src/betterproto2/__init__.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,10 @@ class FieldMetadata:
141141
number: int
142142
# Protobuf type name
143143
proto_type: str
144+
144145
# Map information if the proto_type is a map
145-
map_types: tuple[str, str] | None = None
146+
map_meta: tuple[FieldMetadata, FieldMetadata] | None = None
147+
146148
# Groups several "one-of" fields together
147149
group: str | None = None
148150

@@ -160,12 +162,24 @@ def get(field: dataclasses.Field) -> FieldMetadata:
160162
return field.metadata["betterproto"]
161163

162164

165+
def map_meta(
166+
proto_type_1: str,
167+
proto_type_2: str,
168+
*,
169+
unwrap_2: Callable[[], type] | None = None,
170+
) -> tuple[FieldMetadata, FieldMetadata]:
171+
key_meta = FieldMetadata(1, proto_type_1)
172+
value_meta = FieldMetadata(2, proto_type_2, unwrap=unwrap_2)
173+
174+
return key_meta, value_meta
175+
176+
163177
def field(
164178
number: int,
165179
proto_type: str,
166180
*,
167181
default_factory: Callable[[], Any] | None = None,
168-
map_types: tuple[str, str] | None = None,
182+
map_meta: tuple[FieldMetadata, FieldMetadata] | None = None,
169183
group: str | None = None,
170184
unwrap: Callable[[], type] | None = None,
171185
optional: bool = False,
@@ -202,7 +216,7 @@ def field(
202216

203217
return dataclasses.field(
204218
default_factory=default_factory,
205-
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group, unwrap, optional, repeated)},
219+
metadata={"betterproto": FieldMetadata(number, proto_type, map_meta, group, unwrap, optional, repeated)},
206220
)
207221

208222

@@ -485,7 +499,7 @@ def _get_cls_by_field(cls: type[Message], fields: Iterable[dataclasses.Field]) -
485499
for field_ in fields:
486500
meta = FieldMetadata.get(field_)
487501
if meta.proto_type == TYPE_MAP:
488-
assert meta.map_types
502+
assert meta.map_meta
489503
kt = cls._cls_for(field_, index=0)
490504
vt = cls._cls_for(field_, index=1)
491505
field_cls[field_.name] = dataclasses.make_dataclass(
@@ -494,12 +508,12 @@ def _get_cls_by_field(cls: type[Message], fields: Iterable[dataclasses.Field]) -
494508
(
495509
"key",
496510
kt,
497-
field(1, meta.map_types[0], default_factory=kt),
511+
field(1, meta.map_meta[0].proto_type, default_factory=kt),
498512
),
499513
(
500514
"value",
501515
vt,
502-
field(2, meta.map_types[1], default_factory=vt),
516+
field(2, meta.map_meta[1].proto_type, default_factory=vt),
503517
),
504518
],
505519
bases=(Message,),
@@ -720,9 +734,9 @@ def __bytes__(self) -> bytes:
720734

721735
elif isinstance(value, dict):
722736
for k, v in value.items():
723-
assert meta.map_types
724-
sk = _serialize_single(1, meta.map_types[0], k)
725-
sv = _serialize_single(2, meta.map_types[1], v)
737+
assert meta.map_meta
738+
sk = _serialize_single(1, meta.map_meta[0].proto_type, k)
739+
sv = _serialize_single(2, meta.map_meta[1].proto_type, v, unwrap=meta.map_meta[1].unwrap)
726740
stream.write(_serialize_single(meta.number, meta.proto_type, sk + sv))
727741
else:
728742
stream.write(
@@ -1007,13 +1021,12 @@ def to_dict(
10071021
output[cased_name] = output_value
10081022

10091023
elif meta.proto_type == TYPE_MAP:
1010-
assert meta.map_types is not None
1024+
assert meta.map_meta is not None
10111025
field_type_k = field_types[field_name].__args__[0]
10121026
field_type_v = field_types[field_name].__args__[1]
1013-
# TODO wrapped types don't work in maps
10141027
output_map = {
1015-
_value_to_dict(k, meta.map_types[0], field_type_k, None, **kwargs)[0]: _value_to_dict(
1016-
v, meta.map_types[1], field_type_v, None, **kwargs
1028+
_value_to_dict(k, meta.map_meta[0].proto_type, field_type_k, None, **kwargs)[0]: _value_to_dict(
1029+
v, meta.map_meta[1].proto_type, field_type_v, meta.map_meta[1].unwrap, **kwargs
10171030
)[0]
10181031
for k, v in value.items()
10191032
}
@@ -1058,7 +1071,7 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
10581071
value, meta.proto_type, cls._betterproto.cls_by_field[field_name], meta.unwrap
10591072
)
10601073

1061-
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1074+
elif meta.map_meta and meta.map_meta[1].proto_type == TYPE_MESSAGE:
10621075
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
10631076
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
10641077
else:
@@ -1209,7 +1222,7 @@ def from_pydict(self: T, value: Mapping[str, Any]) -> T:
12091222
v = value[key]
12101223
else:
12111224
v = cls().from_pydict(value[key])
1212-
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1225+
elif meta.map_meta and meta.map_meta[1].proto_type == TYPE_MESSAGE:
12131226
v = getattr(self, field_name)
12141227
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
12151228
for k in value[key]:

tests/test_message_wraping.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import datetime
2+
3+
4+
def test_message_wrapping_map():
5+
from tests.output_betterproto.message_wrapping import MapMessage
6+
7+
msg = MapMessage(map1={"key": 12.0}, map2={"key": datetime.timedelta(seconds=1)})
8+
9+
bytes(msg)
10+
11+
assert msg.to_dict() == {"map1": {"key": 12.0}, "map2": {"key": "1s"}}

0 commit comments

Comments
 (0)