Skip to content

Commit 87b3a4b

Browse files
Move parsing of protobuf data types and services into separate methods
1 parent f2e8719 commit 87b3a4b

File tree

1 file changed

+182
-178
lines changed

1 file changed

+182
-178
lines changed

betterproto/plugin.py

Lines changed: 182 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -163,186 +163,10 @@ def generate_code(request, response):
163163
for proto_file in output_package_content["files"]:
164164
item: DescriptorProto
165165
for item, path in traverse(proto_file):
166-
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
167-
168-
if isinstance(item, DescriptorProto):
169-
# print(item, file=sys.stderr)
170-
if item.options.map_entry:
171-
# Skip generated map entry messages since we just use dicts
172-
continue
173-
174-
data.update(
175-
{
176-
"type": "Message",
177-
"comment": get_comment(proto_file, path),
178-
"properties": [],
179-
}
180-
)
181-
182-
for i, f in enumerate(item.field):
183-
t = py_type(input_package_name, template_data["imports"], f)
184-
zero = get_py_zero(f.type)
185-
186-
repeated = False
187-
packed = False
188-
189-
field_type = f.Type.Name(f.type).lower()[5:]
190-
191-
field_wraps = ""
192-
match_wrapper = re.match(
193-
r"\.google\.protobuf\.(.+)Value", f.type_name
194-
)
195-
if match_wrapper:
196-
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
197-
if hasattr(betterproto, wrapped_type):
198-
field_wraps = f"betterproto.{wrapped_type}"
199-
200-
map_types = None
201-
if f.type == 11:
202-
# This might be a map...
203-
message_type = f.type_name.split(".").pop().lower()
204-
# message_type = py_type(package)
205-
map_entry = f"{f.name.replace('_', '').lower()}entry"
206-
207-
if message_type == map_entry:
208-
for nested in item.nested_type:
209-
if (
210-
nested.name.replace("_", "").lower()
211-
== map_entry
212-
):
213-
if nested.options.map_entry:
214-
# print("Found a map!", file=sys.stderr)
215-
k = py_type(
216-
input_package_name,
217-
template_data["imports"],
218-
nested.field[0],
219-
)
220-
v = py_type(
221-
input_package_name,
222-
template_data["imports"],
223-
nested.field[1],
224-
)
225-
t = f"Dict[{k}, {v}]"
226-
field_type = "map"
227-
map_types = (
228-
f.Type.Name(nested.field[0].type),
229-
f.Type.Name(nested.field[1].type),
230-
)
231-
template_data["typing_imports"].add("Dict")
232-
233-
if f.label == 3 and field_type != "map":
234-
# Repeated field
235-
repeated = True
236-
t = f"List[{t}]"
237-
zero = "[]"
238-
template_data["typing_imports"].add("List")
239-
240-
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
241-
packed = True
242-
243-
one_of = ""
244-
if f.HasField("oneof_index"):
245-
one_of = item.oneof_decl[f.oneof_index].name
246-
247-
if "Optional[" in t:
248-
template_data["typing_imports"].add("Optional")
249-
250-
if "timedelta" in t:
251-
template_data["datetime_imports"].add("timedelta")
252-
elif "datetime" in t:
253-
template_data["datetime_imports"].add("datetime")
254-
255-
data["properties"].append(
256-
{
257-
"name": f.name,
258-
"py_name": pythonize_field_name(f.name),
259-
"number": f.number,
260-
"comment": get_comment(proto_file, path + [2, i]),
261-
"proto_type": int(f.type),
262-
"field_type": field_type,
263-
"field_wraps": field_wraps,
264-
"map_types": map_types,
265-
"type": t,
266-
"zero": zero,
267-
"repeated": repeated,
268-
"packed": packed,
269-
"one_of": one_of,
270-
}
271-
)
272-
# print(f, file=sys.stderr)
273-
274-
template_data["messages"].append(data)
275-
elif isinstance(item, EnumDescriptorProto):
276-
# print(item.name, path, file=sys.stderr)
277-
data.update(
278-
{
279-
"type": "Enum",
280-
"comment": get_comment(proto_file, path),
281-
"entries": [
282-
{
283-
"name": v.name,
284-
"value": v.number,
285-
"comment": get_comment(proto_file, path + [2, i]),
286-
}
287-
for i, v in enumerate(item.value)
288-
],
289-
}
290-
)
291-
292-
template_data["enums"].append(data)
166+
read_protobuf_type(input_package_name, item, path, proto_file, template_data)
293167

