Skip to content

Commit f2e8719

Browse files
Clarify variable names
1 parent 98d00f0 commit f2e8719

File tree

1 file changed

+52
-50
lines changed

1 file changed

+52
-50
lines changed

betterproto/plugin.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
2+
import collections
33
import itertools
44
import os.path
55
import pathlib
@@ -8,6 +8,8 @@
88
import textwrap
99
from typing import List, Union
1010

11+
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
12+
1113
import betterproto
1214
from betterproto.compile.importing import get_type_reference
1315
from betterproto.compile.naming import (
@@ -129,29 +131,27 @@ def generate_code(request, response):
129131
)
130132
template = env.get_template("template.py.j2")
131133

132-
output_map = {}
134+
# Gather output packages
135+
output_package_files = collections.defaultdict()
133136
for proto_file in request.proto_file:
134137
if (
135138
proto_file.package == "google.protobuf"
136139
and "INCLUDE_GOOGLE" not in plugin_options
137140
):
138141
continue
139142

140-
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
141-
142-
if output_file not in output_map:
143-
output_map[output_file] = {"package": proto_file.package, "files": []}
144-
output_map[output_file]["files"].append(proto_file)
145-
146-
# TODO: Figure out how to handle gRPC request/response messages and add
147-
# processing below for Service.
148-
149-
for filename, options in output_map.items():
150-
package = options["package"]
151-
# print(package, filename, file=sys.stderr)
152-
output = {
153-
"package": package,
154-
"files": [f.name for f in options["files"]],
143+
output_package = proto_file.package
144+
output_package_files.setdefault(
145+
output_package, {"input_package": proto_file.package, "files": []}
146+
)
147+
output_package_files[output_package]["files"].append(proto_file)
148+
149+
output_paths = set()
150+
for output_package_name, output_package_content in output_package_files.items():
151+
input_package_name = output_package_content["input_package"]
152+
template_data = {
153+
"input_package": input_package_name,
154+
"files": [f.name for f in output_package_content["files"]],
155155
"imports": set(),
156156
"datetime_imports": set(),
157157
"typing_imports": set(),
@@ -160,7 +160,7 @@ def generate_code(request, response):
160160
"services": [],
161161
}
162162

163-
for proto_file in options["files"]:
163+
for proto_file in output_package_content["files"]:
164164
item: DescriptorProto
165165
for item, path in traverse(proto_file):
166166
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
@@ -180,7 +180,7 @@ def generate_code(request, response):
180180
)
181181

182182
for i, f in enumerate(item.field):
183-
t = py_type(package, output["imports"], f)
183+
t = py_type(input_package_name, template_data["imports"], f)
184184
zero = get_py_zero(f.type)
185185

186186
repeated = False
@@ -213,13 +213,13 @@ def generate_code(request, response):
213213
if nested.options.map_entry:
214214
# print("Found a map!", file=sys.stderr)
215215
k = py_type(
216-
package,
217-
output["imports"],
216+
input_package_name,
217+
template_data["imports"],
218218
nested.field[0],
219219
)
220220
v = py_type(
221-
package,
222-
output["imports"],
221+
input_package_name,
222+
template_data["imports"],
223223
nested.field[1],
224224
)
225225
t = f"Dict[{k}, {v}]"
@@ -228,14 +228,14 @@ def generate_code(request, response):
228228
f.Type.Name(nested.field[0].type),
229229
f.Type.Name(nested.field[1].type),
230230
)
231-
output["typing_imports"].add("Dict")
231+
template_data["typing_imports"].add("Dict")
232232

233233
if f.label == 3 and field_type != "map":
234234
# Repeated field
235235
repeated = True
236236
t = f"List[{t}]"
237237
zero = "[]"
238-
output["typing_imports"].add("List")
238+
template_data["typing_imports"].add("List")
239239

240240
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
241241
packed = True
@@ -245,12 +245,12 @@ def generate_code(request, response):
245245
one_of = item.oneof_decl[f.oneof_index].name
246246

247247
if "Optional[" in t:
248-
output["typing_imports"].add("Optional")
248+
template_data["typing_imports"].add("Optional")
249249

