Skip to content

Commit 654423e

Browse files
Refactor to dict (#71)
* Remove wrong test * Refactor to_dict * Fix typing
1 parent 56f419d commit 654423e

File tree

2 files changed

+86
-89
lines changed

2 files changed

+86
-89
lines changed

src/betterproto2/__init__.py

Lines changed: 86 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import math
99
import struct
1010
import sys
11-
import typing
1211
import warnings
1312
from abc import ABC
1413
from base64 import (
@@ -550,6 +549,58 @@ class OutputFormat(IntEnum):
550549
PROTO_JSON = 2
551550

552551

552+
def _value_to_dict(
553+
value: Any,
554+
proto_type: str,
555+
field_type: type,
556+
output_format: OutputFormat,
557+
casing: Casing,
558+
include_default_values: bool,
559+
) -> tuple[Any, bool]:
560+
"""
561+
Convert a single item to a Python dictionnary. This function is called on each element of a
562+
list, set, etc by `Message.to_dict`.
563+
564+
Returns:
565+
A tuple (dict, is_default_value)
566+
"""
567+
kwargs = { # For recursive calls
568+
"output_format": output_format,
569+
"casing": casing,
570+
"include_default_values": include_default_values,
571+
}
572+
573+
if proto_type == TYPE_MESSAGE:
574+
if isinstance(value, datetime):
575+
if output_format == OutputFormat.PROTO_JSON:
576+
return _Timestamp.timestamp_to_json(value), False
577+
return value, False
578+
579+
if isinstance(value, timedelta):
580+
if output_format == OutputFormat.PROTO_JSON:
581+
return _Duration.delta_to_json(value), False
582+
return value, False
583+
584+
if not isinstance(value, Message): # For wrapped types
585+
return value, False
586+
587+
return value.to_dict(**kwargs), False
588+
589+
if output_format == OutputFormat.PYTHON:
590+
return value, not bool(value)
591+
592+
# PROTO_JSON
593+
if proto_type in INT_64_TYPES:
594+
return str(value), not bool(value)
595+
if proto_type == TYPE_BYTES:
596+
return b64encode(value).decode("utf8"), not (bool(value))
597+
if proto_type == TYPE_ENUM:
598+
return field_type(value).name, not bool(value)
599+
if proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
600+
return _dump_float(value), not bool(value)
601+
return value, not bool(value)
602+
603+
553604
class Message(ABC):
554605
"""
555606
The base class for protobuf messages, all generated messages will inherit from
@@ -965,95 +1016,46 @@ def to_dict(
9651016

9661017
output: dict[str, Any] = {}
9671018
field_types = self._type_hints()
1019+
9681020
for field_name, meta in self._betterproto.meta_by_field_name.items():
969-
field_is_repeated = meta.repeated
9701021
value = getattr(self, field_name)
9711022
cased_name = casing(field_name).rstrip("_") # type: ignore
972-
if meta.proto_type == TYPE_MESSAGE:
973-
if isinstance(value, datetime):
974-
if output_format == OutputFormat.PROTO_JSON:
975-
output[cased_name] = _Timestamp.timestamp_to_json(value)
976-
else:
977-
output[cased_name] = value
978-
elif isinstance(value, timedelta):
979-
if output_format == OutputFormat.PROTO_JSON:
980-
output[cased_name] = _Duration.delta_to_json(value)
981-
else:
982-
output[cased_name] = value
9831023

984-
elif meta.wraps:
985-
if value is not None or include_default_values:
986-
output[cased_name] = value
987-
elif field_is_repeated:
988-
# Convert each item.
989-
if output_format == OutputFormat.PYTHON:
990-
value = [i.to_dict(**kwargs) for i in value]
991-
else:
992-
cls = self._betterproto.cls_by_field[field_name]
993-
if cls == datetime:
994-
value = [_Timestamp.timestamp_to_json(i) for i in value]
995-
elif cls == timedelta:
996-
value = [_Duration.delta_to_json(i) for i in value]
997-
else:
998-
value = [i.to_dict(**kwargs) for i in value]
999-
if value or include_default_values:
1000-
output[cased_name] = value
1001-
elif value is None:
1002-
if include_default_values:
1003-
output[cased_name] = None
1004-
else:
1005-
output[cased_name] = value.to_dict(**kwargs)
1006-
elif meta.proto_type == TYPE_MAP:
1007-
output_map = {**value}
1008-
for k in value:
1009-
if hasattr(value[k], "to_dict"):
1010-
output_map[k] = value[k].to_dict(**kwargs)
1024+
if meta.repeated or meta.optional:
1025+
field_type = field_types[field_name].__args__[0]
1026+
else:
1027+
field_type = field_types[field_name]
1028+
1029+
if meta.repeated:
1030+
output_value = [_value_to_dict(v, meta.proto_type, field_type, **kwargs)[0] for v in value]
1031+
if output_value or include_default_values:
1032+
output[cased_name] = output_value
10111033

1012-
if value or include_default_values:
1034+
elif meta.proto_type == TYPE_MAP:
1035+
assert meta.map_types is not None
1036+
field_type_k = field_types[field_name].__args__[0]
1037+
field_type_v = field_types[field_name].__args__[1]
1038+
output_map = {
1039+
_value_to_dict(k, meta.map_types[0], field_type_k, **kwargs)[0]: _value_to_dict(
1040+
v, meta.map_types[1], field_type_v, **kwargs
1041+
)[0]
1042+
for k, v in value.items()
1043+
}
1044+
1045+
if output_map or include_default_values:
10131046
output[cased_name] = output_map
1014-
elif value != self._get_field_default(field_name) or include_default_values:
1015-
if output_format == OutputFormat.PROTO_JSON:
1016-
if meta.proto_type in INT_64_TYPES:
1017-
if field_is_repeated:
1018-
output[cased_name] = [str(n) for n in value]
1019-
elif value is None:
1020-
if include_default_values:
1021-
output[cased_name] = value
1022-
else:
1023-
output[cased_name] = str(value)
1024-
elif meta.proto_type == TYPE_BYTES:
1025-
if field_is_repeated:
1026-
output[cased_name] = [b64encode(b).decode("utf8") for b in value]
1027-
elif value is None and include_default_values:
1028-
output[cased_name] = value
1029-
else:
1030-
output[cased_name] = b64encode(value).decode("utf8")
1031-
elif meta.proto_type == TYPE_ENUM:
1032-
if field_is_repeated:
1033-
enum_class = field_types[field_name].__args__[0]
1034-
if isinstance(value, typing.Iterable) and not isinstance(value, str):
1035-
output[cased_name] = [enum_class(el).name for el in value]
1036-
else:
1037-
# transparently upgrade single value to repeated
1038-
output[cased_name] = [enum_class(value).name]
1039-
elif value is None:
1040-
if include_default_values:
1041-
output[cased_name] = value
1042-
elif meta.optional:
1043-
enum_class = field_types[field_name].__args__[0]
1044-
output[cased_name] = enum_class(value).name
1045-
else:
1046-
enum_class = field_types[field_name] # noqa
1047-
output[cased_name] = enum_class(value).name
1048-
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1049-
if field_is_repeated:
1050-
output[cased_name] = [_dump_float(n) for n in value]
1051-
else:
1052-
output[cased_name] = _dump_float(value)
1053-
else:
1054-
output[cased_name] = value
1047+
1048+
else:
1049+
if value is None:
1050+
output_value, is_default = None, True
10551051
else:
1056-
output[cased_name] = value
1052+
output_value, is_default = _value_to_dict(value, meta.proto_type, field_type, **kwargs)
1053+
if meta.optional:
1054+
is_default = False
1055+
1056+
if include_default_values or not is_default:
1057+
output[cased_name] = output_value
1058+
10571059
return output
10581060

10591061
@classmethod
@@ -1297,7 +1299,7 @@ def _validate_field_groups(cls, values):
12971299
pass
12981300
else:
12991301

1300-
def parse_patched(self, data: bytes) -> T:
1302+
def parse_patched(self, data: bytes) -> Message:
13011303
betterproto2_rust_codec.deserialize(self, data)
13021304
return self
13031305

tests/inputs/enum/test_enum.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ def test_repeated_enum_to_dict():
5858
assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]
5959

6060

61-
def test_repeated_enum_with_single_value_to_dict():
62-
assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"]
63-
assert Test(choices=1).to_dict()["choices"] == ["ONE"]
64-
65-
6661
def test_repeated_enum_with_non_list_iterables_to_dict():
6762
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
6863
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]

0 commit comments

Comments
 (0)