|
8 | 8 | import math
|
9 | 9 | import struct
|
10 | 10 | import sys
|
11 |
| -import typing |
12 | 11 | import warnings
|
13 | 12 | from abc import ABC
|
14 | 13 | from base64 import (
|
@@ -550,6 +549,58 @@ class OutputFormat(IntEnum):
|
550 | 549 | PROTO_JSON = 2
|
551 | 550 |
|
552 | 551 |
|
| 552 | +def _value_to_dict( |
| 553 | + value: Any, |
| 554 | + proto_type: str, |
| 555 | + field_type: type, |
| 556 | + output_format: OutputFormat, |
| 557 | + casing: Casing, |
| 558 | + include_default_values: bool, |
| 559 | +) -> tuple[Any, bool]: |
| 560 | + """ |
| 561 | + Convert a single item to a Python dictionnary. This function is called on each element of a |
| 562 | + list, set, etc by `Message.to_dict`. |
| 563 | +
|
| 564 | + Returns: |
| 565 | + A tuple (dict, is_default_value) |
| 566 | + """ |
| 567 | + kwargs = { # For recursive calls |
| 568 | + "output_format": output_format, |
| 569 | + "casing": casing, |
| 570 | + "include_default_values": include_default_values, |
| 571 | + } |
| 572 | + |
| 573 | + if proto_type == TYPE_MESSAGE: |
| 574 | + if isinstance(value, datetime): |
| 575 | + if output_format == OutputFormat.PROTO_JSON: |
| 576 | + return _Timestamp.timestamp_to_json(value), False |
| 577 | + return value, False |
| 578 | + |
| 579 | + if isinstance(value, timedelta): |
| 580 | + if output_format == OutputFormat.PROTO_JSON: |
| 581 | + return _Duration.delta_to_json(value), False |
| 582 | + return value, False |
| 583 | + |
| 584 | + if not isinstance(value, Message): # For wrapped types |
| 585 | + return value, False |
| 586 | + |
| 587 | + return value.to_dict(**kwargs), False |
| 588 | + |
| 589 | + if output_format == OutputFormat.PYTHON: |
| 590 | + return value, not bool(value) |
| 591 | + |
| 592 | + # PROTO_JSON |
| 593 | + if proto_type in INT_64_TYPES: |
| 594 | + return str(value), not bool(value) |
| 595 | + if proto_type == TYPE_BYTES: |
| 596 | + return b64encode(value).decode("utf8"), not (bool(value)) |
| 597 | + if proto_type == TYPE_ENUM: |
| 598 | + return field_type(value).name, not bool(value) |
| 599 | + if proto_type in (TYPE_FLOAT, TYPE_DOUBLE): |
| 600 | + return _dump_float(value), not bool(value) |
| 601 | + return value, not bool(value) |
| 602 | + |
| 603 | + |
553 | 604 | class Message(ABC):
|
554 | 605 | """
|
555 | 606 | The base class for protobuf messages, all generated messages will inherit from
|
@@ -965,95 +1016,46 @@ def to_dict(
|
965 | 1016 |
|
966 | 1017 | output: dict[str, Any] = {}
|
967 | 1018 | field_types = self._type_hints()
|
| 1019 | + |
968 | 1020 | for field_name, meta in self._betterproto.meta_by_field_name.items():
|
969 |
| - field_is_repeated = meta.repeated |
970 | 1021 | value = getattr(self, field_name)
|
971 | 1022 | cased_name = casing(field_name).rstrip("_") # type: ignore
|
972 |
| - if meta.proto_type == TYPE_MESSAGE: |
973 |
| - if isinstance(value, datetime): |
974 |
| - if output_format == OutputFormat.PROTO_JSON: |
975 |
| - output[cased_name] = _Timestamp.timestamp_to_json(value) |
976 |
| - else: |
977 |
| - output[cased_name] = value |
978 |
| - elif isinstance(value, timedelta): |
979 |
| - if output_format == OutputFormat.PROTO_JSON: |
980 |
| - output[cased_name] = _Duration.delta_to_json(value) |
981 |
| - else: |
982 |
| - output[cased_name] = value |
983 | 1023 |
|
984 |
| - elif meta.wraps: |
985 |
| - if value is not None or include_default_values: |
986 |
| - output[cased_name] = value |
987 |
| - elif field_is_repeated: |
988 |
| - # Convert each item. |
989 |
| - if output_format == OutputFormat.PYTHON: |
990 |
| - value = [i.to_dict(**kwargs) for i in value] |
991 |
| - else: |
992 |
| - cls = self._betterproto.cls_by_field[field_name] |
993 |
| - if cls == datetime: |
994 |
| - value = [_Timestamp.timestamp_to_json(i) for i in value] |
995 |
| - elif cls == timedelta: |
996 |
| - value = [_Duration.delta_to_json(i) for i in value] |
997 |
| - else: |
998 |
| - value = [i.to_dict(**kwargs) for i in value] |
999 |
| - if value or include_default_values: |
1000 |
| - output[cased_name] = value |
1001 |
| - elif value is None: |
1002 |
| - if include_default_values: |
1003 |
| - output[cased_name] = None |
1004 |
| - else: |
1005 |
| - output[cased_name] = value.to_dict(**kwargs) |
1006 |
| - elif meta.proto_type == TYPE_MAP: |
1007 |
| - output_map = {**value} |
1008 |
| - for k in value: |
1009 |
| - if hasattr(value[k], "to_dict"): |
1010 |
| - output_map[k] = value[k].to_dict(**kwargs) |
| 1024 | + if meta.repeated or meta.optional: |
| 1025 | + field_type = field_types[field_name].__args__[0] |
| 1026 | + else: |
| 1027 | + field_type = field_types[field_name] |
| 1028 | + |
| 1029 | + if meta.repeated: |
| 1030 | + output_value = [_value_to_dict(v, meta.proto_type, field_type, **kwargs)[0] for v in value] |
| 1031 | + if output_value or include_default_values: |
| 1032 | + output[cased_name] = output_value |
1011 | 1033 |
|
1012 |
| - if value or include_default_values: |
| 1034 | + elif meta.proto_type == TYPE_MAP: |
| 1035 | + assert meta.map_types is not None |
| 1036 | + field_type_k = field_types[field_name].__args__[0] |
| 1037 | + field_type_v = field_types[field_name].__args__[1] |
| 1038 | + output_map = { |
| 1039 | + _value_to_dict(k, meta.map_types[0], field_type_k, **kwargs)[0]: _value_to_dict( |
| 1040 | + v, meta.map_types[1], field_type_v, **kwargs |
| 1041 | + )[0] |
| 1042 | + for k, v in value.items() |
| 1043 | + } |
| 1044 | + |
| 1045 | + if output_map or include_default_values: |
1013 | 1046 | output[cased_name] = output_map
|
1014 |
| - elif value != self._get_field_default(field_name) or include_default_values: |
1015 |
| - if output_format == OutputFormat.PROTO_JSON: |
1016 |
| - if meta.proto_type in INT_64_TYPES: |
1017 |
| - if field_is_repeated: |
1018 |
| - output[cased_name] = [str(n) for n in value] |
1019 |
| - elif value is None: |
1020 |
| - if include_default_values: |
1021 |
| - output[cased_name] = value |
1022 |
| - else: |
1023 |
| - output[cased_name] = str(value) |
1024 |
| - elif meta.proto_type == TYPE_BYTES: |
1025 |
| - if field_is_repeated: |
1026 |
| - output[cased_name] = [b64encode(b).decode("utf8") for b in value] |
1027 |
| - elif value is None and include_default_values: |
1028 |
| - output[cased_name] = value |
1029 |
| - else: |
1030 |
| - output[cased_name] = b64encode(value).decode("utf8") |
1031 |
| - elif meta.proto_type == TYPE_ENUM: |
1032 |
| - if field_is_repeated: |
1033 |
| - enum_class = field_types[field_name].__args__[0] |
1034 |
| - if isinstance(value, typing.Iterable) and not isinstance(value, str): |
1035 |
| - output[cased_name] = [enum_class(el).name for el in value] |
1036 |
| - else: |
1037 |
| - # transparently upgrade single value to repeated |
1038 |
| - output[cased_name] = [enum_class(value).name] |
1039 |
| - elif value is None: |
1040 |
| - if include_default_values: |
1041 |
| - output[cased_name] = value |
1042 |
| - elif meta.optional: |
1043 |
| - enum_class = field_types[field_name].__args__[0] |
1044 |
| - output[cased_name] = enum_class(value).name |
1045 |
| - else: |
1046 |
| - enum_class = field_types[field_name] # noqa |
1047 |
| - output[cased_name] = enum_class(value).name |
1048 |
| - elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): |
1049 |
| - if field_is_repeated: |
1050 |
| - output[cased_name] = [_dump_float(n) for n in value] |
1051 |
| - else: |
1052 |
| - output[cased_name] = _dump_float(value) |
1053 |
| - else: |
1054 |
| - output[cased_name] = value |
| 1047 | + |
| 1048 | + else: |
| 1049 | + if value is None: |
| 1050 | + output_value, is_default = None, True |
1055 | 1051 | else:
|
1056 |
| - output[cased_name] = value |
| 1052 | + output_value, is_default = _value_to_dict(value, meta.proto_type, field_type, **kwargs) |
| 1053 | + if meta.optional: |
| 1054 | + is_default = False |
| 1055 | + |
| 1056 | + if include_default_values or not is_default: |
| 1057 | + output[cased_name] = output_value |
| 1058 | + |
1057 | 1059 | return output
|
1058 | 1060 |
|
1059 | 1061 | @classmethod
|
@@ -1297,7 +1299,7 @@ def _validate_field_groups(cls, values):
|
1297 | 1299 | pass
|
1298 | 1300 | else:
|
1299 | 1301 |
|
1300 |
| - def parse_patched(self, data: bytes) -> T: |
| 1302 | + def parse_patched(self, data: bytes) -> Message: |
1301 | 1303 | betterproto2_rust_codec.deserialize(self, data)
|
1302 | 1304 | return self
|
1303 | 1305 |
|
|
0 commit comments