Skip to content

Commit c78851b

Browse files
Merge pull request #12 from ulasozguler/master
Added `include_default_values` parameter to `to_dict` function
2 parents 559b883 + c0170f4 commit c78851b

File tree

2 files changed

+108
-10
lines changed

2 files changed

+108
-10
lines changed

betterproto/__init__.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,16 @@ def parse(self: T, data: bytes) -> T:
704704
def FromString(cls: Type[T], data: bytes) -> T:
705705
return cls().parse(data)
706706

707-
def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
707+
def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> dict:
708708
"""
709709
Returns a dict representation of this message instance which can be
710710
used to serialize to e.g. JSON. Defaults to camel casing for
711711
compatibility but can be set to other modes.
712+
713+
`include_default_values` can be set to `True` to include default
714+
values of fields. E.g. an `int32` type field with `0` value will
715+
not be in returned dict if `include_default_values` is set to
716+
`False`.
712717
"""
713718
output: Dict[str, Any] = {}
714719
for field in dataclasses.fields(self):
@@ -717,28 +722,29 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
717722
cased_name = casing(field.name).rstrip("_") # type: ignore
718723
if meta.proto_type == "message":
719724
if isinstance(v, datetime):
720-
if v != DATETIME_ZERO:
725+
if v != DATETIME_ZERO or include_default_values:
721726
output[cased_name] = _Timestamp.timestamp_to_json(v)
722727
elif isinstance(v, timedelta):
723-
if v != timedelta(0):
728+
if v != timedelta(0) or include_default_values:
724729
output[cased_name] = _Duration.delta_to_json(v)
725730
elif meta.wraps:
726-
if v is not None:
731+
if v is not None or include_default_values:
727732
output[cased_name] = v
728733
elif isinstance(v, list):
729734
# Convert each item.
730-
v = [i.to_dict(casing) for i in v]
735+
v = [i.to_dict(casing, include_default_values) for i in v]
731736
output[cased_name] = v
732-
elif v._serialized_on_wire:
733-
output[cased_name] = v.to_dict(casing)
737+
else:
738+
if v._serialized_on_wire or include_default_values:
739+
output[cased_name] = v.to_dict(casing, include_default_values)
734740
elif meta.proto_type == "map":
735741
for k in v:
736742
if hasattr(v[k], "to_dict"):
737-
v[k] = v[k].to_dict(casing)
743+
v[k] = v[k].to_dict(casing, include_default_values)
738744

739-
if v:
745+
if v or include_default_values:
740746
output[cased_name] = v
741-
elif v != self._get_field_default(field, meta):
747+
elif v != self._get_field_default(field, meta) or include_default_values:
742748
if meta.proto_type in INT_64_TYPES:
743749
if isinstance(v, list):
744750
output[cased_name] = [str(n) for n in v]

betterproto/tests/test_features.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,95 @@ class Request(betterproto.Message):
177177
# Differentiate between not passed and the zero-value.
178178
assert Request().parse(b"").flag == None
179179
assert Request().parse(b"\n\x00").flag == False
180+
181+
182+
def test_to_dict_default_values():
183+
@dataclass
184+
class TestMessage(betterproto.Message):
185+
some_int: int = betterproto.int32_field(1)
186+
some_double: float = betterproto.double_field(2)
187+
some_str: str = betterproto.string_field(3)
188+
some_bool: bool = betterproto.bool_field(4)
189+
190+
# Empty dict
191+
test = TestMessage().from_dict({})
192+
193+
assert test.to_dict(include_default_values=True) == {
194+
'someInt': 0,
195+
'someDouble': 0.0,
196+
'someStr': '',
197+
'someBool': False
198+
}
199+
200+
# All default values
201+
test = TestMessage().from_dict({
202+
'someInt': 0,
203+
'someDouble': 0.0,
204+
'someStr': '',
205+
'someBool': False
206+
})
207+
208+
assert test.to_dict(include_default_values=True) == {
209+
'someInt': 0,
210+
'someDouble': 0.0,
211+
'someStr': '',
212+
'someBool': False
213+
}
214+
215+
# Some default and some other values
216+
@dataclass
217+
class TestMessage2(betterproto.Message):
218+
some_int: int = betterproto.int32_field(1)
219+
some_double: float = betterproto.double_field(2)
220+
some_str: str = betterproto.string_field(3)
221+
some_bool: bool = betterproto.bool_field(4)
222+
some_default_int: int = betterproto.int32_field(5)
223+
some_default_double: float = betterproto.double_field(6)
224+
some_default_str: str = betterproto.string_field(7)
225+
some_default_bool: bool = betterproto.bool_field(8)
226+
227+
test = TestMessage2().from_dict({
228+
'someInt': 2,
229+
'someDouble': 1.2,
230+
'someStr': 'hello',
231+
'someBool': True,
232+
'someDefaultInt': 0,
233+
'someDefaultDouble': 0.0,
234+
'someDefaultStr': '',
235+
'someDefaultBool': False
236+
})
237+
238+
assert test.to_dict(include_default_values=True) == {
239+
'someInt': 2,
240+
'someDouble': 1.2,
241+
'someStr': 'hello',
242+
'someBool': True,
243+
'someDefaultInt': 0,
244+
'someDefaultDouble': 0.0,
245+
'someDefaultStr': '',
246+
'someDefaultBool': False
247+
}
248+
249+
# Nested messages
250+
@dataclass
251+
class TestChildMessage(betterproto.Message):
252+
some_other_int: int = betterproto.int32_field(1)
253+
254+
@dataclass
255+
class TestParentMessage(betterproto.Message):
256+
some_int: int = betterproto.int32_field(1)
257+
some_double: float = betterproto.double_field(2)
258+
some_message: TestChildMessage = betterproto.message_field(3)
259+
260+
test = TestParentMessage().from_dict({
261+
'someInt': 0,
262+
'someDouble': 1.2,
263+
})
264+
265+
assert test.to_dict(include_default_values=True) == {
266+
'someInt': 0,
267+
'someDouble': 1.2,
268+
'someMessage': {
269+
'someOtherInt': 0
270+
}
271+
}

0 commit comments

Comments
 (0)