Skip to content

Commit 040dbf6

Browse files
committed
Fix more issues
1 parent d7f8731 commit 040dbf6

File tree

2 files changed

+69
-88
lines changed

2 files changed

+69
-88
lines changed

src/betterproto/plugin/models.py

Lines changed: 55 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,30 @@
2929
reference to `A` to `B`'s `fields` attribute.
3030
"""
3131

32+
from __future__ import annotations
3233

3334
import builtins
3435
import re
3536
import textwrap
37+
from collections.abc import (
38+
Iterable,
39+
Iterator,
40+
)
3641
from dataclasses import (
3742
dataclass,
3843
field,
3944
)
40-
from typing import (
41-
Dict,
42-
Iterable,
43-
Iterator,
44-
List,
45-
Optional,
46-
Set,
47-
Type,
48-
Union,
49-
)
45+
from typing import Any
5046

5147
import betterproto
5248
from betterproto import which_one_of
53-
from betterproto.casing import sanitize_name
5449
from betterproto.compile.importing import (
5550
get_type_reference,
5651
parse_source_type_name,
5752
)
5853
from betterproto.compile.naming import (
5954
pythonize_class_name,
55+
pythonize_enum_member_name,
6056
pythonize_field_name,
6157
pythonize_method_name,
6258
)
@@ -69,24 +65,14 @@
6965
FieldDescriptorProtoType,
7066
FileDescriptorProto,
7167
MethodDescriptorProto,
68+
ServiceDescriptorProto,
7269
)
7370
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
7471

75-
from ..compile.importing import (
76-
get_type_reference,
77-
parse_source_type_name,
78-
)
79-
from ..compile.naming import (
80-
pythonize_class_name,
81-
pythonize_enum_member_name,
82-
pythonize_field_name,
83-
pythonize_method_name,
84-
)
85-
8672

8773
# Create a unique placeholder to deal with
8874
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
89-
PLACEHOLDER = object()
75+
PLACEHOLDER: Any = object()
9076

9177
# Organize proto types into categories
9278
PROTO_FLOAT_TYPES = (
@@ -152,7 +138,7 @@ def monkey_patch_oneof_index():
152138

153139

154140
def get_comment(
155-
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
141+
proto_file: FileDescriptorProto, path: list[int], indent: int = 4
156142
) -> str:
157143
pad = " " * indent
158144
for sci_loc in proto_file.source_code_info.location:
@@ -176,11 +162,11 @@ class ProtoContentBase:
176162
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
177163

178164
source_file: FileDescriptorProto
179-
path: List[int]
165+
path: list[int]
180166
comment_indent: int = 4
181-
parent: Union["betterproto.Message", "OutputTemplate"]
167+
parent: betterproto.Message | OutputTemplate
182168

183-
__dataclass_fields__: Dict[str, object]
169+
__dataclass_fields__: dict[str, object]
184170

185171
def __post_init__(self) -> None:
186172
"""Checks that no fake default fields were left as placeholders."""
@@ -189,14 +175,14 @@ def __post_init__(self) -> None:
189175
raise ValueError(f"`{field_name}` is a required field.")
190176

191177
@property
192-
def output_file(self) -> "OutputTemplate":
178+
def output_file(self) -> OutputTemplate:
193179
current = self
194180
while not isinstance(current, OutputTemplate):
195181
current = current.parent
196182
return current
197183

198184
@property
199-
def request(self) -> "PluginRequestCompiler":
185+
def request(self) -> PluginRequestCompiler:
200186
current = self
201187
while not isinstance(current, OutputTemplate):
202188
current = current.parent
@@ -215,10 +201,10 @@ def comment(self) -> str:
215201
@dataclass
216202
class PluginRequestCompiler:
217203
plugin_request_obj: CodeGeneratorRequest
218-
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
204+
output_packages: dict[str, OutputTemplate] = field(default_factory=dict)
219205

220206
@property
221-
def all_messages(self) -> List["MessageCompiler"]:
207+
def all_messages(self) -> list[MessageCompiler]:
222208
"""All of the messages in this request.
223209
224210
Returns
@@ -242,16 +228,16 @@ class OutputTemplate:
242228

243229
parent_request: PluginRequestCompiler
244230
package_proto_obj: FileDescriptorProto
245-
input_files: List[str] = field(default_factory=list)
246-
imports: Set[str] = field(default_factory=set)
247-
datetime_imports: Set[str] = field(default_factory=set)
248-
typing_imports: Set[str] = field(default_factory=set)
249-
pydantic_imports: Set[str] = field(default_factory=set)
231+
input_files: list[FileDescriptorProto] = field(default_factory=list)
232+
imports: set[str] = field(default_factory=set)
233+
datetime_imports: set[str] = field(default_factory=set)
234+
typing_imports: set[str] = field(default_factory=set)
235+
pydantic_imports: set[str] = field(default_factory=set)
250236
builtins_import: bool = False
251-
messages: List["MessageCompiler"] = field(default_factory=list)
252-
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
253-
services: List["ServiceCompiler"] = field(default_factory=list)
254-
imports_type_checking_only: Set[str] = field(default_factory=set)
237+
messages: list[MessageCompiler] = field(default_factory=list)
238+
enums: list[EnumDefinitionCompiler] = field(default_factory=list)
239+
services: list[ServiceCompiler] = field(default_factory=list)
240+
imports_type_checking_only: set[str] = field(default_factory=set)
255241
pydantic_dataclasses: bool = False
256242
output: bool = True
257243

