diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index cbafc838d..fe997a2ae 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1714,7 +1714,10 @@ def to_pydict( defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): field_is_repeated = defaults[field_name] is list - value = getattr(self, field_name) + try: + value = getattr(self, field_name) + except AttributeError: + value = self._get_field_default(field_name) cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == TYPE_MESSAGE: if isinstance(value, datetime): @@ -1795,7 +1798,11 @@ def from_pydict(self: T, value: Mapping[str, Any]) -> T: if value[key] is not None: if meta.proto_type == TYPE_MESSAGE: - v = getattr(self, field_name) + try: + v = getattr(self, field_name) + except AttributeError: + v = self._get_field_default(field_name) + setattr(self, field_name, v) if isinstance(v, list): cls = self._betterproto.cls_by_field[field_name] for item in value[key]: @@ -1811,7 +1818,11 @@ def from_pydict(self: T, value: Mapping[str, Any]) -> T: # assignment here is necessary. v.from_pydict(value[key]) elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: - v = getattr(self, field_name) + try: + v = getattr(self, field_name) + except AttributeError: + v = self._get_field_default(field_name) + setattr(self, field_name, v) cls = self._betterproto.cls_by_field[f"{field_name}.value"] for k in value[key]: v[k] = cls().from_pydict(value[key][k]) diff --git a/tests/test_features.py b/tests/test_features.py index 193e6b250..6274c51e9 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -181,6 +181,49 @@ class Foo(betterproto.Message): assert betterproto.which_one_of(foo2, "group2")[0] == "" +def test_oneof_dict_pydict(): + + class Sub(betterproto.Enum): + aaa = 1 + bbb = 2 + + + @dataclass + class Foo(betterproto.Message): + bar: int = betterproto.int32_field(1, group="group1") + baz: str = betterproto.string_field(2, group="group1") + sub: Sub = betterproto.enum_field(3, group="group1") + + foo1 = Foo(bar=1) + + assert foo1.to_dict() == {"bar": 1} + assert foo1.to_pydict() == {"bar": 1} + + foo2 = Foo(baz="baz") + + assert foo2.to_dict() == {"baz": "baz"} + assert foo2.to_pydict() == {"baz": "baz"} + + foo3 = Foo(sub=Sub.bbb) + + # Enum fields should serialize as strings in to_dict, serialize as Enum values in to_pydict + assert foo3.to_dict() == {"sub": "bbb"} + assert foo3.to_pydict() == {"sub": Sub.bbb} + + foo4 = Foo().from_dict({"bar": 1}) + assert foo4.bar == 1 + + foo5 = Foo().from_pydict({"bar": 1}) + assert foo5.bar == 1 + + foo6 = Foo().from_dict({"sub": "bbb"}) + assert foo6.sub == Sub.bbb + + foo7 = Foo().from_pydict({"sub": 2}) + assert foo7.sub == Sub.bbb + + + @pytest.mark.skipif( sys.version_info < (3, 10), reason="pattern matching is only supported in python3.10+",