29
29
reference to `A` to `B`'s `fields` attribute.
30
30
"""
31
31
32
+ from __future__ import annotations
32
33
33
34
import builtins
34
35
import re
35
36
import textwrap
37
+ from collections .abc import (
38
+ Iterable ,
39
+ Iterator ,
40
+ )
36
41
from dataclasses import (
37
42
dataclass ,
38
43
field ,
39
44
)
40
- from typing import (
41
- Dict ,
42
- Iterable ,
43
- Iterator ,
44
- List ,
45
- Optional ,
46
- Set ,
47
- Type ,
48
- Union ,
49
- )
45
+ from typing import Any
50
46
51
47
import betterproto
52
48
from betterproto import which_one_of
53
- from betterproto .casing import sanitize_name
54
49
from betterproto .compile .importing import (
55
50
get_type_reference ,
56
51
parse_source_type_name ,
57
52
)
58
53
from betterproto .compile .naming import (
59
54
pythonize_class_name ,
55
+ pythonize_enum_member_name ,
60
56
pythonize_field_name ,
61
57
pythonize_method_name ,
62
58
)
69
65
FieldDescriptorProtoType ,
70
66
FileDescriptorProto ,
71
67
MethodDescriptorProto ,
68
+ ServiceDescriptorProto ,
72
69
)
73
70
from betterproto .lib .google .protobuf .compiler import CodeGeneratorRequest
74
71
75
- from ..compile .importing import (
76
- get_type_reference ,
77
- parse_source_type_name ,
78
- )
79
- from ..compile .naming import (
80
- pythonize_class_name ,
81
- pythonize_enum_member_name ,
82
- pythonize_field_name ,
83
- pythonize_method_name ,
84
- )
85
-
86
72
87
73
# Create a unique placeholder to deal with
88
74
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
89
- PLACEHOLDER = object ()
75
+ PLACEHOLDER : Any = object ()
90
76
91
77
# Organize proto types into categories
92
78
PROTO_FLOAT_TYPES = (
@@ -152,7 +138,7 @@ def monkey_patch_oneof_index():
152
138
153
139
154
140
def get_comment (
155
- proto_file : " FileDescriptorProto" , path : List [int ], indent : int = 4
141
+ proto_file : FileDescriptorProto , path : list [int ], indent : int = 4
156
142
) -> str :
157
143
pad = " " * indent
158
144
for sci_loc in proto_file .source_code_info .location :
@@ -176,11 +162,11 @@ class ProtoContentBase:
176
162
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
177
163
178
164
source_file : FileDescriptorProto
179
- path : List [int ]
165
+ path : list [int ]
180
166
comment_indent : int = 4
181
- parent : Union [ " betterproto.Message" , " OutputTemplate" ]
167
+ parent : betterproto .Message | OutputTemplate
182
168
183
- __dataclass_fields__ : Dict [str , object ]
169
+ __dataclass_fields__ : dict [str , object ]
184
170
185
171
def __post_init__ (self ) -> None :
186
172
"""Checks that no fake default fields were left as placeholders."""
@@ -189,14 +175,14 @@ def __post_init__(self) -> None:
189
175
raise ValueError (f"`{ field_name } ` is a required field." )
190
176
191
177
@property
192
- def output_file (self ) -> " OutputTemplate" :
178
+ def output_file (self ) -> OutputTemplate :
193
179
current = self
194
180
while not isinstance (current , OutputTemplate ):
195
181
current = current .parent
196
182
return current
197
183
198
184
@property
199
- def request (self ) -> " PluginRequestCompiler" :
185
+ def request (self ) -> PluginRequestCompiler :
200
186
current = self
201
187
while not isinstance (current , OutputTemplate ):
202
188
current = current .parent
@@ -215,10 +201,10 @@ def comment(self) -> str:
215
201
@dataclass
216
202
class PluginRequestCompiler :
217
203
plugin_request_obj : CodeGeneratorRequest
218
- output_packages : Dict [str , " OutputTemplate" ] = field (default_factory = dict )
204
+ output_packages : dict [str , OutputTemplate ] = field (default_factory = dict )
219
205
220
206
@property
221
- def all_messages (self ) -> List [ " MessageCompiler" ]:
207
+ def all_messages (self ) -> list [ MessageCompiler ]:
222
208
"""All of the messages in this request.
223
209
224
210
Returns
@@ -242,16 +228,16 @@ class OutputTemplate:
242
228
243
229
parent_request : PluginRequestCompiler
244
230
package_proto_obj : FileDescriptorProto
245
- input_files : List [ str ] = field (default_factory = list )
246
- imports : Set [str ] = field (default_factory = set )
247
- datetime_imports : Set [str ] = field (default_factory = set )
248
- typing_imports : Set [str ] = field (default_factory = set )
249
- pydantic_imports : Set [str ] = field (default_factory = set )
231
+ input_files : list [ FileDescriptorProto ] = field (default_factory = list )
232
+ imports : set [str ] = field (default_factory = set )
233
+ datetime_imports : set [str ] = field (default_factory = set )
234
+ typing_imports : set [str ] = field (default_factory = set )
235
+ pydantic_imports : set [str ] = field (default_factory = set )
250
236
builtins_import : bool = False
251
- messages : List [ " MessageCompiler" ] = field (default_factory = list )
252
- enums : List [ " EnumDefinitionCompiler" ] = field (default_factory = list )
253
- services : List [ " ServiceCompiler" ] = field (default_factory = list )
254
- imports_type_checking_only : Set [str ] = field (default_factory = set )
237
+ messages : list [ MessageCompiler ] = field (default_factory = list )
238
+ enums : list [ EnumDefinitionCompiler ] = field (default_factory = list )
239
+ services : list [ ServiceCompiler ] = field (default_factory = list )
240
+ imports_type_checking_only : set [str ] = field (default_factory = set )
255
241
pydantic_dataclasses : bool = False
256
242
output : bool = True
257
243
@@ -278,7 +264,7 @@ def input_filenames(self) -> Iterable[str]:
278
264
return sorted (f .name for f in self .input_files )
279
265
280
266
@property
281
- def python_module_imports (self ) -> Set [str ]:
267
+ def python_module_imports (self ) -> set [str ]:
282
268
imports = set ()
283
269
if any (x for x in self .messages if any (x .deprecated_fields )):
284
270
imports .add ("warnings" )
@@ -292,14 +278,12 @@ class MessageCompiler(ProtoContentBase):
292
278
"""Representation of a protobuf message."""
293
279
294
280
source_file : FileDescriptorProto
295
- parent : Union [ " MessageCompiler" , OutputTemplate ] = PLACEHOLDER
281
+ parent : MessageCompiler | OutputTemplate = PLACEHOLDER
296
282
proto_obj : DescriptorProto = PLACEHOLDER
297
- path : List [int ] = PLACEHOLDER
298
- fields : List [Union ["FieldCompiler" , "MessageCompiler" ]] = field (
299
- default_factory = list
300
- )
283
+ path : list [int ] = PLACEHOLDER
284
+ fields : list [FieldCompiler | MessageCompiler ] = field (default_factory = list )
301
285
deprecated : bool = field (default = False , init = False )
302
- builtins_types : Set [str ] = field (default_factory = set )
286
+ builtins_types : set [str ] = field (default_factory = set )
303
287
304
288
def __post_init__ (self ) -> None :
305
289
# Add message to output file
@@ -319,6 +303,10 @@ def proto_name(self) -> str:
319
303
def py_name (self ) -> str :
320
304
return pythonize_class_name (self .proto_name )
321
305
306
+ @property
307
+ def repeated (self ) -> bool :
308
+ raise NotImplementedError
309
+
322
310
@property
323
311
def annotation (self ) -> str :
324
312
if self .repeated :
@@ -349,11 +337,12 @@ def has_message_field(self) -> bool:
349
337
350
338
351
339
def is_map (
352
- proto_field_obj : FieldDescriptorProto , parent_message : DescriptorProto
340
+ proto_field_obj : FieldDescriptorProto ,
341
+ parent_message : DescriptorProto | MessageCompiler ,
353
342
) -> bool :
354
343
"""True if proto_field_obj is a map, otherwise False."""
355
344
if proto_field_obj .type == FieldDescriptorProtoType .TYPE_MESSAGE :
356
- if not hasattr (parent_message , "nested_type" ):
345
+ if not isinstance (parent_message , DescriptorProto ):
357
346
return False
358
347
359
348
# This might be a map...
@@ -416,16 +405,16 @@ def get_field_string(self, indent: int = 4) -> str:
416
405
return f"{ name } { annotations } = { betterproto_field_type } "
417
406
418
407
@property
419
- def betterproto_field_args (self ) -> List [str ]:
408
+ def betterproto_field_args (self ) -> list [str ]:
420
409
args = []
421
410
if self .field_wraps :
422
411
args .append (f"wraps={ self .field_wraps } " )
423
412
if self .optional :
424
- args .append (f "optional=True" )
413
+ args .append ("optional=True" )
425
414
return args
426
415
427
416
@property
428
- def datetime_imports (self ) -> Set [str ]:
417
+ def datetime_imports (self ) -> set [str ]:
429
418
imports = set ()
430
419
annotation = self .annotation
431
420
# FIXME: false positives - e.g. `MyDatetimedelta`
@@ -436,7 +425,7 @@ def datetime_imports(self) -> Set[str]:
436
425
return imports
437
426
438
427
@property
439
- def typing_imports (self ) -> Set [str ]:
428
+ def typing_imports (self ) -> set [str ]:
440
429
imports = set ()
441
430
annotation = self .annotation
442
431
if "Optional[" in annotation :
@@ -448,7 +437,7 @@ def typing_imports(self) -> Set[str]:
448
437
return imports
449
438
450
439
@property
451
- def pydantic_imports (self ) -> Set [str ]:
440
+ def pydantic_imports (self ) -> set [str ]:
452
441
return set ()
453
442
454
443
@property
@@ -464,7 +453,7 @@ def add_imports_to(self, output_file: OutputTemplate) -> None:
464
453
output_file .builtins_import = output_file .builtins_import or self .use_builtins
465
454
466
455
@property
467
- def field_wraps (self ) -> Optional [ str ] :
456
+ def field_wraps (self ) -> str | None :
468
457
"""Returns betterproto wrapped field type or None."""
469
458
match_wrapper = re .match (
470
459
r"\.google\.protobuf\.(.+)Value$" , self .proto_obj .type_name
@@ -582,7 +571,7 @@ def annotation(self) -> str:
582
571
@dataclass
583
572
class OneOfFieldCompiler (FieldCompiler ):
584
573
@property
585
- def betterproto_field_args (self ) -> List [str ]:
574
+ def betterproto_field_args (self ) -> list [str ]:
586
575
args = super ().betterproto_field_args
587
576
group = self .parent .proto_obj .oneof_decl [self .proto_obj .oneof_index ].name
588
577
args .append (f'group="{ group } "' )
@@ -599,14 +588,14 @@ def optional(self) -> bool:
599
588
return True
600
589
601
590
@property
602
- def pydantic_imports (self ) -> Set [str ]:
591
+ def pydantic_imports (self ) -> set [str ]:
603
592
return {"root_validator" }
604
593
605
594
606
595
@dataclass
607
596
class MapEntryCompiler (FieldCompiler ):
608
- py_k_type : Type = PLACEHOLDER
609
- py_v_type : Type = PLACEHOLDER
597
+ py_k_type : str = PLACEHOLDER
598
+ py_v_type : str = PLACEHOLDER
610
599
proto_k_type : str = PLACEHOLDER
611
600
proto_v_type : str = PLACEHOLDER
612
601
@@ -636,7 +625,7 @@ def __post_init__(self) -> None:
636
625
super ().__post_init__ () # call FieldCompiler-> MessageCompiler __post_init__
637
626
638
627
@property
639
- def betterproto_field_args (self ) -> List [str ]:
628
+ def betterproto_field_args (self ) -> list [str ]:
640
629
return [f"betterproto.{ self .proto_k_type } " , f"betterproto.{ self .proto_v_type } " ]
641
630
642
631
@property
@@ -657,7 +646,7 @@ class EnumDefinitionCompiler(MessageCompiler):
657
646
"""Representation of a proto Enum definition."""
658
647
659
648
proto_obj : EnumDescriptorProto = PLACEHOLDER
660
- entries : List [ " EnumDefinitionCompiler.EnumEntry" ] = PLACEHOLDER
649
+ entries : list [ EnumDefinitionCompiler .EnumEntry ] = PLACEHOLDER
661
650
662
651
@dataclass (unsafe_hash = True )
663
652
class EnumEntry :
@@ -695,9 +684,9 @@ def default_value_string(self) -> str:
695
684
@dataclass
696
685
class ServiceCompiler (ProtoContentBase ):
697
686
parent : OutputTemplate = PLACEHOLDER
698
- proto_obj : DescriptorProto = PLACEHOLDER
699
- path : List [int ] = PLACEHOLDER
700
- methods : List [ " ServiceMethodCompiler" ] = field (default_factory = list )
687
+ proto_obj : ServiceDescriptorProto = PLACEHOLDER
688
+ path : list [int ] = PLACEHOLDER
689
+ methods : list [ ServiceMethodCompiler ] = field (default_factory = list )
701
690
702
691
def __post_init__ (self ) -> None :
703
692
# Add service to output file
@@ -718,7 +707,7 @@ def py_name(self) -> str:
718
707
class ServiceMethodCompiler (ProtoContentBase ):
719
708
parent : ServiceCompiler
720
709
proto_obj : MethodDescriptorProto
721
- path : List [int ] = PLACEHOLDER
710
+ path : list [int ] = PLACEHOLDER
722
711
comment_indent : int = 8
723
712
724
713
def __post_init__ (self ) -> None :
@@ -769,7 +758,7 @@ def route(self) -> str:
769
758
return f"/{ package_part } { self .parent .proto_name } /{ self .proto_name } "
770
759
771
760
@property
772
- def py_input_message (self ) -> Optional [ MessageCompiler ] :
761
+ def py_input_message (self ) -> MessageCompiler | None :
773
762
"""Find the input message object.
774
763
775
764
Returns
0 commit comments