11
11
from google .protobuf .compiler .plugin_pb2 import CodeGeneratorRequest
12
12
13
13
import betterproto
14
- from betterproto .compile .importing import get_type_reference
14
+ from betterproto .compile .importing import get_type_reference , parse_source_type_name
15
15
from betterproto .compile .naming import (
16
16
pythonize_class_name ,
17
17
pythonize_field_name ,
18
18
pythonize_method_name ,
19
19
)
20
+ from betterproto .lib .google .protobuf import ServiceDescriptorProto
20
21
21
22
try :
22
23
# betterproto[compiler] specific dependencies
@@ -76,11 +77,12 @@ def get_py_zero(type_num: int) -> Union[str, float]:
76
77
return zero
77
78
78
79
79
- # Todo: Keep information about nested hierarchy
80
80
def traverse (proto_file ):
81
+ # Todo: Keep information about nested hierarchy
81
82
def _traverse (path , items , prefix = "" ):
82
83
for i , item in enumerate (items ):
83
- # Adjust the name since we flatten the heirarchy.
84
+ # Adjust the name since we flatten the hierarchy.
85
+ # Todo: don't change the name, but include full name in returned tuple
84
86
item .name = next_prefix = prefix + item .name
85
87
yield item , path + [i ]
86
88
@@ -162,17 +164,21 @@ def generate_code(request, response):
162
164
output_package_content ["template_data" ] = template_data
163
165
164
166
# Read Messages and Enums
167
+ output_types = []
165
168
for output_package_name , output_package_content in output_package_files .items ():
166
169
for proto_file in output_package_content ["files" ]:
167
170
for item , path in traverse (proto_file ):
168
- read_protobuf_object (item , path , proto_file , output_package_content )
171
+ type_data = read_protobuf_type (
172
+ item , path , proto_file , output_package_content
173
+ )
174
+ output_types .append (type_data )
169
175
170
176
# Read Services
171
177
for output_package_name , output_package_content in output_package_files .items ():
172
178
for proto_file in output_package_content ["files" ]:
173
179
for index , service in enumerate (proto_file .service ):
174
180
read_protobuf_service (
175
- service , index , proto_file , output_package_content
181
+ service , index , proto_file , output_package_content , output_types
176
182
)
177
183
178
184
# Render files
@@ -214,63 +220,31 @@ def generate_code(request, response):
214
220
print (f"Writing { output_package_name } " , file = sys .stderr )
215
221
216
222
217
- def read_protobuf_service (service : DescriptorProto , index , proto_file , content ):
218
- input_package_name = content ["input_package" ]
219
- template_data = content ["template_data" ]
220
- # print(service, file=sys.stderr)
221
- data = {
222
- "name" : service .name ,
223
- "py_name" : pythonize_class_name (service .name ),
224
- "comment" : get_comment (proto_file , [6 , index ]),
225
- "methods" : [],
226
- }
227
- for j , method in enumerate (service .method ):
228
- input_message = None
229
- input_type = get_type_reference (
230
- input_package_name , template_data ["imports" ], method .input_type
231
- ).strip ('"' )
232
- for msg in template_data ["messages" ]:
233
- if msg ["name" ] == input_type :
234
- input_message = msg
235
- for field in msg ["properties" ]:
236
- if field ["zero" ] == "None" :
237
- template_data ["typing_imports" ].add ("Optional" )
238
- break
223
+ def lookup_method_input_type (method , types ):
224
+ package , name = parse_source_type_name (method .input_type )
239
225
240
- data ["methods" ].append (
241
- {
242
- "name" : method .name ,
243
- "py_name" : pythonize_method_name (method .name ),
244
- "comment" : get_comment (proto_file , [6 , index , 2 , j ], indent = 8 ),
245
- "route" : f"/{ input_package_name } .{ service .name } /{ method .name } " ,
246
- "input" : get_type_reference (
247
- input_package_name , template_data ["imports" ], method .input_type
248
- ).strip ('"' ),
249
- "input_message" : input_message ,
250
- "output" : get_type_reference (
251
- input_package_name ,
252
- template_data ["imports" ],
253
- method .output_type ,
254
- unwrap = False ,
255
- ),
256
- "client_streaming" : method .client_streaming ,
257
- "server_streaming" : method .server_streaming ,
258
- }
259
- )
226
+ for known_type in types :
227
+ if known_type ["type" ] != "Message" :
228
+ continue
260
229
261
- if method . client_streaming :
262
- template_data [ "typing_imports" ]. add ( "AsyncIterable" )
263
- template_data [ "typing_imports" ]. add ( "Iterable" )
264
- template_data [ "typing_imports" ]. add ( "Union" )
265
- if method . server_streaming :
266
- template_data [ "typing_imports" ]. add ( "AsyncIterator" )
267
- template_data [ "services" ]. append ( data )
230
+ # Nested types are currently flattened without dots.
231
+ # Todo: keep a fully quantified name in types, that is comparable with method.input_type
232
+ if (
233
+ package == known_type [ "package" ]
234
+ and name . replace ( "." , "" ) == known_type [ "name" ]
235
+ ):
236
+ return known_type
268
237
269
238
270
- def read_protobuf_object (item : DescriptorProto , path : List [int ], proto_file , content ):
239
+ def read_protobuf_type (item : DescriptorProto , path : List [int ], proto_file , content ):
271
240
input_package_name = content ["input_package" ]
272
241
template_data = content ["template_data" ]
273
- data = {"name" : item .name , "py_name" : pythonize_class_name (item .name )}
242
+ data = {
243
+ "name" : item .name ,
244
+ "py_name" : pythonize_class_name (item .name ),
245
+ "descriptor" : item ,
246
+ "package" : input_package_name ,
247
+ }
274
248
if isinstance (item , DescriptorProto ):
275
249
# print(item, file=sys.stderr)
276
250
if item .options .map_entry :
@@ -373,6 +347,7 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
373
347
# print(f, file=sys.stderr)
374
348
375
349
template_data ["messages" ].append (data )
350
+ return data
376
351
elif isinstance (item , EnumDescriptorProto ):
377
352
# print(item.name, path, file=sys.stderr)
378
353
data .update (
@@ -391,6 +366,57 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
391
366
)
392
367
393
368
template_data ["enums" ].append (data )
369
+ return data
370
+
371
+
372
+ def read_protobuf_service (
373
+ service : ServiceDescriptorProto , index , proto_file , content , output_types
374
+ ):
375
+ input_package_name = content ["input_package" ]
376
+ template_data = content ["template_data" ]
377
+ # print(service, file=sys.stderr)
378
+ data = {
379
+ "name" : service .name ,
380
+ "py_name" : pythonize_class_name (service .name ),
381
+ "comment" : get_comment (proto_file , [6 , index ]),
382
+ "methods" : [],
383
+ }
384
+ for j , method in enumerate (service .method ):
385
+ method_input_message = lookup_method_input_type (method , output_types )
386
+
387
+ if method_input_message :
388
+ for field in method_input_message ["properties" ]:
389
+ if field ["zero" ] == "None" :
390
+ template_data ["typing_imports" ].add ("Optional" )
391
+
392
+ data ["methods" ].append (
393
+ {
394
+ "name" : method .name ,
395
+ "py_name" : pythonize_method_name (method .name ),
396
+ "comment" : get_comment (proto_file , [6 , index , 2 , j ], indent = 8 ),
397
+ "route" : f"/{ input_package_name } .{ service .name } /{ method .name } " ,
398
+ "input" : get_type_reference (
399
+ input_package_name , template_data ["imports" ], method .input_type
400
+ ).strip ('"' ),
401
+ "input_message" : method_input_message ,
402
+ "output" : get_type_reference (
403
+ input_package_name ,
404
+ template_data ["imports" ],
405
+ method .output_type ,
406
+ unwrap = False ,
407
+ ),
408
+ "client_streaming" : method .client_streaming ,
409
+ "server_streaming" : method .server_streaming ,
410
+ }
411
+ )
412
+
413
+ if method .client_streaming :
414
+ template_data ["typing_imports" ].add ("AsyncIterable" )
415
+ template_data ["typing_imports" ].add ("Iterable" )
416
+ template_data ["typing_imports" ].add ("Union" )
417
+ if method .server_streaming :
418
+ template_data ["typing_imports" ].add ("AsyncIterator" )
419
+ template_data ["services" ].append (data )
394
420
395
421
396
422
def main ():
0 commit comments