1
1
#!/usr/bin/env python
2
-
2
+ import collections
3
3
import itertools
4
4
import os .path
5
5
import pathlib
8
8
import textwrap
9
9
from typing import List , Union
10
10
11
+ from google .protobuf .compiler .plugin_pb2 import CodeGeneratorRequest
12
+
11
13
import betterproto
12
14
from betterproto .compile .importing import get_type_reference
13
15
from betterproto .compile .naming import (
@@ -129,29 +131,27 @@ def generate_code(request, response):
129
131
)
130
132
template = env .get_template ("template.py.j2" )
131
133
132
- output_map = {}
134
+ # Gather output packages
135
+ output_package_files = collections .defaultdict ()
133
136
for proto_file in request .proto_file :
134
137
if (
135
138
proto_file .package == "google.protobuf"
136
139
and "INCLUDE_GOOGLE" not in plugin_options
137
140
):
138
141
continue
139
142
140
- output_file = str (pathlib .Path (* proto_file .package .split ("." ), "__init__.py" ))
141
-
142
- if output_file not in output_map :
143
- output_map [output_file ] = {"package" : proto_file .package , "files" : []}
144
- output_map [output_file ]["files" ].append (proto_file )
145
-
146
- # TODO: Figure out how to handle gRPC request/response messages and add
147
- # processing below for Service.
148
-
149
- for filename , options in output_map .items ():
150
- package = options ["package" ]
151
- # print(package, filename, file=sys.stderr)
152
- output = {
153
- "package" : package ,
154
- "files" : [f .name for f in options ["files" ]],
143
+ output_package = proto_file .package
144
+ output_package_files .setdefault (
145
+ output_package , {"input_package" : proto_file .package , "files" : []}
146
+ )
147
+ output_package_files [output_package ]["files" ].append (proto_file )
148
+
149
+ output_paths = set ()
150
+ for output_package_name , output_package_content in output_package_files .items ():
151
+ input_package_name = output_package_content ["input_package" ]
152
+ template_data = {
153
+ "input_package" : input_package_name ,
154
+ "files" : [f .name for f in output_package_content ["files" ]],
155
155
"imports" : set (),
156
156
"datetime_imports" : set (),
157
157
"typing_imports" : set (),
@@ -160,7 +160,7 @@ def generate_code(request, response):
160
160
"services" : [],
161
161
}
162
162
163
- for proto_file in options ["files" ]:
163
+ for proto_file in output_package_content ["files" ]:
164
164
item : DescriptorProto
165
165
for item , path in traverse (proto_file ):
166
166
data = {"name" : item .name , "py_name" : pythonize_class_name (item .name )}
@@ -180,7 +180,7 @@ def generate_code(request, response):
180
180
)
181
181
182
182
for i , f in enumerate (item .field ):
183
- t = py_type (package , output ["imports" ], f )
183
+ t = py_type (input_package_name , template_data ["imports" ], f )
184
184
zero = get_py_zero (f .type )
185
185
186
186
repeated = False
@@ -213,13 +213,13 @@ def generate_code(request, response):
213
213
if nested .options .map_entry :
214
214
# print("Found a map!", file=sys.stderr)
215
215
k = py_type (
216
- package ,
217
- output ["imports" ],
216
+ input_package_name ,
217
+ template_data ["imports" ],
218
218
nested .field [0 ],
219
219
)
220
220
v = py_type (
221
- package ,
222
- output ["imports" ],
221
+ input_package_name ,
222
+ template_data ["imports" ],
223
223
nested .field [1 ],
224
224
)
225
225
t = f"Dict[{ k } , { v } ]"
@@ -228,14 +228,14 @@ def generate_code(request, response):
228
228
f .Type .Name (nested .field [0 ].type ),
229
229
f .Type .Name (nested .field [1 ].type ),
230
230
)
231
- output ["typing_imports" ].add ("Dict" )
231
+ template_data ["typing_imports" ].add ("Dict" )
232
232
233
233
if f .label == 3 and field_type != "map" :
234
234
# Repeated field
235
235
repeated = True
236
236
t = f"List[{ t } ]"
237
237
zero = "[]"
238
- output ["typing_imports" ].add ("List" )
238
+ template_data ["typing_imports" ].add ("List" )
239
239
240
240
if f .type in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 13 , 15 , 16 , 17 , 18 ]:
241
241
packed = True
@@ -245,12 +245,12 @@ def generate_code(request, response):
245
245
one_of = item .oneof_decl [f .oneof_index ].name
246
246
247
247
if "Optional[" in t :
248
- output ["typing_imports" ].add ("Optional" )
248
+ template_data ["typing_imports" ].add ("Optional" )
249
249
250
250
if "timedelta" in t :
251
- output ["datetime_imports" ].add ("timedelta" )
251
+ template_data ["datetime_imports" ].add ("timedelta" )
252
252
elif "datetime" in t :
253
- output ["datetime_imports" ].add ("datetime" )
253
+ template_data ["datetime_imports" ].add ("datetime" )
254
254
255
255
data ["properties" ].append (
256
256
{
@@ -271,7 +271,7 @@ def generate_code(request, response):
271
271
)
272
272
# print(f, file=sys.stderr)
273
273
274
- output ["messages" ].append (data )
274
+ template_data ["messages" ].append (data )
275
275
elif isinstance (item , EnumDescriptorProto ):
276
276
# print(item.name, path, file=sys.stderr)
277
277
data .update (
@@ -289,7 +289,7 @@ def generate_code(request, response):
289
289
}
290
290
)
291
291
292
- output ["enums" ].append (data )
292
+ template_data ["enums" ].append (data )
293
293
294
294
for i , service in enumerate (proto_file .service ):
295
295
# print(service, file=sys.stderr)
@@ -304,29 +304,29 @@ def generate_code(request, response):
304
304
for j , method in enumerate (service .method ):
305
305
input_message = None
306
306
input_type = get_type_reference (
307
- package , output ["imports" ], method .input_type
307
+ input_package_name , template_data ["imports" ], method .input_type
308
308
).strip ('"' )
309
- for msg in output ["messages" ]:
309
+ for msg in template_data ["messages" ]:
310
310
if msg ["name" ] == input_type :
311
311
input_message = msg
312
312
for field in msg ["properties" ]:
313
313
if field ["zero" ] == "None" :
314
- output ["typing_imports" ].add ("Optional" )
314
+ template_data ["typing_imports" ].add ("Optional" )
315
315
break
316
316
317
317
data ["methods" ].append (
318
318
{
319
319
"name" : method .name ,
320
320
"py_name" : pythonize_method_name (method .name ),
321
321
"comment" : get_comment (proto_file , [6 , i , 2 , j ], indent = 8 ),
322
- "route" : f"/{ package } .{ service .name } /{ method .name } " ,
322
+ "route" : f"/{ input_package_name } .{ service .name } /{ method .name } " ,
323
323
"input" : get_type_reference (
324
- package , output ["imports" ], method .input_type
324
+ input_package_name , template_data ["imports" ], method .input_type
325
325
).strip ('"' ),
326
326
"input_message" : input_message ,
327
327
"output" : get_type_reference (
328
- package ,
329
- output ["imports" ],
328
+ input_package_name ,
329
+ template_data ["imports" ],
330
330
method .output_type ,
331
331
unwrap = False ,
332
332
),
@@ -336,30 +336,32 @@ def generate_code(request, response):
336
336
)
337
337
338
338
if method .client_streaming :
339
- output ["typing_imports" ].add ("AsyncIterable" )
340
- output ["typing_imports" ].add ("Iterable" )
341
- output ["typing_imports" ].add ("Union" )
339
+ template_data ["typing_imports" ].add ("AsyncIterable" )
340
+ template_data ["typing_imports" ].add ("Iterable" )
341
+ template_data ["typing_imports" ].add ("Union" )
342
342
if method .server_streaming :
343
- output ["typing_imports" ].add ("AsyncIterator" )
343
+ template_data ["typing_imports" ].add ("AsyncIterator" )
344
344
345
- output ["services" ].append (data )
345
+ template_data ["services" ].append (data )
346
346
347
- output ["imports" ] = sorted (output ["imports" ])
348
- output ["datetime_imports" ] = sorted (output ["datetime_imports" ])
349
- output ["typing_imports" ] = sorted (output ["typing_imports" ])
347
+ template_data ["imports" ] = sorted (template_data ["imports" ])
348
+ template_data ["datetime_imports" ] = sorted (template_data ["datetime_imports" ])
349
+ template_data ["typing_imports" ] = sorted (template_data ["typing_imports" ])
350
350
351
351
# Fill response
352
+ output_path = pathlib .Path (* output_package_name .split ("." ), "__init__.py" )
353
+ output_paths .add (output_path )
354
+
352
355
f = response .file .add ()
353
- f .name = filename
356
+ f .name = str ( output_path )
354
357
355
358
# Render and then format the output file.
356
359
f .content = black .format_str (
357
- template .render (description = output ),
360
+ template .render (description = template_data ),
358
361
mode = black .FileMode (target_versions = set ([black .TargetVersion .PY37 ])),
359
362
)
360
363
361
364
# Make each output directory a package with __init__ file
362
- output_paths = set (pathlib .Path (path ) for path in output_map .keys ())
363
365
init_files = (
364
366
set (
365
367
directory .joinpath ("__init__.py" )
@@ -373,8 +375,8 @@ def generate_code(request, response):
373
375
init = response .file .add ()
374
376
init .name = str (init_file )
375
377
376
- for filename in sorted (output_paths .union (init_files )):
377
- print (f"Writing { filename } " , file = sys .stderr )
378
+ for output_package_name in sorted (output_paths .union (init_files )):
379
+ print (f"Writing { output_package_name } " , file = sys .stderr )
378
380
379
381
380
382
def main ():
0 commit comments