Skip to content

Commit d035288

Browse files
Remove to_pydict (#62)
* Remove to_pydict * Add kwargs variable * Fix bug
1 parent eeb60d3 commit d035288

File tree

2 files changed

+96
-135
lines changed

2 files changed

+96
-135
lines changed

src/betterproto2/__init__.py

Lines changed: 82 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
timedelta,
2222
timezone,
2323
)
24+
from enum import IntEnum
2425
from io import BytesIO
2526
from itertools import count
2627
from typing import (
@@ -560,6 +561,15 @@ def _get_cls_by_field(cls: Type["Message"], fields: Iterable[dataclasses.Field])
560561
return field_cls
561562

562563

564+
class OutputFormat(IntEnum):
565+
"""
566+
Chosen output format for the `Message.to_dict` method.
567+
"""
568+
569+
PYTHON = 1
570+
PROTO_JSON = 2
571+
572+
563573
class Message(ABC):
564574
"""
565575
The base class for protobuf messages, all generated messages will inherit from
@@ -606,10 +616,6 @@ def __repr__(self) -> str:
606616
]
607617
return f"{self.__class__.__name__}({', '.join(parts)})"
608618

609-
# def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
610-
# for field_name in self._betterproto.sorted_field_names:
611-
# yield field_name, self.__getattribute__(field_name), PLACEHOLDER
612-
613619
def __bool__(self) -> bool:
614620
"""True if the message has any fields with non-default values."""
615621
return any(
@@ -946,9 +952,15 @@ def FromString(cls: Type[T], data: bytes) -> T:
946952
"""
947953
return cls().parse(data)
948954

949-
def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> Dict[str, Any]:
955+
def to_dict(
956+
self,
957+
*,
958+
output_format: OutputFormat = OutputFormat.PROTO_JSON,
959+
casing: Casing = Casing.CAMEL,
960+
include_default_values: bool = False,
961+
) -> Dict[str, Any]:
950962
"""
951-
Returns a JSON serializable dict representation of this object.
963+
Return a dict representation of the message.
952964
953965
Parameters
954966
-----------
@@ -965,6 +977,12 @@ def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool =
965977
Dict[:class:`str`, Any]
966978
The JSON serializable dict representation of this object.
967979
"""
980+
kwargs = { # For recursive calls
981+
"output_format": output_format,
982+
"casing": casing,
983+
"include_default_values": include_default_values,
984+
}
985+
968986
output: Dict[str, Any] = {}
969987
field_types = self._type_hints()
970988
for field_name, meta in self._betterproto.meta_by_field_name.items():
@@ -973,74 +991,87 @@ def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool =
973991
cased_name = casing(field_name).rstrip("_") # type: ignore
974992
if meta.proto_type == TYPE_MESSAGE:
975993
if isinstance(value, datetime):
976-
output[cased_name] = _Timestamp.timestamp_to_json(value)
994+
if output_format == OutputFormat.PROTO_JSON:
995+
output[cased_name] = _Timestamp.timestamp_to_json(value)
996+
else:
997+
output[cased_name] = value
977998
elif isinstance(value, timedelta):
978-
output[cased_name] = _Duration.delta_to_json(value)
999+
if output_format == OutputFormat.PROTO_JSON:
1000+
output[cased_name] = _Duration.delta_to_json(value)
1001+
else:
1002+
output[cased_name] = value
1003+
9791004
elif meta.wraps:
9801005
if value is not None or include_default_values:
9811006
output[cased_name] = value
9821007
elif field_is_repeated:
9831008
# Convert each item.
984-
cls = self._betterproto.cls_by_field[field_name]
985-
if cls == datetime:
986-
value = [_Timestamp.timestamp_to_json(i) for i in value]
987-
elif cls == timedelta:
988-
value = [_Duration.delta_to_json(i) for i in value]
1009+
if output_format == OutputFormat.PYTHON:
1010+
value = [i.to_dict(**kwargs) for i in value]
9891011
else:
990-
value = [i.to_dict(casing, include_default_values) for i in value]
1012+
cls = self._betterproto.cls_by_field[field_name]
1013+
if cls == datetime:
1014+
value = [_Timestamp.timestamp_to_json(i) for i in value]
1015+
elif cls == timedelta:
1016+
value = [_Duration.delta_to_json(i) for i in value]
1017+
else:
1018+
value = [i.to_dict(**kwargs) for i in value]
9911019
if value or include_default_values:
9921020
output[cased_name] = value
9931021
elif value is None:
9941022
if include_default_values:
995-
output[cased_name] = value
1023+
output[cased_name] = None
9961024
else:
997-
output[cased_name] = value.to_dict(casing, include_default_values)
1025+
output[cased_name] = value.to_dict(**kwargs)
9981026
elif meta.proto_type == TYPE_MAP:
9991027
output_map = {**value}
10001028
for k in value:
10011029
if hasattr(value[k], "to_dict"):
1002-
output_map[k] = value[k].to_dict(casing, include_default_values)
1030+
output_map[k] = value[k].to_dict(**kwargs)
10031031

10041032
if value or include_default_values:
10051033
output[cased_name] = output_map
10061034
elif value != self._get_field_default(field_name) or include_default_values:
1007-
if meta.proto_type in INT_64_TYPES:
1008-
if field_is_repeated:
1009-
output[cased_name] = [str(n) for n in value]
1010-
elif value is None:
1011-
if include_default_values:
1012-
output[cased_name] = value
1013-
else:
1014-
output[cased_name] = str(value)
1015-
elif meta.proto_type == TYPE_BYTES:
1016-
if field_is_repeated:
1017-
output[cased_name] = [b64encode(b).decode("utf8") for b in value]
1018-
elif value is None and include_default_values:
1019-
output[cased_name] = value
1020-
else:
1021-
output[cased_name] = b64encode(value).decode("utf8")
1022-
elif meta.proto_type == TYPE_ENUM:
1023-
if field_is_repeated:
1024-
enum_class = field_types[field_name].__args__[0]
1025-
if isinstance(value, typing.Iterable) and not isinstance(value, str):
1026-
output[cased_name] = [enum_class(el).name for el in value]
1035+
if output_format == OutputFormat.PROTO_JSON:
1036+
if meta.proto_type in INT_64_TYPES:
1037+
if field_is_repeated:
1038+
output[cased_name] = [str(n) for n in value]
1039+
elif value is None:
1040+
if include_default_values:
1041+
output[cased_name] = value
10271042
else:
1028-
# transparently upgrade single value to repeated
1029-
output[cased_name] = [enum_class(value).name]
1030-
elif value is None:
1031-
if include_default_values:
1043+
output[cased_name] = str(value)
1044+
elif meta.proto_type == TYPE_BYTES:
1045+
if field_is_repeated:
1046+
output[cased_name] = [b64encode(b).decode("utf8") for b in value]
1047+
elif value is None and include_default_values:
10321048
output[cased_name] = value
1033-
elif meta.optional:
1034-
enum_class = field_types[field_name].__args__[0]
1035-
output[cased_name] = enum_class(value).name
1036-
else:
1037-
enum_class = field_types[field_name] # noqa
1038-
output[cased_name] = enum_class(value).name
1039-
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1040-
if field_is_repeated:
1041-
output[cased_name] = [_dump_float(n) for n in value]
1049+
else:
1050+
output[cased_name] = b64encode(value).decode("utf8")
1051+
elif meta.proto_type == TYPE_ENUM:
1052+
if field_is_repeated:
1053+
enum_class = field_types[field_name].__args__[0]
1054+
if isinstance(value, typing.Iterable) and not isinstance(value, str):
1055+
output[cased_name] = [enum_class(el).name for el in value]
1056+
else:
1057+
# transparently upgrade single value to repeated
1058+
output[cased_name] = [enum_class(value).name]
1059+
elif value is None:
1060+
if include_default_values:
1061+
output[cased_name] = value
1062+
elif meta.optional:
1063+
enum_class = field_types[field_name].__args__[0]
1064+
output[cased_name] = enum_class(value).name
1065+
else:
1066+
enum_class = field_types[field_name] # noqa
1067+
output[cased_name] = enum_class(value).name
1068+
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1069+
if field_is_repeated:
1070+
output[cased_name] = [_dump_float(n) for n in value]
1071+
else:
1072+
output[cased_name] = _dump_float(value)
10421073
else:
1043-
output[cased_name] = _dump_float(value)
1074+
output[cased_name] = value
10441075
else:
10451076
output[cased_name] = value
10461077
return output
@@ -1188,69 +1219,6 @@ def from_json(self: T, value: Union[str, bytes]) -> T:
11881219
"""
11891220
return self.from_dict(json.loads(value))
11901221

1191-
def to_pydict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> Dict[str, Any]:
1192-
"""
1193-
Returns a python dict representation of this object.
1194-
1195-
Parameters
1196-
-----------
1197-
casing: :class:`Casing`
1198-
The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1199-
compatibility purposes.
1200-
include_default_values: :class:`bool`
1201-
If ``True`` will include the default values of fields. Default is ``False``.
1202-
E.g. an ``int32`` field will be included with a value of ``0`` if this is
1203-
set to ``True``, otherwise this would be ignored.
1204-
1205-
Returns
1206-
--------
1207-
Dict[:class:`str`, Any]
1208-
The python dict representation of this object.
1209-
"""
1210-
output: Dict[str, Any] = {}
1211-
for field_name, meta in self._betterproto.meta_by_field_name.items():
1212-
field_is_repeated = meta.repeated
1213-
value = getattr(self, field_name)
1214-
cased_name = casing(field_name).rstrip("_") # type: ignore
1215-
if meta.proto_type == TYPE_MESSAGE:
1216-
if isinstance(value, datetime):
1217-
if (
1218-
value != DATETIME_ZERO
1219-
or include_default_values
1220-
or self._include_default_value_for_oneof(field_name=field_name, meta=meta)
1221-
):
1222-
output[cased_name] = value
1223-
elif isinstance(value, timedelta):
1224-
if (
1225-
value != timedelta(0)
1226-
or include_default_values
1227-
or self._include_default_value_for_oneof(field_name=field_name, meta=meta)
1228-
):
1229-
output[cased_name] = value
1230-
elif meta.wraps:
1231-
if value is not None or include_default_values:
1232-
output[cased_name] = value
1233-
elif field_is_repeated:
1234-
# Convert each item.
1235-
value = [i.to_pydict(casing, include_default_values) for i in value]
1236-
if value or include_default_values:
1237-
output[cased_name] = value
1238-
elif value is None:
1239-
if include_default_values:
1240-
output[cased_name] = None
1241-
else:
1242-
output[cased_name] = value.to_pydict(casing, include_default_values)
1243-
elif meta.proto_type == TYPE_MAP:
1244-
for k in value:
1245-
if hasattr(value[k], "to_pydict"):
1246-
value[k] = value[k].to_pydict(casing, include_default_values)
1247-
1248-
if value or include_default_values:
1249-
output[cased_name] = value
1250-
elif value != self._get_field_default(field_name) or include_default_values:
1251-
output[cased_name] = value
1252-
return output
1253-
12541222
def from_pydict(self: T, value: Mapping[str, Any]) -> T:
12551223
"""
12561224
Parse the key/value pairs into the current message instance. This returns the

tests/test_features.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from unittest.mock import ANY
1111

1212
import betterproto2
13+
from betterproto2 import OutputFormat
1314

1415

1516
def test_class_init():
@@ -18,7 +19,7 @@ def test_class_init():
1819
foo = Foo(name="foo", child=Bar(name="bar"))
1920

2021
assert foo.to_dict() == {"name": "foo", "child": {"name": "bar"}}
21-
assert foo.to_pydict() == {"name": "foo", "child": {"name": "bar"}}
22+
assert foo.to_dict(output_format=OutputFormat.PYTHON) == {"name": "foo", "child": {"name": "bar"}}
2223

2324

2425
def test_enum_as_int_json():
@@ -35,7 +36,7 @@ def test_enum_as_int_json():
3536
# Similar expectations for pydict
3637
enum_msg = EnumMsg().from_dict({"enum": 1})
3738
assert enum_msg.enum == Enum.ONE
38-
assert enum_msg.to_pydict() == {"enum": Enum.ONE}
39+
assert enum_msg.to_dict(output_format=OutputFormat.PYTHON) == {"enum": Enum.ONE}
3940

4041

4142
def test_unknown_fields():
@@ -127,25 +128,13 @@ def test_dict_casing():
127128
"snakeCase": 3,
128129
"kabobCase": 4,
129130
}
130-
assert msg.to_pydict() == {
131-
"pascalCase": 1,
132-
"camelCase": 2,
133-
"snakeCase": 3,
134-
"kabobCase": 4,
135-
}
136131