250250
if "timedelta" in t:
251-
output["datetime_imports"].add("timedelta")
251+
template_data["datetime_imports"].add("timedelta")
252252
elif "datetime" in t:
253-
output["datetime_imports"].add("datetime")
253+
template_data["datetime_imports"].add("datetime")
254254

255255
data["properties"].append(
256256
{
@@ -271,7 +271,7 @@ def generate_code(request, response):
271271
)
272272
# print(f, file=sys.stderr)
273273

274-
output["messages"].append(data)
274+
template_data["messages"].append(data)
275275
elif isinstance(item, EnumDescriptorProto):
276276
# print(item.name, path, file=sys.stderr)
277277
data.update(
@@ -289,7 +289,7 @@ def generate_code(request, response):
289289
}
290290
)
291291

292-
output["enums"].append(data)
292+
template_data["enums"].append(data)
293293

294294
for i, service in enumerate(proto_file.service):
295295
# print(service, file=sys.stderr)
@@ -304,29 +304,29 @@ def generate_code(request, response):
304304
for j, method in enumerate(service.method):
305305
input_message = None
306306
input_type = get_type_reference(
307-
package, output["imports"], method.input_type
307+
input_package_name, template_data["imports"], method.input_type
308308
).strip('"')
309-
for msg in output["messages"]:
309+
for msg in template_data["messages"]:
310310
if msg["name"] == input_type:
311311
input_message = msg
312312
for field in msg["properties"]:
313313
if field["zero"] == "None":
314-
output["typing_imports"].add("Optional")
314+
template_data["typing_imports"].add("Optional")
315315
break
316316

317317
data["methods"].append(
318318
{
319319
"name": method.name,
320320
"py_name": pythonize_method_name(method.name),
321321
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
322-
"route": f"/{package}.{service.name}/{method.name}",
322+
"route": f"/{input_package_name}.{service.name}/{method.name}",
323323
"input": get_type_reference(
324-
package, output["imports"], method.input_type
324+
input_package_name, template_data["imports"], method.input_type
325325
).strip('"'),
326326
"input_message": input_message,
327327
"output": get_type_reference(
328-
package,
329-
output["imports"],
328+
input_package_name,
329+
template_data["imports"],
330330
method.output_type,
331331
unwrap=False,
332332
),
@@ -336,30 +336,32 @@ def generate_code(request, response):
336336
)
337337

338338
if method.client_streaming:
339-
output["typing_imports"].add("AsyncIterable")
340-
output["typing_imports"].add("Iterable")
341-
output["typing_imports"].add("Union")
339+
template_data["typing_imports"].add("AsyncIterable")
340+
template_data["typing_imports"].add("Iterable")
341+
template_data["typing_imports"].add("Union")
342342
if method.server_streaming:
343-
output["typing_imports"].add("AsyncIterator")
343+
template_data["typing_imports"].add("AsyncIterator")
344344

345-
output["services"].append(data)
345+
template_data["services"].append(data)
346346

347-
output["imports"] = sorted(output["imports"])
348-
output["datetime_imports"] = sorted(output["datetime_imports"])
349-
output["typing_imports"] = sorted(output["typing_imports"])
347+
template_data["imports"] = sorted(template_data["imports"])
348+
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
349+
template_data["typing_imports"] = sorted(template_data["typing_imports"])
350350

351351
# Fill response
352+
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
353+
output_paths.add(output_path)
354+
352355
f = response.file.add()
353-
f.name = filename
356+
f.name = str(output_path)
354357

355358
# Render and then format the output file.
356359
f.content = black.format_str(
357-
template.render(description=output),
360+
template.render(description=template_data),
358361
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
359362
)
360363

361364
# Make each output directory a package with __init__ file
362-
output_paths = set(pathlib.Path(path) for path in output_map.keys())
363365
init_files = (
364366
set(
365367
directory.joinpath("__init__.py")
@@ -373,8 +375,8 @@ def generate_code(request, response):
373375
init = response.file.add()
374376
init.name = str(init_file)
375377

376-
for filename in sorted(output_paths.union(init_files)):
377-
print(f"Writing {filename}", file=sys.stderr)
378+
for output_package_name in sorted(output_paths.union(init_files)):
379+
print(f"Writing {output_package_name}", file=sys.stderr)
378380

379381

380382
def main():

0 commit comments

Comments
 (0)