Skip to content

Commit d43d5af

Browse files
committed
Better JSON casing support, renaming messages/fields
1 parent ef0a1bf commit d43d5af

File tree

11 files changed

+165
-83
lines changed

11 files changed

+165
-83
lines changed

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ rope = "*"
1414
protobuf = "*"
1515
jinja2 = "*"
1616
grpclib = "*"
17+
stringcase = "*"
1718

1819
[requires]
1920
python_version = "3.7"

Pipfile.lock

Lines changed: 34 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ $ pipenv run tests
301301
- [x] Unary-unary
302302
- [x] Server streaming response
303303
- [ ] Client streaming request
304-
- [ ] Renaming messages and fields to conform to Python name standards
304+
- [x] Renaming messages and fields to conform to Python name standards
305305
- [ ] Renaming clashes with language keywords and standard library top-level packages
306306
- [x] Python package
307307
- [x] Automate running tests

betterproto/__init__.py

Lines changed: 64 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import grpclib.client
2626
import grpclib.const
27+
import stringcase
2728

2829
# Proto 3 data types
2930
TYPE_ENUM = "enum"
@@ -101,6 +102,13 @@
101102
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
102103

103104

105+
class Casing(enum.Enum):
106+
"""Casing constants for serialization."""
107+
108+
CAMEL = stringcase.camelcase
109+
SNAKE = stringcase.snakecase
110+
111+
104112
class _PLACEHOLDER:
105113
pass
106114

@@ -624,48 +632,50 @@ def parse(self: T, data: bytes) -> T:
624632
def FromString(cls: Type[T], data: bytes) -> T:
625633
return cls().parse(data)
626634

627-
def to_dict(self) -> dict:
635+
def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
628636
"""
629637
Returns a dict representation of this message instance which can be
630-
used to serialize to e.g. JSON.
638+
used to serialize to e.g. JSON. Defaults to camel casing for
639+
compatibility but can be set to other modes.
631640
"""
632641
output: Dict[str, Any] = {}
633642
for field in dataclasses.fields(self):
634643
meta = FieldMetadata.get(field)
635644
v = getattr(self, field.name)
645+
cased_name = casing(field.name)
636646
if meta.proto_type == "message":
637647
if isinstance(v, list):
638648
# Convert each item.
639649
v = [i.to_dict() for i in v]
640-
output[field.name] = v
650+
output[cased_name] = v
641651
elif v._serialized_on_wire:
642-
output[field.name] = v.to_dict()
652+
output[cased_name] = v.to_dict()
643653
elif meta.proto_type == "map":
644654
for k in v:
645655
if hasattr(v[k], "to_dict"):
646656
v[k] = v[k].to_dict()
647657

648658
if v:
649-
output[field.name] = v
659+
output[cased_name] = v
650660
elif v != get_default(meta.proto_type):
651661
if meta.proto_type in INT_64_TYPES:
652662
if isinstance(v, list):
653-
output[field.name] = [str(n) for n in v]
663+
output[cased_name] = [str(n) for n in v]
654664
else:
655-
output[field.name] = str(v)
665+
output[cased_name] = str(v)
656666
elif meta.proto_type == TYPE_BYTES:
657667
if isinstance(v, list):
658-
output[field.name] = [b64encode(b).decode("utf8") for b in v]
668+
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
659669
else:
660-
output[field.name] = b64encode(v).decode("utf8")
670+
output[cased_name] = b64encode(v).decode("utf8")
661671
elif meta.proto_type == TYPE_ENUM:
662672
enum_values = list(self._cls_for(field))
663673
if isinstance(v, list):
664-
output[field.name] = [enum_values[e].name for e in v]
674+
output[cased_name] = [enum_values[e].name for e in v]
665675
else:
666-
output[field.name] = enum_values[v].name
676+
output[cased_name] = enum_values[v].name
667677
else:
668-
output[field.name] = v
678+
output[cased_name] = v
669679
return output
670680