137132
assert msg.to_dict(casing=betterproto2.Casing.SNAKE) == {
138133
"pascal_case": 1,
139134
"camel_case": 2,
140135
"snake_case": 3,
141136
"kabob_case": 4,
142137
}
143-
assert msg.to_pydict(casing=betterproto2.Casing.SNAKE) == {
144-
"pascal_case": 1,
145-
"camel_case": 2,
146-
"snake_case": 3,
147-
"kabob_case": 4,
148-
}
149138

150139

151140
def test_optional_flag():
@@ -173,12 +162,16 @@ def test_optional_datetime_to_dict():
173162
}
174163

175164
# Check pydict serialization
176-
assert OptionalDatetimeMsg().to_pydict() == {}
177-
assert OptionalDatetimeMsg().to_pydict(include_default_values=True) == {"field": None}
178-
assert OptionalDatetimeMsg(field=datetime(2020, 1, 1)).to_pydict() == {"field": datetime(2020, 1, 1)}
179-
assert OptionalDatetimeMsg(field=datetime(2020, 1, 1)).to_pydict(include_default_values=True) == {
165+
assert OptionalDatetimeMsg().to_dict(output_format=OutputFormat.PYTHON) == {}
166+
assert OptionalDatetimeMsg().to_dict(include_default_values=True, output_format=OutputFormat.PYTHON) == {
167+
"field": None
168+
}
169+
assert OptionalDatetimeMsg(field=datetime(2020, 1, 1)).to_dict(output_format=OutputFormat.PYTHON) == {
180170
"field": datetime(2020, 1, 1)
181171
}
172+
assert OptionalDatetimeMsg(field=datetime(2020, 1, 1)).to_dict(
173+
include_default_values=True, output_format=OutputFormat.PYTHON
174+
) == {"field": datetime(2020, 1, 1)}
182175

183176

184177
def test_to_json_default_values():
@@ -218,7 +211,7 @@ def test_to_dict_default_values():
218211
"someBool": False,
219212
}
220213

221-
assert test.to_pydict(include_default_values=True) == {
214+
assert test.to_dict(include_default_values=True, output_format=OutputFormat.PYTHON) == {
222215
"someInt": 0,
223216
"someDouble": 0.0,
224217
"someStr": "",
@@ -263,7 +256,7 @@ def test_to_dict_default_values():
263256
}
264257
)
265258

266-
assert test.to_pydict(include_default_values=True) == {
259+
assert test.to_dict(include_default_values=True, output_format=OutputFormat.PYTHON) == {
267260
"someInt": 2,
268261
"someDouble": 1.2,
269262
"someStr": "hello",
@@ -284,7 +277,7 @@ def test_to_dict_datetime_values():
284277

285278
test = TimeMsg().from_pydict({"timestamp": datetime(year=2020, month=1, day=1), "duration": timedelta(days=1)})
286279

287-
assert test.to_pydict() == {
280+
assert test.to_dict(output_format=OutputFormat.PYTHON) == {
288281
"timestamp": datetime(year=2020, month=1, day=1),
289282
"duration": timedelta(days=1),
290283
}

0 commit comments

Comments
 (0)