21
21
timedelta ,
22
22
timezone ,
23
23
)
24
+ from enum import IntEnum
24
25
from io import BytesIO
25
26
from itertools import count
26
27
from typing import (
@@ -560,6 +561,15 @@ def _get_cls_by_field(cls: Type["Message"], fields: Iterable[dataclasses.Field])
560
561
return field_cls
561
562
562
563
564
+ class OutputFormat (IntEnum ):
565
+ """
566
+ Chosen output format for the `Message.to_dict` method.
567
+ """
568
+
569
+ PYTHON = 1
570
+ PROTO_JSON = 2
571
+
572
+
563
573
class Message (ABC ):
564
574
"""
565
575
The base class for protobuf messages, all generated messages will inherit from
@@ -606,10 +616,6 @@ def __repr__(self) -> str:
606
616
]
607
617
return f"{ self .__class__ .__name__ } ({ ', ' .join (parts )} )"
608
618
609
- # def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
610
- # for field_name in self._betterproto.sorted_field_names:
611
- # yield field_name, self.__getattribute__(field_name), PLACEHOLDER
612
-
613
619
def __bool__ (self ) -> bool :
614
620
"""True if the message has any fields with non-default values."""
615
621
return any (
@@ -946,9 +952,15 @@ def FromString(cls: Type[T], data: bytes) -> T:
946
952
"""
947
953
return cls ().parse (data )
948
954
949
- def to_dict (self , casing : Casing = Casing .CAMEL , include_default_values : bool = False ) -> Dict [str , Any ]:
955
+ def to_dict (
956
+ self ,
957
+ * ,
958
+ output_format : OutputFormat = OutputFormat .PROTO_JSON ,
959
+ casing : Casing = Casing .CAMEL ,
960
+ include_default_values : bool = False ,
961
+ ) -> Dict [str , Any ]:
950
962
"""
951
- Returns a JSON serializable dict representation of this object .
963
+ Return a dict representation of the message .
952
964
953
965
Parameters
954
966
-----------
@@ -965,6 +977,12 @@ def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool =
965
977
Dict[:class:`str`, Any]
966
978
The JSON serializable dict representation of this object.
967
979
"""
980
+ kwargs = { # For recursive calls
981
+ "output_format" : output_format ,
982
+ "casing" : casing ,
983
+ "include_default_values" : include_default_values ,
984
+ }
985
+
968
986
output : Dict [str , Any ] = {}
969
987
field_types = self ._type_hints ()
970
988
for field_name , meta in self ._betterproto .meta_by_field_name .items ():
@@ -973,74 +991,87 @@ def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool =
973
991
cased_name = casing (field_name ).rstrip ("_" ) # type: ignore
974
992
if meta .proto_type == TYPE_MESSAGE :
975
993
if isinstance (value , datetime ):
976
- output [cased_name ] = _Timestamp .timestamp_to_json (value )
994
+ if output_format == OutputFormat .PROTO_JSON :
995
+ output [cased_name ] = _Timestamp .timestamp_to_json (value )
996
+ else :
997
+ output [cased_name ] = value
977
998
elif isinstance (value , timedelta ):
978
- output [cased_name ] = _Duration .delta_to_json (value )
999
+ if output_format == OutputFormat .PROTO_JSON :
1000
+ output [cased_name ] = _Duration .delta_to_json (value )
1001
+ else :
1002
+ output [cased_name ] = value
1003
+
979
1004
elif meta .wraps :
980
1005
if value is not None or include_default_values :
981
1006
output [cased_name ] = value
982
1007
elif field_is_repeated :
983
1008
# Convert each item.
984
- cls = self ._betterproto .cls_by_field [field_name ]
985
- if cls == datetime :
986
- value = [_Timestamp .timestamp_to_json (i ) for i in value ]
987
- elif cls == timedelta :
988
- value = [_Duration .delta_to_json (i ) for i in value ]
1009
+ if output_format == OutputFormat .PYTHON :
1010
+ value = [i .to_dict (** kwargs ) for i in value ]
989
1011
else :
990
- value = [i .to_dict (casing , include_default_values ) for i in value ]
1012
+ cls = self ._betterproto .cls_by_field [field_name ]
1013
+ if cls == datetime :
1014
+ value = [_Timestamp .timestamp_to_json (i ) for i in value ]
1015
+ elif cls == timedelta :
1016
+ value = [_Duration .delta_to_json (i ) for i in value ]
1017
+ else :
1018
+ value = [i .to_dict (** kwargs ) for i in value ]
991
1019
if value or include_default_values :
992
1020
output [cased_name ] = value
993
1021
elif value is None :
994
1022
if include_default_values :
995
- output [cased_name ] = value
1023
+ output [cased_name ] = None
996
1024
else :
997
- output [cased_name ] = value .to_dict (casing , include_default_values )
1025
+ output [cased_name ] = value .to_dict (** kwargs )
998
1026
elif meta .proto_type == TYPE_MAP :
999
1027
output_map = {** value }
1000
1028
for k in value :
1001
1029
if hasattr (value [k ], "to_dict" ):
1002
- output_map [k ] = value [k ].to_dict (casing , include_default_values )
1030
+ output_map [k ] = value [k ].to_dict (** kwargs )
1003
1031
1004
1032
if value or include_default_values :
1005
1033
output [cased_name ] = output_map
1006
1034
elif value != self ._get_field_default (field_name ) or include_default_values :
1007
- if meta .proto_type in INT_64_TYPES :
1008
- if field_is_repeated :
1009
- output [cased_name ] = [str (n ) for n in value ]
1010
- elif value is None :
1011
- if include_default_values :
1012
- output [cased_name ] = value
1013
- else :
1014
- output [cased_name ] = str (value )
1015
- elif meta .proto_type == TYPE_BYTES :
1016
- if field_is_repeated :
1017
- output [cased_name ] = [b64encode (b ).decode ("utf8" ) for b in value ]
1018
- elif value is None and include_default_values :
1019
- output [cased_name ] = value
1020
- else :
1021
- output [cased_name ] = b64encode (value ).decode ("utf8" )
1022
- elif meta .proto_type == TYPE_ENUM :
1023
- if field_is_repeated :
1024
- enum_class = field_types [field_name ].__args__ [0 ]
1025
- if isinstance (value , typing .Iterable ) and not isinstance (value , str ):
1026
- output [cased_name ] = [enum_class (el ).name for el in value ]
1035
+ if output_format == OutputFormat .PROTO_JSON :
1036
+ if meta .proto_type in INT_64_TYPES :
1037
+ if field_is_repeated :
1038
+ output [cased_name ] = [str (n ) for n in value ]
1039
+ elif value is None :
1040
+ if include_default_values :
1041
+ output [cased_name ] = value
1027
1042
else :
1028
- # transparently upgrade single value to repeated
1029
- output [cased_name ] = [enum_class (value ).name ]
1030
- elif value is None :
1031
- if include_default_values :
1043
+ output [cased_name ] = str (value )
1044
+ elif meta .proto_type == TYPE_BYTES :
1045
+ if field_is_repeated :
1046
+ output [cased_name ] = [b64encode (b ).decode ("utf8" ) for b in value ]
1047
+ elif value is None and include_default_values :
1032
1048
output [cased_name ] = value
1033
- elif meta .optional :
1034
- enum_class = field_types [field_name ].__args__ [0 ]
1035
- output [cased_name ] = enum_class (value ).name
1036
- else :
1037
- enum_class = field_types [field_name ] # noqa
1038
- output [cased_name ] = enum_class (value ).name
1039
- elif meta .proto_type in (TYPE_FLOAT , TYPE_DOUBLE ):
1040
- if field_is_repeated :
1041
- output [cased_name ] = [_dump_float (n ) for n in value ]
1049
+ else :
1050
+ output [cased_name ] = b64encode (value ).decode ("utf8" )
1051
+ elif meta .proto_type == TYPE_ENUM :
1052
+ if field_is_repeated :
1053
+ enum_class = field_types [field_name ].__args__ [0 ]
1054
+ if isinstance (value , typing .Iterable ) and not isinstance (value , str ):
1055
+ output [cased_name ] = [enum_class (el ).name for el in value ]
1056
+ else :
1057
+ # transparently upgrade single value to repeated
1058
+ output [cased_name ] = [enum_class (value ).name ]
1059
+ elif value is None :
1060
+ if include_default_values :
1061
+ output [cased_name ] = value
1062
+ elif meta .optional :
1063
+ enum_class = field_types [field_name ].__args__ [0 ]
1064
+ output [cased_name ] = enum_class (value ).name
1065
+ else :
1066
+ enum_class = field_types [field_name ] # noqa
1067
+ output [cased_name ] = enum_class (value ).name
1068
+ elif meta .proto_type in (TYPE_FLOAT , TYPE_DOUBLE ):
1069
+ if field_is_repeated :
1070
+ output [cased_name ] = [_dump_float (n ) for n in value ]
1071
+ else :
1072
+ output [cased_name ] = _dump_float (value )
1042
1073
else :
1043
- output [cased_name ] = _dump_float ( value )
1074
+ output [cased_name ] = value
1044
1075
else :
1045
1076
output [cased_name ] = value
1046
1077
return output
@@ -1188,69 +1219,6 @@ def from_json(self: T, value: Union[str, bytes]) -> T:
1188
1219
"""
1189
1220
return self .from_dict (json .loads (value ))
1190
1221
1191
- def to_pydict (self , casing : Casing = Casing .CAMEL , include_default_values : bool = False ) -> Dict [str , Any ]:
1192
- """
1193
- Returns a python dict representation of this object.
1194
-
1195
- Parameters
1196
- -----------
1197
- casing: :class:`Casing`
1198
- The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1199
- compatibility purposes.
1200
- include_default_values: :class:`bool`
1201
- If ``True`` will include the default values of fields. Default is ``False``.
1202
- E.g. an ``int32`` field will be included with a value of ``0`` if this is
1203
- set to ``True``, otherwise this would be ignored.
1204
-
1205
- Returns
1206
- --------
1207
- Dict[:class:`str`, Any]
1208
- The python dict representation of this object.
1209
- """
1210
- output : Dict [str , Any ] = {}
1211
- for field_name , meta in self ._betterproto .meta_by_field_name .items ():
1212
- field_is_repeated = meta .repeated
1213
- value = getattr (self , field_name )
1214
- cased_name = casing (field_name ).rstrip ("_" ) # type: ignore
1215
- if meta .proto_type == TYPE_MESSAGE :
1216
- if isinstance (value , datetime ):
1217
- if (
1218
- value != DATETIME_ZERO
1219
- or include_default_values
1220
- or self ._include_default_value_for_oneof (field_name = field_name , meta = meta )
1221
- ):
1222
- output [cased_name ] = value
1223
- elif isinstance (value , timedelta ):
1224
- if (
1225
- value != timedelta (0 )
1226
- or include_default_values
1227
- or self ._include_default_value_for_oneof (field_name = field_name , meta = meta )
1228
- ):
1229
- output [cased_name ] = value
1230
- elif meta .wraps :
1231
- if value is not None or include_default_values :
1232
- output [cased_name ] = value
1233
- elif field_is_repeated :
1234
- # Convert each item.
1235
- value = [i .to_pydict (casing , include_default_values ) for i in value ]
1236
- if value or include_default_values :
1237
- output [cased_name ] = value
1238
- elif value is None :
1239
- if include_default_values :
1240
- output [cased_name ] = None
1241
- else :
1242
- output [cased_name ] = value .to_pydict (casing , include_default_values )
1243
- elif meta .proto_type == TYPE_MAP :
1244
- for k in value :
1245
- if hasattr (value [k ], "to_pydict" ):
1246
- value [k ] = value [k ].to_pydict (casing , include_default_values )
1247
-
1248
- if value or include_default_values :
1249
- output [cased_name ] = value
1250
- elif value != self ._get_field_default (field_name ) or include_default_values :
1251
- output [cased_name ] = value
1252
- return output
1253
-
1254
1222
def from_pydict (self : T , value : Mapping [str , Any ]) -> T :
1255
1223
"""
1256
1224
Parse the key/value pairs into the current message instance. This returns the
0 commit comments