671681
def from_dict(self: T, value: dict) -> T:
@@ -674,44 +684,49 @@ def from_dict(self: T, value: dict) -> T:
674684
returns the instance itself and is therefore assignable and chainable.
675685
"""
676686
self._serialized_on_wire = True
677-
for field in dataclasses.fields(self):
678-
meta = FieldMetadata.get(field)
679-
if field.name in value and value[field.name] is not None:
680-
if meta.proto_type == "message":
681-
v = getattr(self, field.name)
682-
# print(v, value[field.name])
683-
if isinstance(v, list):
684-
cls = self._cls_for(field)
685-
for i in range(len(value[field.name])):
686-
v.append(cls().from_dict(value[field.name][i]))
687-
else:
688-
v.from_dict(value[field.name])
689-
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
690-
v = getattr(self, field.name)
691-
cls = self._cls_for(field, index=1)
692-
for k in value[field.name]:
693-
v[k] = cls().from_dict(value[field.name][k])
694-
else:
695-
v = value[field.name]
696-
if meta.proto_type in INT_64_TYPES:
697-
if isinstance(value[field.name], list):
698-
v = [int(n) for n in value[field.name]]
699-
else:
700-
v = int(value[field.name])
701-
elif meta.proto_type == TYPE_BYTES:
702-
if isinstance(value[field.name], list):
703-
v = [b64decode(n) for n in value[field.name]]
704-
else:
705-
v = b64decode(value[field.name])
706-
elif meta.proto_type == TYPE_ENUM:
707-
enum_cls = self._cls_for(field)
708-
if isinstance(v, list):
709-
v = [enum_cls.from_string(e) for e in v]
710-
elif isinstance(v, str):
711-
v = enum_cls.from_string(v)
687+
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
688+
for key in value:
689+
snake_cased = stringcase.snakecase(key)
690+
if snake_cased in fields_by_name:
691+
field = fields_by_name[snake_cased]
692+
meta = FieldMetadata.get(field)
712693

713-
if v is not None:
714-
setattr(self, field.name, v)
694+
if value[key] is not None:
695+
if meta.proto_type == "message":
696+
v = getattr(self, field.name)
697+
# print(v, value[key])
698+
if isinstance(v, list):
699+
cls = self._cls_for(field)
700+
for i in range(len(value[key])):
701+
v.append(cls().from_dict(value[key][i]))
702+
else:
703+
v.from_dict(value[key])
704+
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
705+
v = getattr(self, field.name)
706+
cls = self._cls_for(field, index=1)
707+
for k in value[key]:
708+
v[k] = cls().from_dict(value[key][k])
709+
else:
710+
v = value[key]
711+
if meta.proto_type in INT_64_TYPES:
712+
if isinstance(value[key], list):
713+
v = [int(n) for n in value[key]]
714+
else:
715+
v = int(value[key])
716+
elif meta.proto_type == TYPE_BYTES:
717+
if isinstance(value[key], list):
718+
v = [b64decode(n) for n in value[key]]
719+
else:
720+
v = b64decode(value[key])
721+
elif meta.proto_type == TYPE_ENUM:
722+
enum_cls = self._cls_for(field)
723+
if isinstance(v, list):
724+
v = [enum_cls.from_string(e) for e in v]
725+
elif isinstance(v, str):
726+
v = enum_cls.from_string(v)
727+
728+
if v is not None:
729+
setattr(self, field.name, v)
715730
return self
716731

717732
def to_json(self, indent: Union[None, int, str] = None) -> str:

betterproto/plugin.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
)
1717
raise SystemExit(1)
1818

19+
import stringcase
20+
1921
from google.protobuf.compiler import plugin_pb2 as plugin
2022
from google.protobuf.descriptor_pb2 import (
2123
DescriptorProto,
@@ -26,12 +28,6 @@
2628
)
2729

2830

29-
def snake_case(value: str) -> str:
30-
return (
31-
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_")
32-
)
33-
34-
3531
def get_ref_type(package: str, imports: set, type_name: str) -> str:
3632
"""
3733
Return a Python type name for a proto type reference. Adds the import if
@@ -40,12 +36,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
4036
type_name = type_name.lstrip(".")
4137
if type_name.startswith(package):
4238
# This is the current package, which has nested types flattened.
43-
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"'
39+
# foo.bar_thing => FooBarThing
40+
parts = type_name.lstrip(package).lstrip(".").split(".")
41+
cased = [stringcase.pascalcase(part) for part in parts]
42+
type_name = f'"{"".join(cased)}"'
4443

4544
if "." in type_name:
4645
# This is imported from another package. No need
4746
# to use a forward ref and we need to add the import.
4847
parts = type_name.split(".")
48+
parts[-1] = stringcase.pascalcase(parts[-1])
4949
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
5050
type_name = f"{parts[-2]}.{parts[-1]}"
5151

@@ -179,7 +179,7 @@ def generate_code(request, response):
179179
for item, path in traverse(proto_file):
180180
# print(item, file=sys.stderr)
181181
# print(path, file=sys.stderr)
182-
data = {"name": item.name}
182+
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
183183

184184
if isinstance(item, DescriptorProto):
185185
# print(item, file=sys.stderr)
@@ -255,6 +255,7 @@ def generate_code(request, response):
255255
data["properties"].append(
256256
{
257257
"name": f.name,
258+
"py_name": stringcase.snakecase(f.name),
258259
"number": f.number,
259260
"comment": get_comment(proto_file, path + [2, i]),
260261
"proto_type": int(f.type),
@@ -294,6 +295,7 @@ def generate_code(request, response):
294295

295296
data = {
296297
"name": service.name,
298+
"py_name": stringcase.pascalcase(service.name),
297299
"comment": get_comment(proto_file, [6, i]),
298300
"methods": [],
299301
}
@@ -317,7 +319,7 @@ def generate_code(request, response):
317319
data["methods"].append(
318320
{
319321
"name": method.name,
320-
"py_name": snake_case(method.name),
322+
"py_name": stringcase.snakecase(method.name),
321323
"comment": get_comment(proto_file, [6, i, 2, j]),
322324
"route": f"/{package}.{service.name}/{method.name}",
323325
"input": get_ref_type(

0 commit comments

Comments
 (0)