@@ -192,300 +192,52 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
192
192
return field .metadata ["betterproto" ]
193
193
194
194
195
- def dataclass_field (
195
+ def field (
196
196
number : int ,
197
197
proto_type : str ,
198
- default_factory : Callable [[], Any ],
199
198
* ,
199
+ default_factory : Callable [[], Any ] | None = None ,
200
200
map_types : Optional [Tuple [str , str ]] = None ,
201
201
group : Optional [str ] = None ,
202
202
wraps : Optional [str ] = None ,
203
203
optional : bool = False ,
204
204
repeated : bool = False ,
205
- ) -> dataclasses . Field :
205
+ ) -> Any : # Return type is Any to pass type checking
206
206
"""Creates a dataclass field with attached protobuf metadata."""
207
207
if repeated :
208
208
default_factory = list
209
209
210
210
elif optional or group :
211
211
default_factory = type (None )
212
212
213
+ else :
214
+ default_factory = {
215
+ TYPE_ENUM : default_factory ,
216
+ TYPE_BOOL : bool ,
217
+ TYPE_INT32 : int ,
218
+ TYPE_INT64 : int ,
219
+ TYPE_UINT32 : int ,
220
+ TYPE_UINT64 : int ,
221
+ TYPE_SINT32 : int ,
222
+ TYPE_SINT64 : int ,
223
+ TYPE_FLOAT : float ,
224
+ TYPE_DOUBLE : float ,
225
+ TYPE_FIXED32 : int ,
226
+ TYPE_SFIXED32 : int ,
227
+ TYPE_FIXED64 : int ,
228
+ TYPE_SFIXED64 : int ,
229
+ TYPE_STRING : str ,
230
+ TYPE_BYTES : bytes ,
231
+ TYPE_MESSAGE : type (None ),
232
+ TYPE_MAP : dict ,
233
+ }[proto_type ]
234
+
213
235
return dataclasses .field (
214
236
default_factory = default_factory ,
215
237
metadata = {"betterproto" : FieldMetadata (number , proto_type , map_types , group , wraps , optional )},
216
238
)
217
239
218
240
219
- # Note: the fields below return `Any` to prevent type errors in the generated
220
- # data classes since the types won't match with `Field` and they get swapped
221
- # out at runtime. The generated dataclass variables are still typed correctly.
222
-
223
-
224
- def enum_field (
225
- number : int ,
226
- enum_default_value : Callable [[], Enum ],
227
- group : Optional [str ] = None ,
228
- optional : bool = False ,
229
- repeated : bool = False ,
230
- ) -> Any :
231
- return dataclass_field (
232
- number ,
233
- TYPE_ENUM ,
234
- enum_default_value ,
235
- group = group ,
236
- optional = optional ,
237
- repeated = repeated ,
238
- )
239
-
240
-
241
- def bool_field (
242
- number : int ,
243
- group : Optional [str ] = None ,
244
- optional : bool = False ,
245
- repeated : bool = False ,
246
- ) -> Any :
247
- return dataclass_field (
248
- number ,
249
- TYPE_BOOL ,
250
- bool ,
251
- group = group ,
252
- optional = optional ,
253
- repeated = repeated ,
254
- )
255
-
256
-
257
- def int32_field (
258
- number : int ,
259
- group : Optional [str ] = None ,
260
- optional : bool = False ,
261
- repeated : bool = False ,
262
- ) -> Any :
263
- return dataclass_field (number , TYPE_INT32 , int , group = group , optional = optional , repeated = repeated )
264
-
265
-
266
- def int64_field (
267
- number : int ,
268
- group : Optional [str ] = None ,
269
- optional : bool = False ,
270
- repeated : bool = False ,
271
- ) -> Any :
272
- return dataclass_field (number , TYPE_INT64 , int , group = group , optional = optional , repeated = repeated )
273
-
274
-
275
- def uint32_field (
276
- number : int ,
277
- group : Optional [str ] = None ,
278
- optional : bool = False ,
279
- repeated : bool = False ,
280
- ) -> Any :
281
- return dataclass_field (
282
- number ,
283
- TYPE_UINT32 ,
284
- int ,
285
- group = group ,
286
- optional = optional ,
287
- repeated = repeated ,
288
- )
289
-
290
-
291
- def uint64_field (
292
- number : int ,
293
- group : Optional [str ] = None ,
294
- optional : bool = False ,
295
- repeated : bool = False ,
296
- ) -> Any :
297
- return dataclass_field (
298
- number ,
299
- TYPE_UINT64 ,
300
- int ,
301
- group = group ,
302
- optional = optional ,
303
- repeated = repeated ,
304
- )
305
-
306
-
307
- def sint32_field (
308
- number : int ,
309
- group : Optional [str ] = None ,
310
- optional : bool = False ,
311
- repeated : bool = False ,
312
- ) -> Any :
313
- return dataclass_field (
314
- number ,
315
- TYPE_SINT32 ,
316
- int ,
317
- group = group ,
318
- optional = optional ,
319
- repeated = repeated ,
320
- )
321
-
322
-
323
- def sint64_field (
324
- number : int ,
325
- group : Optional [str ] = None ,
326
- optional : bool = False ,
327
- repeated : bool = False ,
328
- ) -> Any :
329
- return dataclass_field (
330
- number ,
331
- TYPE_SINT64 ,
332
- int ,
333
- group = group ,
334
- optional = optional ,
335
- repeated = repeated ,
336
- )
337
-
338
-
339
- def float_field (
340
- number : int ,
341
- group : Optional [str ] = None ,
342
- optional : bool = False ,
343
- repeated : bool = False ,
344
- ) -> Any :
345
- return dataclass_field (
346
- number ,
347
- TYPE_FLOAT ,
348
- float ,
349
- group = group ,
350
- optional = optional ,
351
- repeated = repeated ,
352
- )
353
-
354
-
355
- def double_field (
356
- number : int ,
357
- group : Optional [str ] = None ,
358
- optional : bool = False ,
359
- repeated : bool = False ,
360
- ) -> Any :
361
- return dataclass_field (
362
- number ,
363
- TYPE_DOUBLE ,
364
- float ,
365
- group = group ,
366
- optional = optional ,
367
- repeated = repeated ,
368
- )
369
-
370
-
371
- def fixed32_field (
372
- number : int ,
373
- group : Optional [str ] = None ,
374
- optional : bool = False ,
375
- repeated : bool = False ,
376
- ) -> Any :
377
- return dataclass_field (
378
- number ,
379
- TYPE_FIXED32 ,
380
- float ,
381
- group = group ,
382
- optional = optional ,
383
- repeated = repeated ,
384
- )
385
-
386
-
387
- def fixed64_field (
388
- number : int ,
389
- group : Optional [str ] = None ,
390
- optional : bool = False ,
391
- repeated : bool = False ,
392
- ) -> Any :
393
- return dataclass_field (
394
- number ,
395
- TYPE_FIXED64 ,
396
- float ,
397
- group = group ,
398
- optional = optional ,
399
- repeated = repeated ,
400
- )
401
-
402
-
403
- def sfixed32_field (
404
- number : int ,
405
- group : Optional [str ] = None ,
406
- optional : bool = False ,
407
- repeated : bool = False ,
408
- ) -> Any :
409
- return dataclass_field (
410
- number ,
411
- TYPE_SFIXED32 ,
412
- float ,
413
- group = group ,
414
- optional = optional ,
415
- repeated = repeated ,
416
- )
417
-
418
-
419
- def sfixed64_field (
420
- number : int ,
421
- group : Optional [str ] = None ,
422
- optional : bool = False ,
423
- repeated : bool = False ,
424
- ) -> Any :
425
- return dataclass_field (
426
- number ,
427
- TYPE_SFIXED64 ,
428
- float ,
429
- group = group ,
430
- optional = optional ,
431
- repeated = repeated ,
432
- )
433
-
434
-
435
- def string_field (
436
- number : int ,
437
- group : Optional [str ] = None ,
438
- optional : bool = False ,
439
- repeated : bool = False ,
440
- ) -> Any :
441
- return dataclass_field (
442
- number ,
443
- TYPE_STRING ,
444
- str ,
445
- group = group ,
446
- optional = optional ,
447
- repeated = repeated ,
448
- )
449
-
450
-
451
- def bytes_field (
452
- number : int ,
453
- group : Optional [str ] = None ,
454
- optional : bool = False ,
455
- repeated : bool = False ,
456
- ) -> Any :
457
- return dataclass_field (
458
- number ,
459
- TYPE_BYTES ,
460
- bytes ,
461
- group = group ,
462
- optional = optional ,
463
- repeated = repeated ,
464
- )
465
-
466
-
467
- def message_field (
468
- number : int ,
469
- group : Optional [str ] = None ,
470
- wraps : Optional [str ] = None ,
471
- optional : bool = False ,
472
- repeated : bool = False ,
473
- ) -> Any :
474
- return dataclass_field (
475
- number ,
476
- TYPE_MESSAGE ,
477
- type (None ),
478
- group = group ,
479
- wraps = wraps ,
480
- optional = optional ,
481
- repeated = repeated ,
482
- )
483
-
484
-
485
- def map_field (number : int , key_type : str , value_type : str , group : Optional [str ] = None ) -> Any :
486
- return dataclass_field (number , TYPE_MAP , dict , map_types = (key_type , value_type ), group = group )
487
-
488
-
489
241
def _pack_fmt (proto_type : str ) -> str :
490
242
"""Returns a little-endian format string for reading/writing binary."""
491
243
return {
@@ -774,31 +526,31 @@ def _get_default_gen(cls: Type["Message"], fields: Iterable[dataclasses.Field])
774
526
def _get_cls_by_field (cls : Type ["Message" ], fields : Iterable [dataclasses .Field ]) -> Dict [str , Type ]:
775
527
field_cls = {}
776
528
777
- for field in fields :
778
- meta = FieldMetadata .get (field )
529
+ for field_ in fields :
530
+ meta = FieldMetadata .get (field_ )
779
531
if meta .proto_type == TYPE_MAP :
780
532
assert meta .map_types
781
- kt = cls ._cls_for (field , index = 0 )
782
- vt = cls ._cls_for (field , index = 1 )
783
- field_cls [field .name ] = dataclasses .make_dataclass (
533
+ kt = cls ._cls_for (field_ , index = 0 )
534
+ vt = cls ._cls_for (field_ , index = 1 )
535
+ field_cls [field_ .name ] = dataclasses .make_dataclass (
784
536
"Entry" ,
785
537
[
786
538
(
787
539
"key" ,
788
540
kt ,
789
- dataclass_field (1 , meta .map_types [0 ], default_factory = kt ),
541
+ field (1 , meta .map_types [0 ], default_factory = kt ),
790
542
),
791
543
(
792
544
"value" ,
793
545
vt ,
794
- dataclass_field (2 , meta .map_types [1 ], default_factory = vt ),
546
+ field (2 , meta .map_types [1 ], default_factory = vt ),
795
547
),
796
548
],
797
549
bases = (Message ,),
798
550
)
799
- field_cls [f"{ field .name } .value" ] = vt
551
+ field_cls [f"{ field_ .name } .value" ] = vt
800
552
else :
801
- field_cls [field .name ] = cls ._cls_for (field )
553
+ field_cls [field_ .name ] = cls ._cls_for (field_ )
802
554
803
555
return field_cls
804
556
0 commit comments