Skip to content

Commit 0ae2e78

Browse files
Remote from_pydict (#90)
* Remote from_pydict * Fix type checking
1 parent ec4d445 commit 0ae2e78

File tree

2 files changed

+22
-91
lines changed

2 files changed

+22
-91
lines changed

src/betterproto2/__init__.py

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
450450

451451
class ProtoClassMetadata:
452452
__slots__ = (
453-
"oneof_group_by_field",
454453
"oneof_field_by_group",
455454
"default_gen",
456455
"cls_by_field",
@@ -459,7 +458,6 @@ class ProtoClassMetadata:
459458
"sorted_field_names",
460459
)
461460

462-
oneof_group_by_field: dict[str, str] # TODO delete (still used in the rust codec for now)
463461
oneof_field_by_group: dict[str, set[dataclasses.Field]]
464462
field_name_by_number: dict[int, str]
465463
meta_by_field_name: dict[str, FieldMetadata]
@@ -468,7 +466,6 @@ class ProtoClassMetadata:
468466
cls_by_field: dict[str, type]
469467

470468
def __init__(self, cls: type[Message]):
471-
by_field = {}
472469
by_group: dict[str, set] = {}
473470
by_field_name = {}
474471
by_field_number = {}
@@ -479,15 +476,11 @@ def __init__(self, cls: type[Message]):
479476
meta = FieldMetadata.get(field)
480477

481478
if meta.group:
482-
# This is part of a one-of group.
483-
by_field[field.name] = meta.group
484-
485479
by_group.setdefault(meta.group, set()).add(field)
486480

487481
by_field_name[field.name] = meta
488482
by_field_number[meta.number] = field.name
489483

490-
self.oneof_group_by_field = by_field
491484
self.oneof_field_by_group = by_group
492485
self.field_name_by_number = by_field_number
493486
self.meta_by_field_name = by_field_name
@@ -588,15 +581,15 @@ def _value_to_dict(
588581
return value, not bool(value)
589582

590583

591-
def _value_from_dict(value: Any, proto_type: str, field_type: type, unwrap: Callable[[], type] | None = None) -> Any:
584+
def _value_from_dict(value: Any, meta: FieldMetadata, field_type: type) -> Any:
592585
# TODO directly pass `meta` when available for maps
593586

594-
if proto_type == TYPE_MESSAGE:
595-
msg_cls = unwrap() if unwrap else field_type
587+
if meta.proto_type == TYPE_MESSAGE:
588+
msg_cls = meta.unwrap() if meta.unwrap else field_type
596589

597590
msg = msg_cls.from_dict(value)
598591

599-
if unwrap:
592+
if meta.unwrap:
600593
return msg.to_wrapped()
601594
return msg
602595

@@ -1056,7 +1049,7 @@ def to_dict(
10561049
return output
10571050

10581051
@classmethod
1059-
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
1052+
def _from_dict_init(cls, mapping: Mapping[str, Any] | Any) -> Mapping[str, Any]:
10601053
# TODO restructure using other function
10611054
init_kwargs: dict[str, Any] = {}
10621055
for key, value in mapping.items():
@@ -1070,14 +1063,9 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
10701063

10711064
if meta.proto_type == TYPE_MESSAGE:
10721065
if meta.repeated:
1073-
value = [
1074-
_value_from_dict(item, meta.proto_type, cls._betterproto.cls_by_field[field_name], meta.unwrap)
1075-
for item in value
1076-
]
1066+
value = [_value_from_dict(item, meta, cls._betterproto.cls_by_field[field_name]) for item in value]
10771067
else:
1078-
value = _value_from_dict(
1079-
value, meta.proto_type, cls._betterproto.cls_by_field[field_name], meta.unwrap
1080-
)
1068+
value = _value_from_dict(value, meta, cls._betterproto.cls_by_field[field_name])
10811069

10821070
elif meta.map_meta and meta.map_meta[1].proto_type == TYPE_MESSAGE:
10831071
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
@@ -1100,7 +1088,7 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
11001088
return init_kwargs
11011089

11021090
@hybridmethod
1103-
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
1091+
def from_dict(cls: type[Self], value: Mapping[str, Any] | Any) -> Self: # type: ignore
11041092
"""
11051093
Parse the key/value pairs into the a new message instance.
11061094
@@ -1114,10 +1102,13 @@ def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignor
11141102
:class:`Message`
11151103
The initialized message.
11161104
"""
1105+
if not isinstance(value, Mapping) and hasattr(cls, "from_wrapped"): # type: ignore
1106+
return cls.from_wrapped(value) # type: ignore
1107+
11171108
return cls(**cls._from_dict_init(value))
11181109

11191110
@from_dict.instancemethod
1120-
def from_dict(self, value: Mapping[str, Any]) -> Self:
1111+
def from_dict(self, value: Mapping[str, Any] | Any) -> Self:
11211112
"""
11221113
Parse the key/value pairs into the current message instance. This returns the
11231114
instance itself and is therefore assignable and chainable.
@@ -1194,49 +1185,6 @@ def from_json(self: T, value: str | bytes) -> T:
11941185
"""
11951186
return self.from_dict(json.loads(value))
11961187

1197-
def from_pydict(self: T, value: Mapping[str, Any]) -> T:
1198-
"""
1199-
Parse the key/value pairs into the current message instance. This returns the
1200-
instance itself and is therefore assignable and chainable.
1201-
1202-
Parameters
1203-
-----------
1204-
value: Dict[:class:`str`, Any]
1205-
The dictionary to parse from.
1206-
1207-
Returns
1208-
--------
1209-
:class:`Message`
1210-
The initialized message.
1211-
"""
1212-
for key in value:
1213-
field_name = safe_snake_case(key)
1214-
meta = self._betterproto.meta_by_field_name.get(field_name)
1215-
if not meta:
1216-
continue
1217-
1218-
if value[key] is not None:
1219-
if meta.proto_type == TYPE_MESSAGE:
1220-
v = getattr(self, field_name)
1221-
cls = self._betterproto.cls_by_field[field_name]
1222-
if issubclass(cls, list):
1223-
raise NotImplementedError # TODO look at this
1224-
elif meta.unwrap:
1225-
v = value[key]
1226-
else:
1227-
v = cls().from_pydict(value[key])
1228-
elif meta.map_meta and meta.map_meta[1].proto_type == TYPE_MESSAGE:
1229-
v = getattr(self, field_name)
1230-
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
1231-
for k in value[key]:
1232-
v[k] = cls().from_pydict(value[key][k])
1233-
else:
1234-
v = value[key]
1235-
1236-
if v is not None:
1237-
setattr(self, field_name, v)
1238-
return self
1239-
12401188
def is_set(self, name: str) -> bool:
12411189
"""
12421190
Check if field with the given name has been set.

tests/test_features.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -246,42 +246,25 @@ def test_to_dict_default_values():
246246
"someDefaultBool": False,
247247
}
248248

249-
test = MsgB().from_pydict(
250-
{
251-
"someInt": 2,
252-
"someDouble": 1.2,
253-
"someStr": "hello",
254-
"someBool": True,
255-
"someDefaultInt": 0,
256-
"someDefaultDouble": 0.0,
257-
"someDefaultStr": "",
258-
"someDefaultBool": False,
259-
}
260-
)
261-
262-
assert test.to_dict(include_default_values=True, output_format=OutputFormat.PYTHON) == {
263-
"someInt": 2,
264-
"someDouble": 1.2,
265-
"someStr": "hello",
266-
"someBool": True,
267-
"someDefaultInt": 0,
268-
"someDefaultDouble": 0.0,
269-
"someDefaultStr": "",
270-
"someDefaultBool": False,
271-
}
272-
273249

274250
def test_to_dict_datetime_values():
275251
from tests.output_betterproto.features import TimeMsg
276252

277-
test = TimeMsg().from_dict({"timestamp": "2020-01-01T00:00:00Z", "duration": "86400s"})
253+
test = TimeMsg.from_dict({"timestamp": "2020-01-01T00:00:00Z", "duration": "86400s"})
278254
assert test.to_dict() == {"timestamp": "2020-01-01T00:00:00Z", "duration": "86400s"}
279255

280-
test = TimeMsg().from_pydict({"timestamp": datetime(year=2020, month=1, day=1), "duration": timedelta(days=1)})
256+
test = TimeMsg.from_dict(
257+
{"timestamp": datetime(year=2020, month=1, day=1, tzinfo=timezone.utc), "duration": timedelta(days=1)}
258+
)
281259
assert test.to_dict(output_format=OutputFormat.PYTHON) == {
282-
"timestamp": datetime(year=2020, month=1, day=1),
260+
"timestamp": datetime(year=2020, month=1, day=1, tzinfo=timezone.utc),
283261
"duration": timedelta(days=1),
284262
}
263+
assert test.to_dict(output_format=OutputFormat.PROTO_JSON) == {
264+
"timestamp": "2020-01-01T00:00:00Z",
265+
"duration": "86400s",
266+
}
267+
bytes(test)
285268

286269

287270
def test_oneof_default_value_set_causes_writes_wire():

0 commit comments

Comments
 (0)