Skip to content

Commit dedead0

Browse files
Read proto objects before services
1 parent 87b3a4b commit dedead0

File tree

1 file changed

+28
-18
lines changed

1 file changed

+28
-18
lines changed

betterproto/plugin.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def get_py_zero(type_num: int) -> Union[str, float]:
7676
return zero
7777

7878

79+
# Todo: Keep information about nested hierarchy
7980
def traverse(proto_file):
8081
def _traverse(path, items, prefix=""):
8182
for i, item in enumerate(items):
@@ -146,11 +147,10 @@ def generate_code(request, response):
146147
)
147148
output_package_files[output_package]["files"].append(proto_file)
148149

149-
output_paths = set()
150+
# Initialize Template data for each package
150151
for output_package_name, output_package_content in output_package_files.items():
151-
input_package_name = output_package_content["input_package"]
152152
template_data = {
153-
"input_package": input_package_name,
153+
"input_package": output_package_content["input_package"],
154154
"files": [f.name for f in output_package_content["files"]],
155155
"imports": set(),
156156
"datetime_imports": set(),
@@ -159,15 +159,26 @@ def generate_code(request, response):
159159
"enums": [],
160160
"services": [],
161161
}
162+
output_package_content["template_data"] = template_data
162163

164+
# Read Messages and Enums
165+
for output_package_name, output_package_content in output_package_files.items():
163166
for proto_file in output_package_content["files"]:
164-
item: DescriptorProto
165167
for item, path in traverse(proto_file):
166-
read_protobuf_type(input_package_name, item, path, proto_file, template_data)
168+
read_protobuf_object(item, path, proto_file, output_package_content)
167169

168-
for i, service in enumerate(proto_file.service):
169-
read_protobuf_service(i, input_package_name, proto_file, service, template_data)
170+
# Read Services
171+
for output_package_name, output_package_content in output_package_files.items():
172+
for proto_file in output_package_content["files"]:
173+
for index, service in enumerate(proto_file.service):
174+
read_protobuf_service(
175+
service, index, proto_file, output_package_content
176+
)
170177

178+
# Render files
179+
output_paths = set()
180+
for output_package_name, output_package_content in output_package_files.items():
181+
template_data = output_package_content["template_data"]
171182
template_data["imports"] = sorted(template_data["imports"])
172183
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
173184
template_data["typing_imports"] = sorted(template_data["typing_imports"])
@@ -203,12 +214,14 @@ def generate_code(request, response):
203214
print(f"Writing {output_package_name}", file=sys.stderr)
204215

205216

206-
def read_protobuf_service(i, input_package_name, proto_file, service, template_data):
217+
def read_protobuf_service(service: DescriptorProto, index, proto_file, content):
218+
input_package_name = content["input_package"]
219+
template_data = content["template_data"]
207220
# print(service, file=sys.stderr)
208221
data = {
209222
"name": service.name,
210223
"py_name": pythonize_class_name(service.name),
211-
"comment": get_comment(proto_file, [6, i]),
224+
"comment": get_comment(proto_file, [6, index]),
212225
"methods": [],
213226
}
214227
for j, method in enumerate(service.method):
@@ -228,7 +241,7 @@ def read_protobuf_service(i, input_package_name, proto_file, service, template_d
228241
{
229242
"name": method.name,
230243
"py_name": pythonize_method_name(method.name),
231-
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
244+
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
232245
"route": f"/{input_package_name}.{service.name}/{method.name}",
233246
"input": get_type_reference(
234247
input_package_name, template_data["imports"], method.input_type
@@ -254,7 +267,9 @@ def read_protobuf_service(i, input_package_name, proto_file, service, template_d
254267
template_data["services"].append(data)
255268

256269

257-
def read_protobuf_type(input_package_name, item, path, proto_file, template_data):
270+
def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, content):
271+
input_package_name = content["input_package"]
272+
template_data = content["template_data"]
258273
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
259274
if isinstance(item, DescriptorProto):
260275
# print(item, file=sys.stderr)
@@ -280,9 +295,7 @@ def read_protobuf_type(input_package_name, item, path, proto_file, template_data
280295
field_type = f.Type.Name(f.type).lower()[5:]
281296

282297
field_wraps = ""
283-
match_wrapper = re.match(
284-
r"\.google\.protobuf\.(.+)Value", f.type_name
285-
)
298+
match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name)
286299
if match_wrapper:
287300
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
288301
if hasattr(betterproto, wrapped_type):
@@ -297,10 +310,7 @@ def read_protobuf_type(input_package_name, item, path, proto_file, template_data
297310

298311
if message_type == map_entry:
299312
for nested in item.nested_type:
300-
if (
301-
nested.name.replace("_", "").lower()
302-
== map_entry
303-
):
313+
if nested.name.replace("_", "").lower() == map_entry:
304314
if nested.options.map_entry:
305315
# print("Found a map!", file=sys.stderr)
306316
k = py_type(

0 commit comments

Comments
 (0)