Skip to content

Commit 5daf61f

Browse files
committed
Refactor default value code
1 parent 4679c57 commit 5daf61f

File tree

2 files changed

+51
-47
lines changed

2 files changed

+51
-47
lines changed

betterproto/__init__.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,6 @@ class _PLACEHOLDER:
118118
PLACEHOLDER: Any = _PLACEHOLDER()
119119

120120

121-
def get_default(proto_type: str) -> Any:
122-
"""Get the default (zero value) for a given type."""
123-
return {
124-
TYPE_BOOL: False,
125-
TYPE_FLOAT: 0.0,
126-
TYPE_DOUBLE: 0.0,
127-
TYPE_STRING: "",
128-
TYPE_BYTES: b"",
129-
TYPE_MAP: {},
130-
}.get(proto_type, 0)
131-
132-
133121
@dataclasses.dataclass(frozen=True)
134122
class FieldMetadata:
135123
"""Stores internal metadata used for parsing & serialization."""
@@ -467,11 +455,22 @@ def __bytes__(self) -> bytes:
467455
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
468456
selected_in_group = True
469457

470-
if isinstance(value, list):
471-
if not len(value) and not selected_in_group:
472-
# Empty values are not serialized
473-
continue
458+
serialize_empty = False
459+
if isinstance(value, Message) and value._serialized_on_wire:
460+
# Empty messages can still be sent on the wire if they were
461+
# set (or received empty).
462+
serialize_empty = True
463+
464+
if value == self._get_field_default(field, meta) and not (
465+
selected_in_group or serialize_empty
466+
):
467+
# Default (zero) values are not serialized. Two exceptions are
468+
# if this is the selected oneof item or if we know we have to
469+
# serialize an empty message (i.e. zero value was explicitly
470+
# set by the user).
471+
continue
474472

473+
if isinstance(value, list):
475474
if meta.proto_type in PACKED_TYPES:
476475
# Packed lists look like a length-delimited field. First,
477476
# preprocess/encode each value into a buffer and then
@@ -484,23 +483,12 @@ def __bytes__(self) -> bytes:
484483
for item in value:
485484
output += _serialize_single(meta.number, meta.proto_type, item)
486485
elif isinstance(value, dict):
487-
if not len(value) and not selected_in_group:
488-
# Empty values are not serialized
489-
continue
490-
491486
for k, v in value.items():
492487
assert meta.map_types
493488
sk = _serialize_single(1, meta.map_types[0], k)
494489
sv = _serialize_single(2, meta.map_types[1], v)
495490
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
496491
else:
497-
if value == get_default(meta.proto_type) and not selected_in_group:
498-
# Default (zero) values are not serialized
499-
continue
500-
501-
serialize_empty = False
502-
if isinstance(value, Message) and value._serialized_on_wire:
503-
serialize_empty = True
504492
output += _serialize_single(
505493
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
506494
)
@@ -510,30 +498,42 @@ def __bytes__(self) -> bytes:
510498
# For compatibility with other libraries
511499
SerializeToString = __bytes__
512500

513-
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
514-
"""Get the message class for a field from the type hints."""
501+
def _type_hint(self, field_name: str) -> Type:
515502
module = inspect.getmodule(self.__class__)
516503
type_hints = get_type_hints(self.__class__, vars(module))
517-
cls = type_hints[field.name]
504+
return type_hints[field_name]
505+
506+
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
507+
"""Get the message class for a field from the type hints."""
508+
cls = self._type_hint(field.name)
518509
if hasattr(cls, "__args__") and index >= 0:
519-
cls = type_hints[field.name].__args__[index]
510+
cls = cls.__args__[index]
520511
return cls
521512

522513
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
523-
t = self._cls_for(field, index=-1)
514+
t = self._type_hint(field.name)
524515

525516
value: Any = 0
526-
if meta.proto_type == TYPE_MAP:
527-
# Maps cannot be repeated, so we check these first.
528-
value = {}
529-
elif hasattr(t, "__args__") and len(t.__args__) == 1:
530-
# Anything else with type args is a list.
531-
value = []
532-
elif meta.proto_type == TYPE_MESSAGE:
533-
# Message means creating an instance of the right type.
534-
value = t()
517+
if hasattr(t, "__origin__"):
518+
if t.__origin__ == dict:
519+
# This is some kind of map (dict in Python).
520+
value = {}
521+
elif t.__origin__ == list:
522+
# This is some kind of list (repeated) field.
523+
value = []
524+
elif t.__origin__ == Union and t.__args__[1] == type(None):
525+
# This is an optional (wrapped) field. For setting the default we
526+
# really don't care what kind of field it is.
527+
value = None
528+
else:
529+
value = t()
530+
elif issubclass(t, Enum):
531+
# Enums always default to zero.
532+
value = 0
535533
else:
536-
value = get_default(meta.proto_type)
534+
# This is either a primitive scalar or another message type. Calling
535+
# it should result in its zero value.
536+
value = t()
537537

538538
return value
539539

@@ -659,7 +659,7 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
659659

660660
if v:
661661
output[cased_name] = v
662-
elif v != get_default(meta.proto_type):
662+
elif v != self._get_field_default(field, meta):
663663
if meta.proto_type in INT_64_TYPES:
664664
if isinstance(v, list):
665665
output[cased_name] = [str(n) for n in v]

betterproto/plugin.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
3535
Return a Python type name for a proto type reference. Adds the import if
3636
necessary.
3737
"""
38+
# If the package name is a blank string, then this should still work
39+
# because by convention packages are lowercase and message/enum types are
40+
# pascal-cased. May require refactoring in the future.
3841
type_name = type_name.lstrip(".")
3942
if type_name.startswith(package):
40-
# This is the current package, which has nested types flattened.
41-
# foo.bar_thing => FooBarThing
4243
parts = type_name.lstrip(package).lstrip(".").split(".")
43-
cased = [stringcase.pascalcase(part) for part in parts]
44-
type_name = f'"{"".join(cased)}"'
44+
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
45+
# This is the current package, which has nested types flattened.
46+
# foo.bar_thing => FooBarThing
47+
cased = [stringcase.pascalcase(part) for part in parts]
48+
type_name = f'"{"".join(cased)}"'
4549

4650
if "." in type_name:
4751
# This is imported from another package. No need

0 commit comments

Comments
 (0)