Skip to content

Commit f4e2fa2

Browse files
Fix maps (#97)
1 parent 4cf2e9e commit f4e2fa2

File tree

1 file changed

+46
-31
lines changed

1 file changed

+46
-31
lines changed

src/betterproto2/__init__.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -503,19 +503,17 @@ def _get_cls_by_field(cls: type[Message], fields: Iterable[dataclasses.Field]) -
503503
assert meta.map_meta
504504
kt = cls._cls_for(field_, index=0)
505505
vt = cls._cls_for(field_, index=1)
506+
507+
if meta.map_meta[1].proto_type == TYPE_ENUM:
508+
value_field = field(2, meta.map_meta[1].proto_type, default_factory=lambda: vt(0))
509+
else:
510+
value_field = field(2, meta.map_meta[1].proto_type, unwrap=meta.map_meta[1].unwrap)
511+
506512
field_cls[field_.name] = dataclasses.make_dataclass(
507513
"Entry",
508514
[
509-
(
510-
"key",
511-
kt,
512-
field(1, meta.map_meta[0].proto_type, default_factory=kt),
513-
),
514-
(
515-
"value",
516-
vt,
517-
field(2, meta.map_meta[1].proto_type, default_factory=vt),
518-
),
515+
("key", kt, field(1, meta.map_meta[0].proto_type)),
516+
("value", vt, value_field),
519517
],
520518
bases=(Message,),
521519
)
@@ -582,8 +580,6 @@ def _value_to_dict(
582580

583581

584582
def _value_from_dict(value: Any, meta: FieldMetadata, field_type: type) -> Any:
585-
# TODO directly pass `meta` when available for maps
586-
587583
if meta.proto_type == TYPE_MESSAGE:
588584
msg_cls = meta.unwrap() if meta.unwrap else field_type
589585

@@ -593,6 +589,26 @@ def _value_from_dict(value: Any, meta: FieldMetadata, field_type: type) -> Any:
593589
return msg.to_wrapped()
594590
return msg
595591

592+
if meta.proto_type == TYPE_ENUM:
593+
if isinstance(value, str):
594+
return field_type.from_string(value)
595+
if isinstance(value, int):
596+
return field_type(value)
597+
if isinstance(value, Enum):
598+
return value
599+
raise ValueError("Enum value must be a string or an Enum instance")
600+
601+
if meta.proto_type in INT_64_TYPES: # TODO all integer types
602+
return int(value)
603+
604+
if meta.proto_type == TYPE_BYTES:
605+
return b64decode(value)
606+
607+
if meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
608+
return _parse_float(value)
609+
610+
return value
611+
596612

597613
class Message(ABC):
598614
"""
@@ -671,8 +687,8 @@ def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
671687
try:
672688
return cls._betterproto_meta
673689
except AttributeError:
674-
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
675-
return meta
690+
cls._betterproto_meta = ProtoClassMetadata(cls)
691+
return cls._betterproto_meta
676692

677693
def dump(self, stream: SupportsWrite[bytes], delimit: bool = False) -> None:
678694
"""
@@ -1050,6 +1066,8 @@ def _from_dict_init(cls, mapping: Mapping[str, Any] | Any) -> Mapping[str, Any]:
10501066
init_kwargs: dict[str, Any] = {}
10511067
for key, value in mapping.items():
10521068
field_name = safe_snake_case(key)
1069+
field_cls = cls._betterproto.cls_by_field[field_name]
1070+
10531071
try:
10541072
meta = cls._betterproto.meta_by_field_name[field_name]
10551073
except KeyError:
@@ -1059,26 +1077,23 @@ def _from_dict_init(cls, mapping: Mapping[str, Any] | Any) -> Mapping[str, Any]:
10591077

10601078
if meta.proto_type == TYPE_MESSAGE:
10611079
if meta.repeated:
1062-
value = [_value_from_dict(item, meta, cls._betterproto.cls_by_field[field_name]) for item in value]
1080+
value = [_value_from_dict(item, meta, field_cls) for item in value]
10631081
else:
1064-
value = _value_from_dict(value, meta, cls._betterproto.cls_by_field[field_name])
1082+
value = _value_from_dict(value, meta, field_cls)
1083+
1084+
elif meta.proto_type == TYPE_MAP:
1085+
assert meta.map_meta
1086+
assert isinstance(value, dict)
1087+
1088+
value_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
1089+
1090+
value = {k: _value_from_dict(v, meta.map_meta[1], value_cls) for k, v in value.items()}
1091+
1092+
elif meta.repeated:
1093+
value = [_value_from_dict(item, meta, field_cls) for item in value]
10651094

1066-
elif meta.map_meta and meta.map_meta[1].proto_type == TYPE_MESSAGE:
1067-
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
1068-
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
10691095
else:
1070-
if meta.proto_type in INT_64_TYPES:
1071-
value = [int(n) for n in value] if isinstance(value, list) else int(value)
1072-
elif meta.proto_type == TYPE_BYTES:
1073-
value = [b64decode(n) for n in value] if isinstance(value, list) else b64decode(value)
1074-
elif meta.proto_type == TYPE_ENUM:
1075-
enum_cls = cls._betterproto.cls_by_field[field_name]
1076-
if isinstance(value, list):
1077-
value = [enum_cls.from_string(e) for e in value]
1078-
elif isinstance(value, str):
1079-
value = enum_cls.from_string(value)
1080-
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1081-
value = [_parse_float(n) for n in value] if isinstance(value, list) else _parse_float(value)
1096+
value = _value_from_dict(value, meta, field_cls)
10821097

10831098
init_kwargs[field_name] = value
10841099
return init_kwargs

0 commit comments

Comments
 (0)