diff --git a/pyproject.toml b/pyproject.toml index 78d67c34..771cb0c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ protobuf = "^4" protoc-gen-python_betterproto2 = "betterproto2_compiler.plugin:main" [tool.ruff] -extend-exclude = ["tests/output_*"] +extend-exclude = ["tests/output_*", "src/betterproto2_compiler/lib"] target-version = "py310" line-length = 120 @@ -52,7 +52,11 @@ select = [ "SIM102", # Simplify return or yield statements "SIM103", # Simplify list/set/dict comprehensions + "UP", + "I", + + "COM812", # Trailing commas ] @@ -78,8 +82,8 @@ sequence = ["_format", "_sort-imports"] help = "Format the source code, and sort the imports" [tool.poe.tasks.check] -sequence = ["_check-format", "_check-imports"] -help = "Check that the source code is formatted and the imports sorted" +sequence = ["_check-format", "_check-ruff-lint"] +help = "Check that the source code is formatted and the code passes the linter" [tool.poe.tasks._format] cmd = "ruff format src tests" @@ -93,9 +97,9 @@ help = "Sort the imports" cmd = "ruff format --diff src tests" help = "Check that the source code is formatted" -[tool.poe.tasks._check-imports] -cmd = "ruff check --select I src tests" -help = "Check that the imports are sorted" +[tool.poe.tasks._check-ruff-lint] +cmd = "ruff check src tests" +help = "Check the code with the Ruff linter" [tool.poe.tasks.generate_lib] cmd = """ diff --git a/src/betterproto2_compiler/compile/importing.py b/src/betterproto2_compiler/compile/importing.py index 814fff71..d4b3ac78 100644 --- a/src/betterproto2_compiler/compile/importing.py +++ b/src/betterproto2_compiler/compile/importing.py @@ -3,11 +3,6 @@ import os from typing import ( TYPE_CHECKING, - Dict, - List, - Set, - Tuple, - Type, ) from ..casing import safe_snake_case @@ -18,7 +13,7 @@ from ..plugin.models import PluginRequestCompiler from ..plugin.typing_compiler import TypingCompiler -WRAPPER_TYPES: Dict[str, Type] = { +WRAPPER_TYPES: dict[str, type] = { ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.FloatValue": google_protobuf.FloatValue, ".google.protobuf.Int32Value": google_protobuf.Int32Value, @@ -31,7 +26,7 @@ } -def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler") -> Tuple[str, str]: +def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]: """ Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') @@ -77,7 +72,7 @@ def get_type_reference( imports: set, source_type: str, typing_compiler: TypingCompiler, - request: "PluginRequestCompiler", + request: PluginRequestCompiler, unwrap: bool = True, pydantic: bool = False, ) -> str: @@ -98,8 +93,8 @@ def get_type_reference( source_package, source_type = parse_source_type_name(source_type, request) - current_package: List[str] = package.split(".") if package else [] - py_package: List[str] = source_package.split(".") if source_package else [] + current_package: list[str] = package.split(".") if package else [] + py_package: list[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) compiling_google_protobuf = current_package == ["google", "protobuf"] @@ -122,7 +117,7 @@ def get_type_reference( return reference_cousin(current_package, imports, py_package, py_type) -def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type located in the root, i.e. sys.path. """ @@ -139,7 +134,7 @@ def reference_sibling(py_type: str) -> str: return f"{py_type}" -def reference_descendent(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type in a package that is a descendent of the current package, and adds the required import that is aliased to avoid name @@ -157,7 +152,7 @@ def reference_descendent(current_package: List[str], imports: Set[str], py_packa return f"{string_import}.{py_type}" -def reference_ancestor(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type in a package which is an ancestor to the current package, and adds the required import that is aliased (if possible) to avoid @@ -178,7 +173,7 @@ def reference_ancestor(current_package: List[str], imports: Set[str], py_package return string_alias -def reference_cousin(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type in a package that is not descendent, ancestor or sibling, and adds the required import that is aliased to avoid name conflicts. diff --git a/src/betterproto2_compiler/enum.py b/src/betterproto2_compiler/enum.py index 9a1d677b..7a6a5a71 100644 --- a/src/betterproto2_compiler/enum.py +++ b/src/betterproto2_compiler/enum.py @@ -8,9 +8,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Optional, - Tuple, ) if TYPE_CHECKING: @@ -33,14 +30,14 @@ class EnumType(EnumMeta if TYPE_CHECKING else type): _value_map_: Mapping[int, Enum] _member_map_: Mapping[str, Enum] - def __new__(mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> Self: + def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) -> Self: value_map = {} member_map = {} new_mcs = type( f"{name}Type", tuple( - dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type]) + dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type]), ), # reorder the bases so EnumType and type are last to avoid conflicts {"_value_map_": value_map, "_member_map_": member_map}, ) @@ -111,12 +108,12 @@ class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType): inherit from this. Emulates `enum.IntEnum`. """ - name: Optional[str] + name: str | None value: int if not TYPE_CHECKING: - def __new__(cls, *, name: Optional[str], value: int) -> Self: + def __new__(cls, *, name: str | None, value: int) -> Self: self = super().__new__(cls, value) super().__setattr__(self, "name", name) super().__setattr__(self, "value", value) diff --git a/src/betterproto2_compiler/grpc/__init__.py b/src/betterproto2_compiler/grpc/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/betterproto2_compiler/grpc/grpclib_client.py b/src/betterproto2_compiler/grpc/grpclib_client.py deleted file mode 100644 index ab24cedb..00000000 --- a/src/betterproto2_compiler/grpc/grpclib_client.py +++ /dev/null @@ -1,172 +0,0 @@ -import asyncio -from abc import ABC -from typing import ( - TYPE_CHECKING, - AsyncIterable, - AsyncIterator, - Collection, - Iterable, - Mapping, - Optional, - Tuple, - Type, - Union, -) - -import grpclib.const - -if TYPE_CHECKING: - from grpclib.client import Channel - from grpclib.metadata import Deadline - - from .._types import ( - IProtoMessage, - T, - ) - - -Value = Union[str, bytes] -MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] -MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] - - -class ServiceStub(ABC): - """ - Base class for async gRPC clients. - """ - - def __init__( - self, - channel: "Channel", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, - ) -> None: - self.channel = channel - self.timeout = timeout - self.deadline = deadline - self.metadata = metadata - - def __resolve_request_kwargs( - self, - timeout: Optional[float], - deadline: Optional["Deadline"], - metadata: Optional[MetadataLike], - ): - return { - "timeout": self.timeout if timeout is None else timeout, - "deadline": self.deadline if deadline is None else deadline, - "metadata": self.metadata if metadata is None else metadata, - } - - async def _unary_unary( - self, - route: str, - request: "IProtoMessage", - response_type: Type["T"], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, - ) -> "T": - """Make a unary request and return the response.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_UNARY, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - response = await stream.recv_message() - assert response is not None - return response - - async def _unary_stream( - self, - route: str, - request: "IProtoMessage", - response_type: Type["T"], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, - ) -> AsyncIterator["T"]: - """Make a unary request and return the stream response iterator.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_STREAM, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - async for message in stream: - yield message - - async def _stream_unary( - self, - route: str, - request_iterator: MessageSource, - request_type: Type["IProtoMessage"], - response_type: Type["T"], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, - ) -> "T": - """Make a stream request and return the response.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.STREAM_UNARY, - request_type, - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_request() - await self._send_messages(stream, request_iterator) - response = await stream.recv_message() - assert response is not None - return response - - async def _stream_stream( - self, - route: str, - request_iterator: MessageSource, - request_type: Type["IProtoMessage"], - response_type: Type["T"], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, - ) -> AsyncIterator["T"]: - """ - Make a stream request and return an AsyncIterator to iterate over response - messages. - """ - async with self.channel.request( - route, - grpclib.const.Cardinality.STREAM_STREAM, - request_type, - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_request() - sending_task = asyncio.ensure_future(self._send_messages(stream, request_iterator)) - try: - async for response in stream: - yield response - except: - sending_task.cancel() - raise - - @staticmethod - async def _send_messages(stream, messages: MessageSource): - if isinstance(messages, AsyncIterable): - async for message in messages: - await stream.send_message(message) - else: - for message in messages: - await stream.send_message(message) - await stream.end() diff --git a/src/betterproto2_compiler/grpc/grpclib_server.py b/src/betterproto2_compiler/grpc/grpclib_server.py deleted file mode 100644 index 61d4710b..00000000 --- a/src/betterproto2_compiler/grpc/grpclib_server.py +++ /dev/null @@ -1,32 +0,0 @@ -from abc import ABC -from collections.abc import AsyncIterable -from typing import ( - Any, - Callable, -) - -import grpclib -import grpclib.server - - -class ServiceBase(ABC): - """ - Base class for async gRPC servers. - """ - - async def _call_rpc_handler_server_stream( - self, - handler: Callable, - stream: grpclib.server.Stream, - request: Any, - ) -> None: - response_iter = handler(request) - # check if response is actually an AsyncIterator - # this might be false if the method just returns without - # yielding at least once - # in that case, we just interpret it as an empty iterator - if isinstance(response_iter, AsyncIterable): - async for response_message in response_iter: - await stream.send_message(response_message) - else: - response_iter.close() diff --git a/src/betterproto2_compiler/grpc/util/__init__.py b/src/betterproto2_compiler/grpc/util/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/betterproto2_compiler/grpc/util/async_channel.py b/src/betterproto2_compiler/grpc/util/async_channel.py deleted file mode 100644 index 9a9345e0..00000000 --- a/src/betterproto2_compiler/grpc/util/async_channel.py +++ /dev/null @@ -1,190 +0,0 @@ -import asyncio -from typing import ( - AsyncIterable, - AsyncIterator, - Iterable, - Optional, - TypeVar, - Union, -) - -T = TypeVar("T") - - -class ChannelClosed(Exception): - """ - An exception raised on an attempt to send through a closed channel - """ - - -class ChannelDone(Exception): - """ - An exception raised on an attempt to send receive from a channel that is both closed - and empty. - """ - - -class AsyncChannel(AsyncIterable[T]): - """ - A buffered async channel for sending items between coroutines with FIFO ordering. - - This makes decoupled bidirectional steaming gRPC requests easy if used like: - - .. code-block:: python - client = GeneratedStub(grpclib_chan) - request_channel = await AsyncChannel() - # We can start be sending all the requests we already have - await request_channel.send_from([RequestObject(...), RequestObject(...)]) - async for response in client.rpc_call(request_channel): - # The response iterator will remain active until the connection is closed - ... - # More items can be sent at any time - await request_channel.send(RequestObject(...)) - ... - # The channel must be closed to complete the gRPC connection - request_channel.close() - - Items can be sent through the channel by either: - - providing an iterable to the send_from method - - passing them to the send method one at a time - - Items can be received from the channel by either: - - iterating over the channel with a for loop to get all items - - calling the receive method to get one item at a time - - If the channel is empty then receivers will wait until either an item appears or the - channel is closed. - - Once the channel is closed then subsequent attempt to send through the channel will - fail with a ChannelClosed exception. - - When th channel is closed and empty then it is done, and further attempts to receive - from it will fail with a ChannelDone exception - - If multiple coroutines receive from the channel concurrently, each item sent will be - received by only one of the receivers. - - :param source: - An optional iterable will items that should be sent through the channel - immediately. - :param buffer_limit: - Limit the number of items that can be buffered in the channel, A value less than - 1 implies no limit. If the channel is full then attempts to send more items will - result in the sender waiting until an item is received from the channel. - :param close: - If set to True then the channel will automatically close after exhausting source - or immediately if no source is provided. - """ - - def __init__(self, *, buffer_limit: int = 0, close: bool = False): - self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit) - self._closed = False - self._waiting_receivers: int = 0 - # Track whether flush has been invoked so it can only happen once - self._flushed = False - - def __aiter__(self) -> AsyncIterator[T]: - return self - - async def __anext__(self) -> T: - if self.done(): - raise StopAsyncIteration - self._waiting_receivers += 1 - try: - result = await self._queue.get() - if result is self.__flush: - raise StopAsyncIteration - return result - finally: - self._waiting_receivers -= 1 - self._queue.task_done() - - def closed(self) -> bool: - """ - Returns True if this channel is closed and no-longer accepting new items - """ - return self._closed - - def done(self) -> bool: - """ - Check if this channel is done. - - :return: True if this channel is closed and and has been drained of items in - which case any further attempts to receive an item from this channel will raise - a ChannelDone exception. - """ - # After close the channel is not yet done until there is at least one waiting - # receiver per enqueued item. - return self._closed and self._queue.qsize() <= self._waiting_receivers - - async def send_from(self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False) -> "AsyncChannel[T]": - """ - Iterates the given [Async]Iterable and sends all the resulting items. - If close is set to True then subsequent send calls will be rejected with a - ChannelClosed exception. - :param source: an iterable of items to send - :param close: - if True then the channel will be closed after the source has been exhausted - - """ - if self._closed: - raise ChannelClosed("Cannot send through a closed channel") - if isinstance(source, AsyncIterable): - async for item in source: - await self._queue.put(item) - else: - for item in source: - await self._queue.put(item) - if close: - # Complete the closing process - self.close() - return self - - async def send(self, item: T) -> "AsyncChannel[T]": - """ - Send a single item over this channel. - :param item: The item to send - """ - if self._closed: - raise ChannelClosed("Cannot send through a closed channel") - await self._queue.put(item) - return self - - async def receive(self) -> Optional[T]: - """ - Returns the next item from this channel when it becomes available, - or None if the channel is closed before another item is sent. - :return: An item from the channel - """ - if self.done(): - raise ChannelDone("Cannot receive from a closed channel") - self._waiting_receivers += 1 - try: - result = await self._queue.get() - if result is self.__flush: - return None - return result - finally: - self._waiting_receivers -= 1 - self._queue.task_done() - - def close(self): - """ - Close this channel to new items - """ - self._closed = True - asyncio.ensure_future(self._flush_queue()) - - async def _flush_queue(self): - """ - To be called after the channel is closed. Pushes a number of self.__flush - objects to the queue to ensure no waiting consumers get deadlocked. - """ - if not self._flushed: - self._flushed = True - deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize()) - for _ in range(deadlocked_receivers): - await self._queue.put(self.__flush) - - # A special signal object for flushing the queue when the channel is closed - __flush = object() diff --git a/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py b/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py index 109a288e..9df1ea64 100644 --- a/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py +++ b/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py @@ -2401,13 +2401,13 @@ class Value(betterproto2_compiler.Message): ) """Represents a null value.""" - number_value: Optional[float] = betterproto2_compiler.double_field(2, optional=True, group="kind") + number_value: float | None = betterproto2_compiler.double_field(2, optional=True, group="kind") """Represents a double value.""" - string_value: Optional[str] = betterproto2_compiler.string_field(3, optional=True, group="kind") + string_value: str | None = betterproto2_compiler.string_field(3, optional=True, group="kind") """Represents a string value.""" - bool_value: Optional[bool] = betterproto2_compiler.bool_field(4, optional=True, group="kind") + bool_value: bool | None = betterproto2_compiler.bool_field(4, optional=True, group="kind") """Represents a boolean value.""" struct_value: Optional["Struct"] = betterproto2_compiler.message_field(5, optional=True, group="kind") diff --git a/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py b/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py index dc118b7f..2c61f62d 100644 --- a/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py +++ b/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py @@ -78,7 +78,6 @@ Dict, List, Mapping, - Optional, ) import betterproto2 @@ -1022,7 +1021,7 @@ class FieldDescriptorProto(betterproto2.Message): TODO(kenton): Base-64 encode? """ - oneof_index: Optional[int] = betterproto2.int32_field(9, optional=True) + oneof_index: int | None = betterproto2.int32_field(9, optional=True) """ If set, gives the index of a oneof in the containing type's oneof_decl list. This field is a member of that oneof. diff --git a/src/betterproto2_compiler/plugin/compiler.py b/src/betterproto2_compiler/plugin/compiler.py index e8d261c1..780bed82 100644 --- a/src/betterproto2_compiler/plugin/compiler.py +++ b/src/betterproto2_compiler/plugin/compiler.py @@ -14,7 +14,7 @@ "Please ensure that you've installed betterproto as " '`pip install "betterproto[compiler]"` so that compiler dependencies ' "are included." - "\033[0m" + "\033[0m", ) raise SystemExit(1) diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 99520ae8..1e44b307 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -31,18 +31,12 @@ import builtins import re +from collections.abc import Iterable, Iterator from dataclasses import ( dataclass, field, ) from typing import ( - Dict, - Iterable, - Iterator, - List, - Optional, - Set, - Type, Union, ) @@ -146,7 +140,7 @@ def get_comment( proto_file: "FileDescriptorProto", - path: List[int], + path: list[int], ) -> str: for sci_loc in proto_file.source_code_info.location: if list(sci_loc.path) == path: @@ -182,10 +176,10 @@ class ProtoContentBase: source_file: FileDescriptorProto typing_compiler: TypingCompiler - path: List[int] + path: list[int] parent: Union["betterproto2.Message", "OutputTemplate"] - __dataclass_fields__: Dict[str, object] + __dataclass_fields__: dict[str, object] def __post_init__(self) -> None: """Checks that no fake default fields were left as placeholders.""" @@ -225,10 +219,10 @@ def deprecated(self) -> bool: @dataclass class PluginRequestCompiler: plugin_request_obj: CodeGeneratorRequest - output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) + output_packages: dict[str, "OutputTemplate"] = field(default_factory=dict) @property - def all_messages(self) -> List["MessageCompiler"]: + def all_messages(self) -> list["MessageCompiler"]: """All of the messages in this request. Returns @@ -250,11 +244,11 @@ class OutputTemplate: parent_request: PluginRequestCompiler package_proto_obj: FileDescriptorProto - input_files: List[str] = field(default_factory=list) - imports_end: Set[str] = field(default_factory=set) - messages: Dict[str, "MessageCompiler"] = field(default_factory=dict) - enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict) - services: Dict[str, "ServiceCompiler"] = field(default_factory=dict) + input_files: list[str] = field(default_factory=list) + imports_end: set[str] = field(default_factory=set) + messages: dict[str, "MessageCompiler"] = field(default_factory=dict) + enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict) + services: dict[str, "ServiceCompiler"] = field(default_factory=dict) pydantic_dataclasses: bool = False output: bool = True typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler) @@ -290,9 +284,9 @@ class MessageCompiler(ProtoContentBase): typing_compiler: TypingCompiler parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER - path: List[int] = PLACEHOLDER - fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list) - builtins_types: Set[str] = field(default_factory=set) + path: list[int] = PLACEHOLDER + fields: list[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list) + builtins_types: set[str] = field(default_factory=set) def __post_init__(self) -> None: # Add message to output file @@ -328,11 +322,9 @@ def has_oneof_fields(self) -> bool: @property def has_message_field(self) -> bool: return any( - ( - field.proto_obj.type in PROTO_MESSAGE_TYPES - for field in self.fields - if isinstance(field.proto_obj, FieldDescriptorProto) - ) + field.proto_obj.type in PROTO_MESSAGE_TYPES + for field in self.fields + if isinstance(field.proto_obj, FieldDescriptorProto) ) @@ -374,8 +366,8 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: class FieldCompiler(ProtoContentBase): source_file: FileDescriptorProto typing_compiler: TypingCompiler - path: List[int] = PLACEHOLDER - builtins_types: Set[str] = field(default_factory=set) + path: list[int] = PLACEHOLDER + builtins_types: set[str] = field(default_factory=set) parent: MessageCompiler = PLACEHOLDER proto_obj: FieldDescriptorProto = PLACEHOLDER @@ -396,7 +388,7 @@ def get_field_string(self) -> str: return f'{name}: "{self.annotation}" = {betterproto_field_type}' @property - def betterproto_field_args(self) -> List[str]: + def betterproto_field_args(self) -> list[str]: args = [] if self.field_wraps: args.append(f"wraps={self.field_wraps}") @@ -416,7 +408,7 @@ def use_builtins(self) -> bool: ) @property - def field_wraps(self) -> Optional[str]: + def field_wraps(self) -> str | None: """Returns betterproto wrapped field type or None.""" match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name) if match_wrapper: @@ -428,7 +420,8 @@ def field_wraps(self) -> Optional[str]: @property def repeated(self) -> bool: return self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED and not is_map( - self.proto_obj, self.parent + self.proto_obj, + self.parent, ) @property @@ -500,7 +493,7 @@ def optional(self) -> bool: return True @property - def betterproto_field_args(self) -> List[str]: + def betterproto_field_args(self) -> list[str]: args = super().betterproto_field_args group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name args.append(f'group="{group}"') @@ -509,8 +502,8 @@ def betterproto_field_args(self) -> List[str]: @dataclass class MapEntryCompiler(FieldCompiler): - py_k_type: Optional[Type] = None - py_v_type: Optional[Type] = None + py_k_type: type | None = None + py_v_type: type | None = None proto_k_type: str = "" proto_v_type: str = "" @@ -548,7 +541,7 @@ def ready(self) -> None: raise ValueError("can't find enum") @property - def betterproto_field_args(self) -> List[str]: + def betterproto_field_args(self) -> list[str]: return [f"betterproto2.{self.proto_k_type}", f"betterproto2.{self.proto_v_type}"] @property @@ -569,7 +562,7 @@ class EnumDefinitionCompiler(MessageCompiler): """Representation of a proto Enum definition.""" proto_obj: EnumDescriptorProto = PLACEHOLDER - entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER + entries: list["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER @dataclass(unsafe_hash=True) class EnumEntry: @@ -597,8 +590,8 @@ class ServiceCompiler(ProtoContentBase): source_file: FileDescriptorProto parent: OutputTemplate = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER - path: List[int] = PLACEHOLDER - methods: List["ServiceMethodCompiler"] = field(default_factory=list) + path: list[int] = PLACEHOLDER + methods: list["ServiceMethodCompiler"] = field(default_factory=list) def __post_init__(self) -> None: # Add service to output file @@ -619,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase): source_file: FileDescriptorProto parent: ServiceCompiler proto_obj: MethodDescriptorProto - path: List[int] = PLACEHOLDER + path: list[int] = PLACEHOLDER def __post_init__(self) -> None: # Add method to service diff --git a/src/betterproto2_compiler/plugin/module_validation.py b/src/betterproto2_compiler/plugin/module_validation.py index 19a06a96..16432995 100644 --- a/src/betterproto2_compiler/plugin/module_validation.py +++ b/src/betterproto2_compiler/plugin/module_validation.py @@ -1,15 +1,10 @@ import re from collections import defaultdict +from collections.abc import Iterator from dataclasses import ( dataclass, field, ) -from typing import ( - Dict, - Iterator, - List, - Tuple, -) @dataclass @@ -17,7 +12,7 @@ class ModuleValidator: line_iterator: Iterator[str] line_number: int = field(init=False, default=0) - collisions: Dict[str, List[Tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list)) + collisions: dict[str, list[tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list)) def add_import(self, imp: str, number: int, full_line: str): """ diff --git a/src/betterproto2_compiler/plugin/parser.py b/src/betterproto2_compiler/plugin/parser.py index 38456b78..a2d9cbb3 100644 --- a/src/betterproto2_compiler/plugin/parser.py +++ b/src/betterproto2_compiler/plugin/parser.py @@ -1,12 +1,6 @@ import pathlib import sys -from typing import ( - Generator, - List, - Set, - Tuple, - Union, -) +from collections.abc import Generator from betterproto2_compiler.lib.google.protobuf import ( DescriptorProto, @@ -45,13 +39,13 @@ def traverse( proto_file: FileDescriptorProto, -) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]: +) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: # Todo: Keep information about nested hierarchy def _traverse( - path: List[int], - items: Union[List[EnumDescriptorProto], List[DescriptorProto]], + path: list[int], + items: list[EnumDescriptorProto] | list[DescriptorProto], prefix: str = "", - ) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]: + ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple @@ -82,7 +76,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: if output_package_name not in request_data.output_packages: # Create a new output if there is no output for this package request_data.output_packages[output_package_name] = OutputTemplate( - parent_request=request_data, package_proto_obj=proto_file + parent_request=request_data, + package_proto_obj=proto_file, ) # Add this input file to the output corresponding to this package request_data.output_packages[output_package_name].input_files.append(proto_file) @@ -144,7 +139,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: service.ready() # Generate output files - output_paths: Set[pathlib.Path] = set() + output_paths: set[pathlib.Path] = set() for output_package_name, output_package in request_data.output_packages.items(): if not output_package.output: continue @@ -158,7 +153,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: name=str(output_path), # Render and then format the output file content=outputfile_compiler(output_file=output_package), - ) + ), ) # Make each output directory a package with __init__ file @@ -183,7 +178,7 @@ def _make_one_of_field_compiler( source_file: "FileDescriptorProto", parent: MessageCompiler, proto_obj: "FieldDescriptorProto", - path: List[int], + path: list[int], ) -> FieldCompiler: return OneOfFieldCompiler( source_file=source_file, @@ -196,7 +191,7 @@ def _make_one_of_field_compiler( def read_protobuf_type( item: DescriptorProto, - path: List[int], + path: list[int], source_file: "FileDescriptorProto", output_package: OutputTemplate, ) -> None: diff --git a/src/betterproto2_compiler/plugin/typing_compiler.py b/src/betterproto2_compiler/plugin/typing_compiler.py index aa4f2135..fd1e120c 100644 --- a/src/betterproto2_compiler/plugin/typing_compiler.py +++ b/src/betterproto2_compiler/plugin/typing_compiler.py @@ -1,15 +1,11 @@ import abc +import builtins from collections import defaultdict +from collections.abc import Iterator from dataclasses import ( dataclass, field, ) -from typing import ( - Dict, - Iterator, - Optional, - Set, -) class TypingCompiler(metaclass=abc.ABCMeta): @@ -42,7 +38,7 @@ def async_iterator(self, type_: str) -> str: raise NotImplementedError @abc.abstractmethod - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> builtins.dict[str, set[str] | None]: """ Returns either the direct import as a key with none as value, or a set of values to import from the key. @@ -63,7 +59,7 @@ def import_lines(self) -> Iterator: @dataclass class DirectImportTypingCompiler(TypingCompiler): - _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) + _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) def optional(self, type_: str) -> str: self._imports["typing"].add("Optional") @@ -93,7 +89,7 @@ def async_iterator(self, type_: str) -> str: self._imports["typing"].add("AsyncIterator") return f"AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> builtins.dict[str, set[str] | None]: return {k: v if v else None for k, v in self._imports.items()} @@ -129,7 +125,7 @@ def async_iterator(self, type_: str) -> str: self._imported = True return f"typing.AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> builtins.dict[str, set[str] | None]: if self._imported: return {"typing": None} return {} @@ -137,7 +133,7 @@ def imports(self) -> Dict[str, Optional[Set[str]]]: @dataclass class NoTyping310TypingCompiler(TypingCompiler): - _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) + _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) def optional(self, type_: str) -> str: return f"{type_} | None" @@ -163,5 +159,5 @@ def async_iterator(self, type_: str) -> str: self._imports["collections.abc"].add("AsyncIterator") return f"AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> builtins.dict[str, set[str] | None]: return {k: v if v else None for k, v in self._imports.items()} diff --git a/tests/generate.py b/tests/generate.py index 67dad859..f2a2db0b 100644 --- a/tests/generate.py +++ b/tests/generate.py @@ -4,7 +4,6 @@ import shutil import sys from pathlib import Path -from typing import Set from tests.util import ( get_directories, @@ -28,7 +27,7 @@ def clear_directory(dir_path: Path): file_or_directory.unlink() -async def generate(whitelist: Set[str], verbose: bool): +async def generate(whitelist: set[str], verbose: bool): test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} path_whitelist = set() @@ -150,7 +149,7 @@ async def generate_test_case_output(test_case_input_path: Path, test_case_name: "", "NAMES One or more test-case names to generate classes for.", " python generate.py bool double enums", - ) + ), ) diff --git a/tests/test_module_validation.py b/tests/test_module_validation.py index 390b8c75..6995a590 100644 --- a/tests/test_module_validation.py +++ b/tests/test_module_validation.py @@ -1,9 +1,3 @@ -from typing import ( - List, - Optional, - Set, -) - import pytest from betterproto2.plugin.module_validation import ModuleValidator @@ -99,7 +93,7 @@ ), ], ) -def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]): +def test_module_validator(text: list[str], expected_collisions: set[str] | None): line_iterator = iter(text) validator = ModuleValidator(line_iterator) valid = validator.validate() diff --git a/tests/test_typing_compiler.py b/tests/test_typing_compiler.py index e9157f40..1f8a2c66 100644 --- a/tests/test_typing_compiler.py +++ b/tests/test_typing_compiler.py @@ -30,7 +30,7 @@ def test_direct_import_typing_compiler(): "Iterable", "AsyncIterable", "AsyncIterator", - } + }, } diff --git a/tests/util.py b/tests/util.py index 075a434b..c7e1bf0f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -5,18 +5,10 @@ import platform import sys import tempfile +from collections.abc import Callable, Generator from dataclasses import dataclass from pathlib import Path from types import ModuleType -from typing import ( - Callable, - Dict, - Generator, - List, - Optional, - Tuple, - Union, -) os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -39,8 +31,8 @@ def get_directories(path): async def protoc( - path: Union[str, Path], - output_dir: Union[str, Path], + path: str | Path, + output_dir: str | Path, reference: bool = False, pydantic_dataclasses: bool = False, ): @@ -59,7 +51,7 @@ async def protoc( "@echo off", f"\nchdir {os.getcwd()}", f"\n{sys.executable} -u {plugin_path.as_posix()}", - ] + ], ) tf.flush() @@ -88,7 +80,9 @@ async def protoc( *[p.as_posix() for p in path.glob("*.proto")], ] proc = await asyncio.create_subprocess_exec( - *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + *command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() return stdout, stderr, proc.returncode @@ -100,11 +94,11 @@ class TestCaseJsonFile: test_name: str file_name: str - def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]) -> bool: + def belongs_to(self, non_symmetrical_json: dict[str, tuple[str, ...]]) -> bool: return self.file_name in non_symmetrical_json.get(self.test_name, ()) -def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[TestCaseJsonFile]: +def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> list[TestCaseJsonFile]: """ :return: A list of all files found in "{inputs_path}/test_case_name" with names matching @@ -128,7 +122,7 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[ return result -def find_module(module: ModuleType, predicate: Callable[[ModuleType], bool]) -> Optional[ModuleType]: +def find_module(module: ModuleType, predicate: Callable[[ModuleType], bool]) -> ModuleType | None: """ Recursively search module tree for a module that matches the search predicate. Assumes that the submodules are directories containing __init__.py.