@@ -278,7 +264,7 @@ def input_filenames(self) -> Iterable[str]:
278264
return sorted(f.name for f in self.input_files)
279265

280266
@property
281-
def python_module_imports(self) -> Set[str]:
267+
def python_module_imports(self) -> set[str]:
282268
imports = set()
283269
if any(x for x in self.messages if any(x.deprecated_fields)):
284270
imports.add("warnings")
@@ -292,14 +278,12 @@ class MessageCompiler(ProtoContentBase):
292278
"""Representation of a protobuf message."""
293279

294280
source_file: FileDescriptorProto
295-
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
281+
parent: MessageCompiler | OutputTemplate = PLACEHOLDER
296282
proto_obj: DescriptorProto = PLACEHOLDER
297-
path: List[int] = PLACEHOLDER
298-
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
299-
default_factory=list
300-
)
283+
path: list[int] = PLACEHOLDER
284+
fields: list[FieldCompiler | MessageCompiler] = field(default_factory=list)
301285
deprecated: bool = field(default=False, init=False)
302-
builtins_types: Set[str] = field(default_factory=set)
286+
builtins_types: set[str] = field(default_factory=set)
303287

304288
def __post_init__(self) -> None:
305289
# Add message to output file
@@ -319,6 +303,10 @@ def proto_name(self) -> str:
319303
def py_name(self) -> str:
320304
return pythonize_class_name(self.proto_name)
321305

306+
@property
307+
def repeated(self) -> bool:
308+
raise NotImplementedError
309+
322310
@property
323311
def annotation(self) -> str:
324312
if self.repeated:
@@ -349,11 +337,12 @@ def has_message_field(self) -> bool:
349337

350338

351339
def is_map(
352-
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
340+
proto_field_obj: FieldDescriptorProto,
341+
parent_message: DescriptorProto | MessageCompiler,
353342
) -> bool:
354343
"""True if proto_field_obj is a map, otherwise False."""
355344
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
356-
if not hasattr(parent_message, "nested_type"):
345+
if not isinstance(parent_message, DescriptorProto):
357346
return False
358347

359348
# This might be a map...
@@ -416,16 +405,16 @@ def get_field_string(self, indent: int = 4) -> str:
416405
return f"{name}{annotations} = {betterproto_field_type}"
417406

418407
@property
419-
def betterproto_field_args(self) -> List[str]:
408+
def betterproto_field_args(self) -> list[str]:
420409
args = []
421410
if self.field_wraps:
422411
args.append(f"wraps={self.field_wraps}")
423412
if self.optional:
424-
args.append(f"optional=True")
413+
args.append("optional=True")
425414
return args
426415

427416
@property
428-
def datetime_imports(self) -> Set[str]:
417+
def datetime_imports(self) -> set[str]:
429418
imports = set()
430419
annotation = self.annotation
431420
# FIXME: false positives - e.g. `MyDatetimedelta`
@@ -436,7 +425,7 @@ def datetime_imports(self) -> Set[str]:
436425
return imports
437426

438427
@property
439-
def typing_imports(self) -> Set[str]:
428+
def typing_imports(self) -> set[str]:
440429
imports = set()
441430
annotation = self.annotation
442431
if "Optional[" in annotation:
@@ -448,7 +437,7 @@ def typing_imports(self) -> Set[str]:
448437
return imports
449438

450439
@property
451-
def pydantic_imports(self) -> Set[str]:
440+
def pydantic_imports(self) -> set[str]:
452441
return set()
453442

454443
@property
@@ -464,7 +453,7 @@ def add_imports_to(self, output_file: OutputTemplate) -> None:
464453
output_file.builtins_import = output_file.builtins_import or self.use_builtins
465454