294168
for i, service in enumerate(proto_file.service):
295-
# print(service, file=sys.stderr)
296-
297-
data = {
298-
"name": service.name,
299-
"py_name": pythonize_class_name(service.name),
300-
"comment": get_comment(proto_file, [6, i]),
301-
"methods": [],
302-
}
303-
304-
for j, method in enumerate(service.method):
305-
input_message = None
306-
input_type = get_type_reference(
307-
input_package_name, template_data["imports"], method.input_type
308-
).strip('"')
309-
for msg in template_data["messages"]:
310-
if msg["name"] == input_type:
311-
input_message = msg
312-
for field in msg["properties"]:
313-
if field["zero"] == "None":
314-
template_data["typing_imports"].add("Optional")
315-
break
316-
317-
data["methods"].append(
318-
{
319-
"name": method.name,
320-
"py_name": pythonize_method_name(method.name),
321-
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
322-
"route": f"/{input_package_name}.{service.name}/{method.name}",
323-
"input": get_type_reference(
324-
input_package_name, template_data["imports"], method.input_type
325-
).strip('"'),
326-
"input_message": input_message,
327-
"output": get_type_reference(
328-
input_package_name,
329-
template_data["imports"],
330-
method.output_type,
331-
unwrap=False,
332-
),
333-
"client_streaming": method.client_streaming,
334-
"server_streaming": method.server_streaming,
335-
}
336-
)
337-
338-
if method.client_streaming:
339-
template_data["typing_imports"].add("AsyncIterable")
340-
template_data["typing_imports"].add("Iterable")
341-
template_data["typing_imports"].add("Union")
342-
if method.server_streaming:
343-
template_data["typing_imports"].add("AsyncIterator")
344-
345-
template_data["services"].append(data)
169+
read_protobuf_service(i, input_package_name, proto_file, service, template_data)
346170

347171
template_data["imports"] = sorted(template_data["imports"])
348172
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
@@ -379,6 +203,186 @@ def generate_code(request, response):
379203
print(f"Writing {output_package_name}", file=sys.stderr)
380204

381205

