@@ -163,186 +163,10 @@ def generate_code(request, response):
163
163
for proto_file in output_package_content ["files" ]:
164
164
item : DescriptorProto
165
165
for item , path in traverse (proto_file ):
166
- data = {"name" : item .name , "py_name" : pythonize_class_name (item .name )}
167
-
168
- if isinstance (item , DescriptorProto ):
169
- # print(item, file=sys.stderr)
170
- if item .options .map_entry :
171
- # Skip generated map entry messages since we just use dicts
172
- continue
173
-
174
- data .update (
175
- {
176
- "type" : "Message" ,
177
- "comment" : get_comment (proto_file , path ),
178
- "properties" : [],
179
- }
180
- )
181
-
182
- for i , f in enumerate (item .field ):
183
- t = py_type (input_package_name , template_data ["imports" ], f )
184
- zero = get_py_zero (f .type )
185
-
186
- repeated = False
187
- packed = False
188
-
189
- field_type = f .Type .Name (f .type ).lower ()[5 :]
190
-
191
- field_wraps = ""
192
- match_wrapper = re .match (
193
- r"\.google\.protobuf\.(.+)Value" , f .type_name
194
- )
195
- if match_wrapper :
196
- wrapped_type = "TYPE_" + match_wrapper .group (1 ).upper ()
197
- if hasattr (betterproto , wrapped_type ):
198
- field_wraps = f"betterproto.{ wrapped_type } "
199
-
200
- map_types = None
201
- if f .type == 11 :
202
- # This might be a map...
203
- message_type = f .type_name .split ("." ).pop ().lower ()
204
- # message_type = py_type(package)
205
- map_entry = f"{ f .name .replace ('_' , '' ).lower ()} entry"
206
-
207
- if message_type == map_entry :
208
- for nested in item .nested_type :
209
- if (
210
- nested .name .replace ("_" , "" ).lower ()
211
- == map_entry
212
- ):
213
- if nested .options .map_entry :
214
- # print("Found a map!", file=sys.stderr)
215
- k = py_type (
216
- input_package_name ,
217
- template_data ["imports" ],
218
- nested .field [0 ],
219
- )
220
- v = py_type (
221
- input_package_name ,
222
- template_data ["imports" ],
223
- nested .field [1 ],
224
- )
225
- t = f"Dict[{ k } , { v } ]"
226
- field_type = "map"
227
- map_types = (
228
- f .Type .Name (nested .field [0 ].type ),
229
- f .Type .Name (nested .field [1 ].type ),
230
- )
231
- template_data ["typing_imports" ].add ("Dict" )
232
-
233
- if f .label == 3 and field_type != "map" :
234
- # Repeated field
235
- repeated = True
236
- t = f"List[{ t } ]"
237
- zero = "[]"
238
- template_data ["typing_imports" ].add ("List" )
239
-
240
- if f .type in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 13 , 15 , 16 , 17 , 18 ]:
241
- packed = True
242
-
243
- one_of = ""
244
- if f .HasField ("oneof_index" ):
245
- one_of = item .oneof_decl [f .oneof_index ].name
246
-
247
- if "Optional[" in t :
248
- template_data ["typing_imports" ].add ("Optional" )
249
-
250
- if "timedelta" in t :
251
- template_data ["datetime_imports" ].add ("timedelta" )
252
- elif "datetime" in t :
253
- template_data ["datetime_imports" ].add ("datetime" )
254
-
255
- data ["properties" ].append (
256
- {
257
- "name" : f .name ,
258
- "py_name" : pythonize_field_name (f .name ),
259
- "number" : f .number ,
260
- "comment" : get_comment (proto_file , path + [2 , i ]),
261
- "proto_type" : int (f .type ),
262
- "field_type" : field_type ,
263
- "field_wraps" : field_wraps ,
264
- "map_types" : map_types ,
265
- "type" : t ,
266
- "zero" : zero ,
267
- "repeated" : repeated ,
268
- "packed" : packed ,
269
- "one_of" : one_of ,
270
- }
271
- )
272
- # print(f, file=sys.stderr)
273
-
274
- template_data ["messages" ].append (data )
275
- elif isinstance (item , EnumDescriptorProto ):
276
- # print(item.name, path, file=sys.stderr)
277
- data .update (
278
- {
279
- "type" : "Enum" ,
280
- "comment" : get_comment (proto_file , path ),
281
- "entries" : [
282
- {
283
- "name" : v .name ,
284
- "value" : v .number ,
285
- "comment" : get_comment (proto_file , path + [2 , i ]),
286
- }
287
- for i , v in enumerate (item .value )
288
- ],
289
- }
290
- )
291
-
292
- template_data ["enums" ].append (data )
166
+ read_protobuf_type (input_package_name , item , path , proto_file , template_data )
293
167
294
168
for i , service in enumerate (proto_file .service ):
295
- # print(service, file=sys.stderr)
296
-
297
- data = {
298
- "name" : service .name ,
299
- "py_name" : pythonize_class_name (service .name ),
300
- "comment" : get_comment (proto_file , [6 , i ]),
301
- "methods" : [],
302
- }
303
-
304
- for j , method in enumerate (service .method ):
305
- input_message = None
306
- input_type = get_type_reference (
307
- input_package_name , template_data ["imports" ], method .input_type
308
- ).strip ('"' )
309
- for msg in template_data ["messages" ]:
310
- if msg ["name" ] == input_type :
311
- input_message = msg
312
- for field in msg ["properties" ]:
313
- if field ["zero" ] == "None" :
314
- template_data ["typing_imports" ].add ("Optional" )
315
- break
316
-
317
- data ["methods" ].append (
318
- {
319
- "name" : method .name ,
320
- "py_name" : pythonize_method_name (method .name ),
321
- "comment" : get_comment (proto_file , [6 , i , 2 , j ], indent = 8 ),
322
- "route" : f"/{ input_package_name } .{ service .name } /{ method .name } " ,
323
- "input" : get_type_reference (
324
- input_package_name , template_data ["imports" ], method .input_type
325
- ).strip ('"' ),
326
- "input_message" : input_message ,
327
- "output" : get_type_reference (
328
- input_package_name ,
329
- template_data ["imports" ],
330
- method .output_type ,
331
- unwrap = False ,
332
- ),
333
- "client_streaming" : method .client_streaming ,
334
- "server_streaming" : method .server_streaming ,
335
- }
336
- )
337
-
338
- if method .client_streaming :
339
- template_data ["typing_imports" ].add ("AsyncIterable" )
340
- template_data ["typing_imports" ].add ("Iterable" )
341
- template_data ["typing_imports" ].add ("Union" )
342
- if method .server_streaming :
343
- template_data ["typing_imports" ].add ("AsyncIterator" )
344
-
345
- template_data ["services" ].append (data )
169
+ read_protobuf_service (i , input_package_name , proto_file , service , template_data )
346
170
347
171
template_data ["imports" ] = sorted (template_data ["imports" ])
348
172
template_data ["datetime_imports" ] = sorted (template_data ["datetime_imports" ])
@@ -379,6 +203,186 @@ def generate_code(request, response):
379
203
print (f"Writing { output_package_name } " , file = sys .stderr )
380
204
381
205
206
+ def read_protobuf_service (i , input_package_name , proto_file , service , template_data ):
207
+ # print(service, file=sys.stderr)
208
+ data = {
209
+ "name" : service .name ,
210
+ "py_name" : pythonize_class_name (service .name ),
211
+ "comment" : get_comment (proto_file , [6 , i ]),
212
+ "methods" : [],
213
+ }
214
+ for j , method in enumerate (service .method ):
215
+ input_message = None
216
+ input_type = get_type_reference (
217
+ input_package_name , template_data ["imports" ], method .input_type
218
+ ).strip ('"' )
219
+ for msg in template_data ["messages" ]:
220
+ if msg ["name" ] == input_type :
221
+ input_message = msg
222
+ for field in msg ["properties" ]:
223
+ if field ["zero" ] == "None" :
224
+ template_data ["typing_imports" ].add ("Optional" )
225
+ break
226
+
227
+ data ["methods" ].append (
228
+ {
229
+ "name" : method .name ,
230
+ "py_name" : pythonize_method_name (method .name ),
231
+ "comment" : get_comment (proto_file , [6 , i , 2 , j ], indent = 8 ),
232
+ "route" : f"/{ input_package_name } .{ service .name } /{ method .name } " ,
233
+ "input" : get_type_reference (
234
+ input_package_name , template_data ["imports" ], method .input_type
235
+ ).strip ('"' ),
236
+ "input_message" : input_message ,
237
+ "output" : get_type_reference (
238
+ input_package_name ,
239
+ template_data ["imports" ],
240
+ method .output_type ,
241
+ unwrap = False ,
242
+ ),
243
+ "client_streaming" : method .client_streaming ,
244
+ "server_streaming" : method .server_streaming ,
245
+ }
246
+ )
247
+
248
+ if method .client_streaming :
249
+ template_data ["typing_imports" ].add ("AsyncIterable" )
250
+ template_data ["typing_imports" ].add ("Iterable" )
251
+ template_data ["typing_imports" ].add ("Union" )
252
+ if method .server_streaming :
253
+ template_data ["typing_imports" ].add ("AsyncIterator" )
254
+ template_data ["services" ].append (data )
255
+
256
+
257
+ def read_protobuf_type (input_package_name , item , path , proto_file , template_data ):
258
+ data = {"name" : item .name , "py_name" : pythonize_class_name (item .name )}
259
+ if isinstance (item , DescriptorProto ):
260
+ # print(item, file=sys.stderr)
261
+ if item .options .map_entry :
262
+ # Skip generated map entry messages since we just use dicts
263
+ return
264
+
265
+ data .update (
266
+ {
267
+ "type" : "Message" ,
268
+ "comment" : get_comment (proto_file , path ),
269
+ "properties" : [],
270
+ }
271
+ )
272
+
273
+ for i , f in enumerate (item .field ):
274
+ t = py_type (input_package_name , template_data ["imports" ], f )
275
+ zero = get_py_zero (f .type )
276
+
277
+ repeated = False
278
+ packed = False
279
+
280
+ field_type = f .Type .Name (f .type ).lower ()[5 :]
281
+
282
+ field_wraps = ""
283
+ match_wrapper = re .match (
284
+ r"\.google\.protobuf\.(.+)Value" , f .type_name
285
+ )
286
+ if match_wrapper :
287
+ wrapped_type = "TYPE_" + match_wrapper .group (1 ).upper ()
288
+ if hasattr (betterproto , wrapped_type ):
289
+ field_wraps = f"betterproto.{ wrapped_type } "
290
+
291
+ map_types = None
292
+ if f .type == 11 :
293
+ # This might be a map...
294
+ message_type = f .type_name .split ("." ).pop ().lower ()
295
+ # message_type = py_type(package)
296
+ map_entry = f"{ f .name .replace ('_' , '' ).lower ()} entry"
297
+
298
+ if message_type == map_entry :
299
+ for nested in item .nested_type :
300
+ if (
301
+ nested .name .replace ("_" , "" ).lower ()
302
+ == map_entry
303
+ ):
304
+ if nested .options .map_entry :
305
+ # print("Found a map!", file=sys.stderr)
306
+ k = py_type (
307
+ input_package_name ,
308
+ template_data ["imports" ],
309
+ nested .field [0 ],
310
+ )
311
+ v = py_type (
312
+ input_package_name ,
313
+ template_data ["imports" ],
314
+ nested .field [1 ],
315
+ )
316
+ t = f"Dict[{ k } , { v } ]"
317
+ field_type = "map"
318
+ map_types = (
319
+ f .Type .Name (nested .field [0 ].type ),
320
+ f .Type .Name (nested .field [1 ].type ),
321
+ )
322
+ template_data ["typing_imports" ].add ("Dict" )
323
+
324
+ if f .label == 3 and field_type != "map" :
325
+ # Repeated field
326
+ repeated = True
327
+ t = f"List[{ t } ]"
328
+ zero = "[]"
329
+ template_data ["typing_imports" ].add ("List" )
330
+
331
+ if f .type in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 13 , 15 , 16 , 17 , 18 ]:
332
+ packed = True
333
+
334
+ one_of = ""
335
+ if f .HasField ("oneof_index" ):
336
+ one_of = item .oneof_decl [f .oneof_index ].name
337
+
338
+ if "Optional[" in t :
339
+ template_data ["typing_imports" ].add ("Optional" )
340
+
341
+ if "timedelta" in t :
342
+ template_data ["datetime_imports" ].add ("timedelta" )
343
+ elif "datetime" in t :
344
+ template_data ["datetime_imports" ].add ("datetime" )
345
+
346
+ data ["properties" ].append (
347
+ {
348
+ "name" : f .name ,
349
+ "py_name" : pythonize_field_name (f .name ),
350
+ "number" : f .number ,
351
+ "comment" : get_comment (proto_file , path + [2 , i ]),
352
+ "proto_type" : int (f .type ),
353
+ "field_type" : field_type ,
354
+ "field_wraps" : field_wraps ,
355
+ "map_types" : map_types ,
356
+ "type" : t ,
357
+ "zero" : zero ,
358
+ "repeated" : repeated ,
359
+ "packed" : packed ,
360
+ "one_of" : one_of ,
361
+ }
362
+ )
363
+ # print(f, file=sys.stderr)
364
+
365
+ template_data ["messages" ].append (data )
366
+ elif isinstance (item , EnumDescriptorProto ):
367
+ # print(item.name, path, file=sys.stderr)
368
+ data .update (
369
+ {
370
+ "type" : "Enum" ,
371
+ "comment" : get_comment (proto_file , path ),
372
+ "entries" : [
373
+ {
374
+ "name" : v .name ,
375
+ "value" : v .number ,
376
+ "comment" : get_comment (proto_file , path + [2 , i ]),
377
+ }
378
+ for i , v in enumerate (item .value )
379
+ ],
380
+ }
381
+ )
382
+
383
+ template_data ["enums" ].append (data )
384
+
385
+
382
386
def main ():
383
387
"""The plugin's main entry point."""
384
388
# Read request message from stdin
0 commit comments