466455
@property
467-
def field_wraps(self) -> Optional[str]:
456+
def field_wraps(self) -> str | None:
468457
"""Returns betterproto wrapped field type or None."""
469458
match_wrapper = re.match(
470459
r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name
@@ -582,7 +571,7 @@ def annotation(self) -> str:
582571
@dataclass
583572
class OneOfFieldCompiler(FieldCompiler):
584573
@property
585-
def betterproto_field_args(self) -> List[str]:
574+
def betterproto_field_args(self) -> list[str]:
586575
args = super().betterproto_field_args
587576
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
588577
args.append(f'group="{group}"')
@@ -599,14 +588,14 @@ def optional(self) -> bool:
599588
return True
600589

601590
@property
602-
def pydantic_imports(self) -> Set[str]:
591+
def pydantic_imports(self) -> set[str]:
603592
return {"root_validator"}
604593

605594

606595
@dataclass
607596
class MapEntryCompiler(FieldCompiler):
608-
py_k_type: Type = PLACEHOLDER
609-
py_v_type: Type = PLACEHOLDER
597+
py_k_type: str = PLACEHOLDER
598+
py_v_type: str = PLACEHOLDER
610599
proto_k_type: str = PLACEHOLDER
611600
proto_v_type: str = PLACEHOLDER
612601

@@ -636,7 +625,7 @@ def __post_init__(self) -> None:
636625
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
637626

638627
@property
639-
def betterproto_field_args(self) -> List[str]:
628+
def betterproto_field_args(self) -> list[str]:
640629
return [f"betterproto.{self.proto_k_type}", f"betterproto.{self.proto_v_type}"]
641630

642631
@property
@@ -657,7 +646,7 @@ class EnumDefinitionCompiler(MessageCompiler):
657646
"""Representation of a proto Enum definition."""
658647

659648
proto_obj: EnumDescriptorProto = PLACEHOLDER
660-
entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
649+
entries: list[EnumDefinitionCompiler.EnumEntry] = PLACEHOLDER
661650

662651
@dataclass(unsafe_hash=True)
663652
class EnumEntry:
@@ -695,9 +684,9 @@ def default_value_string(self) -> str:
695684
@dataclass
696685
class ServiceCompiler(ProtoContentBase):
697686
parent: OutputTemplate = PLACEHOLDER
698-
proto_obj: DescriptorProto = PLACEHOLDER
699-
path: List[int] = PLACEHOLDER
700-
methods: List["ServiceMethodCompiler"] = field(default_factory=list)
687+
proto_obj: ServiceDescriptorProto = PLACEHOLDER
688+
path: list[int] = PLACEHOLDER
689+
methods: list[ServiceMethodCompiler] = field(default_factory=list)
701690

702691
def __post_init__(self) -> None:
703692
# Add service to output file
@@ -718,7 +707,7 @@ def py_name(self) -> str:
718707
class ServiceMethodCompiler(ProtoContentBase):
719708
parent: ServiceCompiler
720709
proto_obj: MethodDescriptorProto
721-
path: List[int] = PLACEHOLDER
710+
path: list[int] = PLACEHOLDER
722711
comment_indent: int = 8
723712

724713
def __post_init__(self) -> None:
@@ -769,7 +758,7 @@ def route(self) -> str:
769758
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
770759

771760
@property
772-
def py_input_message(self) -> Optional[MessageCompiler]:
761+
def py_input_message(self) -> MessageCompiler | None:
773762
"""Find the input message object.
774763
775764
Returns

src/betterproto/plugin/parser.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1+
from __future__ import annotations
2+
13
import pathlib
24
import sys
3-
from typing import (
4-
Generator,
5-
List,
6-
Set,
7-
Tuple,
8-
Union,
9-
)
5+
from collections.abc import Generator
106

117
from betterproto.lib.google.protobuf import (
128
DescriptorProto,
@@ -41,17 +37,13 @@
4137

4238
def traverse(
4339
proto_file: FileDescriptorProto,
44-
) -> Generator[
45-
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
46-
]:
40+
) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
4741
# Todo: Keep information about nested hierarchy
4842
def _traverse(
49-
path: List[int],
50-
items: Union[List[EnumDescriptorProto], List[DescriptorProto]],
43+
path: list[int],
44+
items: list[EnumDescriptorProto] | list[DescriptorProto],
5145
prefix: str = "",
52-
) -> Generator[
53-
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
54-
]:
46+
) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
5547
for i, item in enumerate(items):
5648
# Adjust the name since we flatten the hierarchy.
5749
# Todo: don't change the name, but include full name in returned tuple
@@ -118,7 +110,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
118110
read_protobuf_service(service, index, output_package)
119111

120112
# Generate output files
121-
output_paths: Set[pathlib.Path] = set()
113+
output_paths: set[pathlib.Path] = set()
122114
for output_package_name, output_package in request_data.output_packages.items():
123115
if not output_package.output:
124116
continue
@@ -154,10 +146,10 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
154146

155147
def _make_one_of_field_compiler(
156148
output_package: OutputTemplate,
157-
source_file: "FileDescriptorProto",
149+
source_file: FileDescriptorProto,
158150
parent: MessageCompiler,
159-
proto_obj: "FieldDescriptorProto",
160-
path: List[int],
151+
proto_obj: FieldDescriptorProto,
152+
path: list[int],
161153
) -> FieldCompiler:
162154
pydantic = output_package.pydantic_dataclasses
163155
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
@@ -170,9 +162,9 @@ def _make_one_of_field_compiler(
170162

171163

172164
def read_protobuf_type(
173-
item: DescriptorProto,
174-
path: List[int],
175-
source_file: "FileDescriptorProto",
165+
item: DescriptorProto | EnumDescriptorProto,
166+
path: list[int],
167+
source_file: FileDescriptorProto,
176168
output_package: OutputTemplate,
177169
) -> None:
178170
if isinstance(item, DescriptorProto):

0 commit comments

Comments
 (0)