206+
def read_protobuf_service(i, input_package_name, proto_file, service, template_data):
207+
# print(service, file=sys.stderr)
208+
data = {
209+
"name": service.name,
210+
"py_name": pythonize_class_name(service.name),
211+
"comment": get_comment(proto_file, [6, i]),
212+
"methods": [],
213+
}
214+
for j, method in enumerate(service.method):
215+
input_message = None
216+
input_type = get_type_reference(
217+
input_package_name, template_data["imports"], method.input_type
218+
).strip('"')
219+
for msg in template_data["messages"]:
220+
if msg["name"] == input_type:
221+
input_message = msg
222+
for field in msg["properties"]:
223+
if field["zero"] == "None":
224+
template_data["typing_imports"].add("Optional")
225+
break
226+
227+
data["methods"].append(
228+
{
229+
"name": method.name,
230+
"py_name": pythonize_method_name(method.name),
231+
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
232+
"route": f"/{input_package_name}.{service.name}/{method.name}",
233+
"input": get_type_reference(
234+
input_package_name, template_data["imports"], method.input_type
235+
).strip('"'),
236+
"input_message": input_message,
237+
"output": get_type_reference(
238+
input_package_name,
239+
template_data["imports"],
240+
method.output_type,
241+
unwrap=False,
242+
),
243+
"client_streaming": method.client_streaming,
244+
"server_streaming": method.server_streaming,
245+
}
246+
)
247+
248+
if method.client_streaming:
249+
template_data["typing_imports"].add("AsyncIterable")
250+
template_data["typing_imports"].add("Iterable")
251+
template_data["typing_imports"].add("Union")
252+
if method.server_streaming:
253+
template_data["typing_imports"].add("AsyncIterator")
254+
template_data["services"].append(data)
255+
256+
257+
def read_protobuf_type(input_package_name, item, path, proto_file, template_data):
258+
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
259+
if isinstance(item, DescriptorProto):
260+
# print(item, file=sys.stderr)
261+
if item.options.map_entry:
262+
# Skip generated map entry messages since we just use dicts
263+
return
264+
265+
data.update(
266+
{
267+
"type": "Message",
268+
"comment": get_comment(proto_file, path),
269+
"properties": [],
270+
}
271+
)
272+
273+
for i, f in enumerate(item.field):
274+
t = py_type(input_package_name, template_data["imports"], f)
275+
zero = get_py_zero(f.type)
276+
277+
repeated = False
278+
packed = False
279+
280+
field_type = f.Type.Name(f.type).lower()[5:]
281+
282+
field_wraps = ""
283+
match_wrapper = re.match(
284+
r"\.google\.protobuf\.(.+)Value", f.type_name
285+
)
286+
if match_wrapper:
287+
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
288+
if hasattr(betterproto, wrapped_type):
289+
field_wraps = f"betterproto.{wrapped_type}"
290+
291+
map_types = None
292+
if f.type == 11:
293+
# This might be a map...
294+
message_type = f.type_name.split(".").pop().lower()
295+
# message_type = py_type(package)
296+
map_entry = f"{f.name.replace('_', '').lower()}entry"
297+
298+
if message_type == map_entry:
299+
for nested in item.nested_type:
300+
if (
301+
nested.name.replace("_", "").lower()
302+
== map_entry
303+
):
304+
if nested.options.map_entry:
305+
# print("Found a map!", file=sys.stderr)
306+
k = py_type(
307+
input_package_name,
308+
template_data["imports"],
309+
nested.field[0],
310+
)
311+
v = py_type(
312+
input_package_name,
313+
template_data["imports"],
314+
nested.field[1],
315+
)
316+
t = f"Dict[{k}, {v}]"
317+
field_type = "map"
318+
map_types = (
319+
f.Type.Name(nested.field[0].type),
320+
f.Type.Name(nested.field[1].type),
321+
)
322+
template_data["typing_imports"].add("Dict")
323+
324+
if f.label == 3 and field_type != "map":
325+
# Repeated field
326+
repeated = True
327+
t = f"List[{t}]"
328+
zero = "[]"
329+
template_data["typing_imports"].add("List")
330+
331+
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
332+
packed = True
333+
334+
one_of = ""
335+
if f.HasField("oneof_index"):
336+
one_of = item.oneof_decl[f.oneof_index].name
337+
338+
if "Optional[" in t:
339+
template_data["typing_imports"].add("Optional")
340+
341+
if "timedelta" in t:
342+
template_data["datetime_imports"].add("timedelta")
343+
elif "datetime" in t:
344+
template_data["datetime_imports"].add("datetime")
345+
346+
data["properties"].append(
347+
{
348+
"name": f.name,
349+
"py_name": pythonize_field_name(f.name),
350+
"number": f.number,
351+
"comment": get_comment(proto_file, path + [2, i]),
352+
"proto_type": int(f.type),
353+
"field_type": field_type,
354+
"field_wraps": field_wraps,
355+
"map_types": map_types,
356+
"type": t,
357+
"zero": zero,
358+
"repeated": repeated,
359+
"packed": packed,
360+
"one_of": one_of,
361+
}
362+
)
363+
# print(f, file=sys.stderr)
364+
365+
template_data["messages"].append(data)
366+
elif isinstance(item, EnumDescriptorProto):
367+
# print(item.name, path, file=sys.stderr)
368+
data.update(
369+
{
370+
"type": "Enum",
371+
"comment": get_comment(proto_file, path),
372+
"entries": [
373+
{
374+
"name": v.name,
375+
"value": v.number,
376+
"comment": get_comment(proto_file, path + [2, i]),
377+
}
378+
for i, v in enumerate(item.value)
379+
],
380+
}
381+
)
382+
383+
template_data["enums"].append(data)
384+
385+
382386
def main():
383387
"""The plugin's main entry point."""
384388
# Read request message from stdin

0 commit comments

Comments
 (0)