From 482b3cf0840919bc526d50dc4115161f93b93738 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Sun, 22 Dec 2024 15:07:55 +0100 Subject: [PATCH 1/5] Use Python 3.10 unions --- pyproject.toml | 2 ++ src/betterproto2_compiler/enum.py | 5 ++- .../grpc/grpclib_client.py | 31 +++++++++---------- .../grpc/util/async_channel.py | 6 ++-- .../lib/pydantic/google/protobuf/__init__.py | 6 ++-- .../lib/std/google/protobuf/__init__.py | 3 +- src/betterproto2_compiler/plugin/models.py | 7 ++--- src/betterproto2_compiler/plugin/parser.py | 7 ++--- .../plugin/typing_compiler.py | 9 +++--- tests/test_module_validation.py | 3 +- tests/util.py | 8 ++--- 11 files changed, 39 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 78d67c34..dccbd4a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,8 @@ select = [ "SIM102", # Simplify return or yield statements "SIM103", # Simplify list/set/dict comprehensions + "UP007", # Use Python 3.10 unions + "I", ] diff --git a/src/betterproto2_compiler/enum.py b/src/betterproto2_compiler/enum.py index 9a1d677b..b92cd3ab 100644 --- a/src/betterproto2_compiler/enum.py +++ b/src/betterproto2_compiler/enum.py @@ -9,7 +9,6 @@ TYPE_CHECKING, Any, Dict, - Optional, Tuple, ) @@ -111,12 +110,12 @@ class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType): inherit from this. Emulates `enum.IntEnum`. """ - name: Optional[str] + name: str | None value: int if not TYPE_CHECKING: - def __new__(cls, *, name: Optional[str], value: int) -> Self: + def __new__(cls, *, name: str | None, value: int) -> Self: self = super().__new__(cls, value) super().__setattr__(self, "name", name) super().__setattr__(self, "value", value) diff --git a/src/betterproto2_compiler/grpc/grpclib_client.py b/src/betterproto2_compiler/grpc/grpclib_client.py index ab24cedb..ff55cf36 100644 --- a/src/betterproto2_compiler/grpc/grpclib_client.py +++ b/src/betterproto2_compiler/grpc/grpclib_client.py @@ -10,7 +10,6 @@ Optional, Tuple, Type, - Union, ) import grpclib.const @@ -25,9 +24,9 @@ ) -Value = Union[str, bytes] -MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] -MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] +Value = str | bytes +MetadataLike = Mapping[str, Value] | Collection[Tuple[str, Value]] +MessageSource = Iterable["IProtoMessage"] | AsyncIterable["IProtoMessage"] class ServiceStub(ABC): @@ -39,9 +38,9 @@ def __init__( self, channel: "Channel", *, - timeout: Optional[float] = None, + timeout: float | None = None, deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, + metadata: MetadataLike | None = None, ) -> None: self.channel = channel self.timeout = timeout @@ -50,9 +49,9 @@ def __init__( def __resolve_request_kwargs( self, - timeout: Optional[float], + timeout: float | None, deadline: Optional["Deadline"], - metadata: Optional[MetadataLike], + metadata: MetadataLike | None, ): return { "timeout": self.timeout if timeout is None else timeout, @@ -66,9 +65,9 @@ async def _unary_unary( request: "IProtoMessage", response_type: Type["T"], *, - timeout: Optional[float] = None, + timeout: float | None = None, deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, + metadata: MetadataLike | None = None, ) -> "T": """Make a unary request and return the response.""" async with self.channel.request( @@ -89,9 +88,9 @@ async def _unary_stream( request: "IProtoMessage", response_type: Type["T"], *, - timeout: Optional[float] = None, + timeout: float | None = None, deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, + metadata: MetadataLike | None = None, ) -> AsyncIterator["T"]: """Make a unary request and return the stream response iterator.""" async with self.channel.request( @@ -112,9 +111,9 @@ async def _stream_unary( request_type: Type["IProtoMessage"], response_type: Type["T"], *, - timeout: Optional[float] = None, + timeout: float | None = None, deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, + metadata: MetadataLike | None = None, ) -> "T": """Make a stream request and return the response.""" async with self.channel.request( @@ -137,9 +136,9 @@ async def _stream_stream( request_type: Type["IProtoMessage"], response_type: Type["T"], *, - timeout: Optional[float] = None, + timeout: float | None = None, deadline: Optional["Deadline"] = None, - metadata: Optional[MetadataLike] = None, + metadata: MetadataLike | None = None, ) -> AsyncIterator["T"]: """ Make a stream request and return an AsyncIterator to iterate over response diff --git a/src/betterproto2_compiler/grpc/util/async_channel.py b/src/betterproto2_compiler/grpc/util/async_channel.py index 9a9345e0..8ec62942 100644 --- a/src/betterproto2_compiler/grpc/util/async_channel.py +++ b/src/betterproto2_compiler/grpc/util/async_channel.py @@ -3,9 +3,7 @@ AsyncIterable, AsyncIterator, Iterable, - Optional, TypeVar, - Union, ) T = TypeVar("T") @@ -117,7 +115,7 @@ def done(self) -> bool: # receiver per enqueued item. return self._closed and self._queue.qsize() <= self._waiting_receivers - async def send_from(self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False) -> "AsyncChannel[T]": + async def send_from(self, source: Iterable[T] | AsyncIterable[T], close: bool = False) -> "AsyncChannel[T]": """ Iterates the given [Async]Iterable and sends all the resulting items. If close is set to True then subsequent send calls will be rejected with a @@ -150,7 +148,7 @@ async def send(self, item: T) -> "AsyncChannel[T]": await self._queue.put(item) return self - async def receive(self) -> Optional[T]: + async def receive(self) -> T | None: """ Returns the next item from this channel when it becomes available, or None if the channel is closed before another item is sent. diff --git a/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py b/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py index 109a288e..9df1ea64 100644 --- a/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py +++ b/src/betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py @@ -2401,13 +2401,13 @@ class Value(betterproto2_compiler.Message): ) """Represents a null value.""" - number_value: Optional[float] = betterproto2_compiler.double_field(2, optional=True, group="kind") + number_value: float | None = betterproto2_compiler.double_field(2, optional=True, group="kind") """Represents a double value.""" - string_value: Optional[str] = betterproto2_compiler.string_field(3, optional=True, group="kind") + string_value: str | None = betterproto2_compiler.string_field(3, optional=True, group="kind") """Represents a string value.""" - bool_value: Optional[bool] = betterproto2_compiler.bool_field(4, optional=True, group="kind") + bool_value: bool | None = betterproto2_compiler.bool_field(4, optional=True, group="kind") """Represents a boolean value.""" struct_value: Optional["Struct"] = betterproto2_compiler.message_field(5, optional=True, group="kind") diff --git a/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py b/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py index dc118b7f..2c61f62d 100644 --- a/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py +++ b/src/betterproto2_compiler/lib/std/google/protobuf/__init__.py @@ -78,7 +78,6 @@ Dict, List, Mapping, - Optional, ) import betterproto2 @@ -1022,7 +1021,7 @@ class FieldDescriptorProto(betterproto2.Message): TODO(kenton): Base-64 encode? """ - oneof_index: Optional[int] = betterproto2.int32_field(9, optional=True) + oneof_index: int | None = betterproto2.int32_field(9, optional=True) """ If set, gives the index of a oneof in the containing type's oneof_decl list. This field is a member of that oneof. diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 99520ae8..11fad592 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -40,7 +40,6 @@ Iterable, Iterator, List, - Optional, Set, Type, Union, @@ -416,7 +415,7 @@ def use_builtins(self) -> bool: ) @property - def field_wraps(self) -> Optional[str]: + def field_wraps(self) -> str | None: """Returns betterproto wrapped field type or None.""" match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name) if match_wrapper: @@ -509,8 +508,8 @@ def betterproto_field_args(self) -> List[str]: @dataclass class MapEntryCompiler(FieldCompiler): - py_k_type: Optional[Type] = None - py_v_type: Optional[Type] = None + py_k_type: Type | None = None + py_v_type: Type | None = None proto_k_type: str = "" proto_v_type: str = "" diff --git a/src/betterproto2_compiler/plugin/parser.py b/src/betterproto2_compiler/plugin/parser.py index 38456b78..ca8ac9e9 100644 --- a/src/betterproto2_compiler/plugin/parser.py +++ b/src/betterproto2_compiler/plugin/parser.py @@ -5,7 +5,6 @@ List, Set, Tuple, - Union, ) from betterproto2_compiler.lib.google.protobuf import ( @@ -45,13 +44,13 @@ def traverse( proto_file: FileDescriptorProto, -) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]: +) -> Generator[Tuple[EnumDescriptorProto | DescriptorProto, List[int]], None, None]: # Todo: Keep information about nested hierarchy def _traverse( path: List[int], - items: Union[List[EnumDescriptorProto], List[DescriptorProto]], + items: List[EnumDescriptorProto] | List[DescriptorProto], prefix: str = "", - ) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]: + ) -> Generator[Tuple[EnumDescriptorProto | DescriptorProto, List[int]], None, None]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple diff --git a/src/betterproto2_compiler/plugin/typing_compiler.py b/src/betterproto2_compiler/plugin/typing_compiler.py index aa4f2135..5acc60e2 100644 --- a/src/betterproto2_compiler/plugin/typing_compiler.py +++ b/src/betterproto2_compiler/plugin/typing_compiler.py @@ -7,7 +7,6 @@ from typing import ( Dict, Iterator, - Optional, Set, ) @@ -42,7 +41,7 @@ def async_iterator(self, type_: str) -> str: raise NotImplementedError @abc.abstractmethod - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> Dict[str, Set[str] | None]: """ Returns either the direct import as a key with none as value, or a set of values to import from the key. @@ -93,7 +92,7 @@ def async_iterator(self, type_: str) -> str: self._imports["typing"].add("AsyncIterator") return f"AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> Dict[str, Set[str] | None]: return {k: v if v else None for k, v in self._imports.items()} @@ -129,7 +128,7 @@ def async_iterator(self, type_: str) -> str: self._imported = True return f"typing.AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> Dict[str, Set[str] | None]: if self._imported: return {"typing": None} return {} @@ -163,5 +162,5 @@ def async_iterator(self, type_: str) -> str: self._imports["collections.abc"].add("AsyncIterator") return f"AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Optional[Set[str]]]: + def imports(self) -> Dict[str, Set[str] | None]: return {k: v if v else None for k, v in self._imports.items()} diff --git a/tests/test_module_validation.py b/tests/test_module_validation.py index 390b8c75..b1b996d2 100644 --- a/tests/test_module_validation.py +++ b/tests/test_module_validation.py @@ -1,6 +1,5 @@ from typing import ( List, - Optional, Set, ) @@ -99,7 +98,7 @@ ), ], ) -def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]): +def test_module_validator(text: List[str], expected_collisions: Set[str] | None): line_iterator = iter(text) validator = ModuleValidator(line_iterator) valid = validator.validate() diff --git a/tests/util.py b/tests/util.py index 075a434b..c2d7ff3f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -13,9 +13,7 @@ Dict, Generator, List, - Optional, Tuple, - Union, ) os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -39,8 +37,8 @@ def get_directories(path): async def protoc( - path: Union[str, Path], - output_dir: Union[str, Path], + path: str | Path, + output_dir: str | Path, reference: bool = False, pydantic_dataclasses: bool = False, ): @@ -128,7 +126,7 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[ return result -def find_module(module: ModuleType, predicate: Callable[[ModuleType], bool]) -> Optional[ModuleType]: +def find_module(module: ModuleType, predicate: Callable[[ModuleType], bool]) -> ModuleType | None: """ Recursively search module tree for a module that matches the search predicate. Assumes that the submodules are directories containing __init__.py. From e2e4038fe6788766c52e3668bb81603c53043ed7 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Sun, 22 Dec 2024 15:21:00 +0100 Subject: [PATCH 2/5] Remove useless files for the compiler --- src/betterproto2_compiler/grpc/__init__.py | 0 .../grpc/grpclib_client.py | 171 ---------------- .../grpc/grpclib_server.py | 32 --- .../grpc/util/__init__.py | 0 .../grpc/util/async_channel.py | 188 ------------------ 5 files changed, 391 deletions(-) delete mode 100644 src/betterproto2_compiler/grpc/__init__.py delete mode 100644 src/betterproto2_compiler/grpc/grpclib_client.py delete mode 100644 src/betterproto2_compiler/grpc/grpclib_server.py delete mode 100644 src/betterproto2_compiler/grpc/util/__init__.py delete mode 100644 src/betterproto2_compiler/grpc/util/async_channel.py diff --git a/src/betterproto2_compiler/grpc/__init__.py b/src/betterproto2_compiler/grpc/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/betterproto2_compiler/grpc/grpclib_client.py b/src/betterproto2_compiler/grpc/grpclib_client.py deleted file mode 100644 index ff55cf36..00000000 --- a/src/betterproto2_compiler/grpc/grpclib_client.py +++ /dev/null @@ -1,171 +0,0 @@ -import asyncio -from abc import ABC -from typing import ( - TYPE_CHECKING, - AsyncIterable, - AsyncIterator, - Collection, - Iterable, - Mapping, - Optional, - Tuple, - Type, -) - -import grpclib.const - -if TYPE_CHECKING: - from grpclib.client import Channel - from grpclib.metadata import Deadline - - from .._types import ( - IProtoMessage, - T, - ) - - -Value = str | bytes -MetadataLike = Mapping[str, Value] | Collection[Tuple[str, Value]] -MessageSource = Iterable["IProtoMessage"] | AsyncIterable["IProtoMessage"] - - -class ServiceStub(ABC): - """ - Base class for async gRPC clients. - """ - - def __init__( - self, - channel: "Channel", - *, - timeout: float | None = None, - deadline: Optional["Deadline"] = None, - metadata: MetadataLike | None = None, - ) -> None: - self.channel = channel - self.timeout = timeout - self.deadline = deadline - self.metadata = metadata - - def __resolve_request_kwargs( - self, - timeout: float | None, - deadline: Optional["Deadline"], - metadata: MetadataLike | None, - ): - return { - "timeout": self.timeout if timeout is None else timeout, - "deadline": self.deadline if deadline is None else deadline, - "metadata": self.metadata if metadata is None else metadata, - } - - async def _unary_unary( - self, - route: str, - request: "IProtoMessage", - response_type: Type["T"], - *, - timeout: float | None = None, - deadline: Optional["Deadline"] = None, - metadata: MetadataLike | None = None, - ) -> "T": - """Make a unary request and return the response.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_UNARY, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - response = await stream.recv_message() - assert response is not None - return response - - async def _unary_stream( - self, - route: str, - request: "IProtoMessage", - response_type: Type["T"], - *, - timeout: float | None = None, - deadline: Optional["Deadline"] = None, - metadata: MetadataLike | None = None, - ) -> AsyncIterator["T"]: - """Make a unary request and return the stream response iterator.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_STREAM, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - async for message in stream: - yield message - - async def _stream_unary( - self, - route: str, - request_iterator: MessageSource, - request_type: Type["IProtoMessage"], - response_type: Type["T"], - *, - timeout: float | None = None, - deadline: Optional["Deadline"] = None, - metadata: MetadataLike | None = None, - ) -> "T": - """Make a stream request and return the response.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.STREAM_UNARY, - request_type, - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_request() - await self._send_messages(stream, request_iterator) - response = await stream.recv_message() - assert response is not None - return response - - async def _stream_stream( - self, - route: str, - request_iterator: MessageSource, - request_type: Type["IProtoMessage"], - response_type: Type["T"], - *, - timeout: float | None = None, - deadline: Optional["Deadline"] = None, - metadata: MetadataLike | None = None, - ) -> AsyncIterator["T"]: - """ - Make a stream request and return an AsyncIterator to iterate over response - messages. - """ - async with self.channel.request( - route, - grpclib.const.Cardinality.STREAM_STREAM, - request_type, - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_request() - sending_task = asyncio.ensure_future(self._send_messages(stream, request_iterator)) - try: - async for response in stream: - yield response - except: - sending_task.cancel() - raise - - @staticmethod - async def _send_messages(stream, messages: MessageSource): - if isinstance(messages, AsyncIterable): - async for message in messages: - await stream.send_message(message) - else: - for message in messages: - await stream.send_message(message) - await stream.end() diff --git a/src/betterproto2_compiler/grpc/grpclib_server.py b/src/betterproto2_compiler/grpc/grpclib_server.py deleted file mode 100644 index 61d4710b..00000000 --- a/src/betterproto2_compiler/grpc/grpclib_server.py +++ /dev/null @@ -1,32 +0,0 @@ -from abc import ABC -from collections.abc import AsyncIterable -from typing import ( - Any, - Callable, -) - -import grpclib -import grpclib.server - - -class ServiceBase(ABC): - """ - Base class for async gRPC servers. - """ - - async def _call_rpc_handler_server_stream( - self, - handler: Callable, - stream: grpclib.server.Stream, - request: Any, - ) -> None: - response_iter = handler(request) - # check if response is actually an AsyncIterator - # this might be false if the method just returns without - # yielding at least once - # in that case, we just interpret it as an empty iterator - if isinstance(response_iter, AsyncIterable): - async for response_message in response_iter: - await stream.send_message(response_message) - else: - response_iter.close() diff --git a/src/betterproto2_compiler/grpc/util/__init__.py b/src/betterproto2_compiler/grpc/util/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/betterproto2_compiler/grpc/util/async_channel.py b/src/betterproto2_compiler/grpc/util/async_channel.py deleted file mode 100644 index 8ec62942..00000000 --- a/src/betterproto2_compiler/grpc/util/async_channel.py +++ /dev/null @@ -1,188 +0,0 @@ -import asyncio -from typing import ( - AsyncIterable, - AsyncIterator, - Iterable, - TypeVar, -) - -T = TypeVar("T") - - -class ChannelClosed(Exception): - """ - An exception raised on an attempt to send through a closed channel - """ - - -class ChannelDone(Exception): - """ - An exception raised on an attempt to send receive from a channel that is both closed - and empty. - """ - - -class AsyncChannel(AsyncIterable[T]): - """ - A buffered async channel for sending items between coroutines with FIFO ordering. - - This makes decoupled bidirectional steaming gRPC requests easy if used like: - - .. code-block:: python - client = GeneratedStub(grpclib_chan) - request_channel = await AsyncChannel() - # We can start be sending all the requests we already have - await request_channel.send_from([RequestObject(...), RequestObject(...)]) - async for response in client.rpc_call(request_channel): - # The response iterator will remain active until the connection is closed - ... - # More items can be sent at any time - await request_channel.send(RequestObject(...)) - ... - # The channel must be closed to complete the gRPC connection - request_channel.close() - - Items can be sent through the channel by either: - - providing an iterable to the send_from method - - passing them to the send method one at a time - - Items can be received from the channel by either: - - iterating over the channel with a for loop to get all items - - calling the receive method to get one item at a time - - If the channel is empty then receivers will wait until either an item appears or the - channel is closed. - - Once the channel is closed then subsequent attempt to send through the channel will - fail with a ChannelClosed exception. - - When th channel is closed and empty then it is done, and further attempts to receive - from it will fail with a ChannelDone exception - - If multiple coroutines receive from the channel concurrently, each item sent will be - received by only one of the receivers. - - :param source: - An optional iterable will items that should be sent through the channel - immediately. - :param buffer_limit: - Limit the number of items that can be buffered in the channel, A value less than - 1 implies no limit. If the channel is full then attempts to send more items will - result in the sender waiting until an item is received from the channel. - :param close: - If set to True then the channel will automatically close after exhausting source - or immediately if no source is provided. - """ - - def __init__(self, *, buffer_limit: int = 0, close: bool = False): - self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit) - self._closed = False - self._waiting_receivers: int = 0 - # Track whether flush has been invoked so it can only happen once - self._flushed = False - - def __aiter__(self) -> AsyncIterator[T]: - return self - - async def __anext__(self) -> T: - if self.done(): - raise StopAsyncIteration - self._waiting_receivers += 1 - try: - result = await self._queue.get() - if result is self.__flush: - raise StopAsyncIteration - return result - finally: - self._waiting_receivers -= 1 - self._queue.task_done() - - def closed(self) -> bool: - """ - Returns True if this channel is closed and no-longer accepting new items - """ - return self._closed - - def done(self) -> bool: - """ - Check if this channel is done. - - :return: True if this channel is closed and and has been drained of items in - which case any further attempts to receive an item from this channel will raise - a ChannelDone exception. - """ - # After close the channel is not yet done until there is at least one waiting - # receiver per enqueued item. - return self._closed and self._queue.qsize() <= self._waiting_receivers - - async def send_from(self, source: Iterable[T] | AsyncIterable[T], close: bool = False) -> "AsyncChannel[T]": - """ - Iterates the given [Async]Iterable and sends all the resulting items. - If close is set to True then subsequent send calls will be rejected with a - ChannelClosed exception. - :param source: an iterable of items to send - :param close: - if True then the channel will be closed after the source has been exhausted - - """ - if self._closed: - raise ChannelClosed("Cannot send through a closed channel") - if isinstance(source, AsyncIterable): - async for item in source: - await self._queue.put(item) - else: - for item in source: - await self._queue.put(item) - if close: - # Complete the closing process - self.close() - return self - - async def send(self, item: T) -> "AsyncChannel[T]": - """ - Send a single item over this channel. - :param item: The item to send - """ - if self._closed: - raise ChannelClosed("Cannot send through a closed channel") - await self._queue.put(item) - return self - - async def receive(self) -> T | None: - """ - Returns the next item from this channel when it becomes available, - or None if the channel is closed before another item is sent. - :return: An item from the channel - """ - if self.done(): - raise ChannelDone("Cannot receive from a closed channel") - self._waiting_receivers += 1 - try: - result = await self._queue.get() - if result is self.__flush: - return None - return result - finally: - self._waiting_receivers -= 1 - self._queue.task_done() - - def close(self): - """ - Close this channel to new items - """ - self._closed = True - asyncio.ensure_future(self._flush_queue()) - - async def _flush_queue(self): - """ - To be called after the channel is closed. Pushes a number of self.__flush - objects to the queue to ensure no waiting consumers get deadlocked. - """ - if not self._flushed: - self._flushed = True - deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize()) - for _ in range(deadlocked_receivers): - await self._queue.put(self.__flush) - - # A special signal object for flushing the queue when the channel is closed - __flush = object() From f7d9e301b44194e01354e7e11017eb51f3fe93d0 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Sun, 22 Dec 2024 15:25:19 +0100 Subject: [PATCH 3/5] Add more rules --- pyproject.toml | 4 +- .../compile/importing.py | 23 +++---- src/betterproto2_compiler/enum.py | 4 +- src/betterproto2_compiler/plugin/models.py | 63 +++++++++---------- .../plugin/module_validation.py | 9 +-- src/betterproto2_compiler/plugin/parser.py | 21 +++---- .../plugin/typing_compiler.py | 19 +++--- tests/generate.py | 3 +- tests/test_module_validation.py | 7 +-- tests/util.py | 12 +--- 10 files changed, 63 insertions(+), 102 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dccbd4a5..fa478ba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ protobuf = "^4" protoc-gen-python_betterproto2 = "betterproto2_compiler.plugin:main" [tool.ruff] -extend-exclude = ["tests/output_*"] +extend-exclude = ["tests/output_*", "src/betterproto2_compiler/lib"] target-version = "py310" line-length = 120 @@ -52,7 +52,7 @@ select = [ "SIM102", # Simplify return or yield statements "SIM103", # Simplify list/set/dict comprehensions - "UP007", # Use Python 3.10 unions + "UP", "I", ] diff --git a/src/betterproto2_compiler/compile/importing.py b/src/betterproto2_compiler/compile/importing.py index 814fff71..d4b3ac78 100644 --- a/src/betterproto2_compiler/compile/importing.py +++ b/src/betterproto2_compiler/compile/importing.py @@ -3,11 +3,6 @@ import os from typing import ( TYPE_CHECKING, - Dict, - List, - Set, - Tuple, - Type, ) from ..casing import safe_snake_case @@ -18,7 +13,7 @@ from ..plugin.models import PluginRequestCompiler from ..plugin.typing_compiler import TypingCompiler -WRAPPER_TYPES: Dict[str, Type] = { +WRAPPER_TYPES: dict[str, type] = { ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.FloatValue": google_protobuf.FloatValue, ".google.protobuf.Int32Value": google_protobuf.Int32Value, @@ -31,7 +26,7 @@ } -def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler") -> Tuple[str, str]: +def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]: """ Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') @@ -77,7 +72,7 @@ def get_type_reference( imports: set, source_type: str, typing_compiler: TypingCompiler, - request: "PluginRequestCompiler", + request: PluginRequestCompiler, unwrap: bool = True, pydantic: bool = False, ) -> str: @@ -98,8 +93,8 @@ def get_type_reference( source_package, source_type = parse_source_type_name(source_type, request) - current_package: List[str] = package.split(".") if package else [] - py_package: List[str] = source_package.split(".") if source_package else [] + current_package: list[str] = package.split(".") if package else [] + py_package: list[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) compiling_google_protobuf = current_package == ["google", "protobuf"] @@ -122,7 +117,7 @@ def get_type_reference( return reference_cousin(current_package, imports, py_package, py_type) -def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type located in the root, i.e. sys.path. """ @@ -139,7 +134,7 @@ def reference_sibling(py_type: str) -> str: return f"{py_type}" -def reference_descendent(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type in a package that is a descendent of the current package, and adds the required import that is aliased to avoid name @@ -157,7 +152,7 @@ def reference_descendent(current_package: List[str], imports: Set[str], py_packa return f"{string_import}.{py_type}" -def reference_ancestor(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type in a package which is an ancestor to the current package, and adds the required import that is aliased (if possible) to avoid @@ -178,7 +173,7 @@ def reference_ancestor(current_package: List[str], imports: Set[str], py_package return string_alias -def reference_cousin(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: +def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: """ Returns a reference to a python type in a package that is not descendent, ancestor or sibling, and adds the required import that is aliased to avoid name conflicts. diff --git a/src/betterproto2_compiler/enum.py b/src/betterproto2_compiler/enum.py index b92cd3ab..890060cd 100644 --- a/src/betterproto2_compiler/enum.py +++ b/src/betterproto2_compiler/enum.py @@ -8,8 +8,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Tuple, ) if TYPE_CHECKING: @@ -32,7 +30,7 @@ class EnumType(EnumMeta if TYPE_CHECKING else type): _value_map_: Mapping[int, Enum] _member_map_: Mapping[str, Enum] - def __new__(mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> Self: + def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) -> Self: value_map = {} member_map = {} diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 11fad592..10f6330b 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -31,17 +31,12 @@ import builtins import re +from collections.abc import Iterable, Iterator from dataclasses import ( dataclass, field, ) from typing import ( - Dict, - Iterable, - Iterator, - List, - Set, - Type, Union, ) @@ -145,7 +140,7 @@ def get_comment( proto_file: "FileDescriptorProto", - path: List[int], + path: list[int], ) -> str: for sci_loc in proto_file.source_code_info.location: if list(sci_loc.path) == path: @@ -181,10 +176,10 @@ class ProtoContentBase: source_file: FileDescriptorProto typing_compiler: TypingCompiler - path: List[int] + path: list[int] parent: Union["betterproto2.Message", "OutputTemplate"] - __dataclass_fields__: Dict[str, object] + __dataclass_fields__: dict[str, object] def __post_init__(self) -> None: """Checks that no fake default fields were left as placeholders.""" @@ -224,10 +219,10 @@ def deprecated(self) -> bool: @dataclass class PluginRequestCompiler: plugin_request_obj: CodeGeneratorRequest - output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) + output_packages: dict[str, "OutputTemplate"] = field(default_factory=dict) @property - def all_messages(self) -> List["MessageCompiler"]: + def all_messages(self) -> list["MessageCompiler"]: """All of the messages in this request. Returns @@ -249,11 +244,11 @@ class OutputTemplate: parent_request: PluginRequestCompiler package_proto_obj: FileDescriptorProto - input_files: List[str] = field(default_factory=list) - imports_end: Set[str] = field(default_factory=set) - messages: Dict[str, "MessageCompiler"] = field(default_factory=dict) - enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict) - services: Dict[str, "ServiceCompiler"] = field(default_factory=dict) + input_files: list[str] = field(default_factory=list) + imports_end: set[str] = field(default_factory=set) + messages: dict[str, "MessageCompiler"] = field(default_factory=dict) + enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict) + services: dict[str, "ServiceCompiler"] = field(default_factory=dict) pydantic_dataclasses: bool = False output: bool = True typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler) @@ -289,9 +284,9 @@ class MessageCompiler(ProtoContentBase): typing_compiler: TypingCompiler parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER - path: List[int] = PLACEHOLDER - fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list) - builtins_types: Set[str] = field(default_factory=set) + path: list[int] = PLACEHOLDER + fields: list[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list) + builtins_types: set[str] = field(default_factory=set) def __post_init__(self) -> None: # Add message to output file @@ -327,11 +322,9 @@ def has_oneof_fields(self) -> bool: @property def has_message_field(self) -> bool: return any( - ( - field.proto_obj.type in PROTO_MESSAGE_TYPES - for field in self.fields - if isinstance(field.proto_obj, FieldDescriptorProto) - ) + field.proto_obj.type in PROTO_MESSAGE_TYPES + for field in self.fields + if isinstance(field.proto_obj, FieldDescriptorProto) ) @@ -373,8 +366,8 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: class FieldCompiler(ProtoContentBase): source_file: FileDescriptorProto typing_compiler: TypingCompiler - path: List[int] = PLACEHOLDER - builtins_types: Set[str] = field(default_factory=set) + path: list[int] = PLACEHOLDER + builtins_types: set[str] = field(default_factory=set) parent: MessageCompiler = PLACEHOLDER proto_obj: FieldDescriptorProto = PLACEHOLDER @@ -395,7 +388,7 @@ def get_field_string(self) -> str: return f'{name}: "{self.annotation}" = {betterproto_field_type}' @property - def betterproto_field_args(self) -> List[str]: + def betterproto_field_args(self) -> list[str]: args = [] if self.field_wraps: args.append(f"wraps={self.field_wraps}") @@ -499,7 +492,7 @@ def optional(self) -> bool: return True @property - def betterproto_field_args(self) -> List[str]: + def betterproto_field_args(self) -> list[str]: args = super().betterproto_field_args group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name args.append(f'group="{group}"') @@ -508,8 +501,8 @@ def betterproto_field_args(self) -> List[str]: @dataclass class MapEntryCompiler(FieldCompiler): - py_k_type: Type | None = None - py_v_type: Type | None = None + py_k_type: type | None = None + py_v_type: type | None = None proto_k_type: str = "" proto_v_type: str = "" @@ -547,7 +540,7 @@ def ready(self) -> None: raise ValueError("can't find enum") @property - def betterproto_field_args(self) -> List[str]: + def betterproto_field_args(self) -> list[str]: return [f"betterproto2.{self.proto_k_type}", f"betterproto2.{self.proto_v_type}"] @property @@ -568,7 +561,7 @@ class EnumDefinitionCompiler(MessageCompiler): """Representation of a proto Enum definition.""" proto_obj: EnumDescriptorProto = PLACEHOLDER - entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER + entries: list["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER @dataclass(unsafe_hash=True) class EnumEntry: @@ -596,8 +589,8 @@ class ServiceCompiler(ProtoContentBase): source_file: FileDescriptorProto parent: OutputTemplate = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER - path: List[int] = PLACEHOLDER - methods: List["ServiceMethodCompiler"] = field(default_factory=list) + path: list[int] = PLACEHOLDER + methods: list["ServiceMethodCompiler"] = field(default_factory=list) def __post_init__(self) -> None: # Add service to output file @@ -618,7 +611,7 @@ class ServiceMethodCompiler(ProtoContentBase): source_file: FileDescriptorProto parent: ServiceCompiler proto_obj: MethodDescriptorProto - path: List[int] = PLACEHOLDER + path: list[int] = PLACEHOLDER def __post_init__(self) -> None: # Add method to service diff --git a/src/betterproto2_compiler/plugin/module_validation.py b/src/betterproto2_compiler/plugin/module_validation.py index 19a06a96..16432995 100644 --- a/src/betterproto2_compiler/plugin/module_validation.py +++ b/src/betterproto2_compiler/plugin/module_validation.py @@ -1,15 +1,10 @@ import re from collections import defaultdict +from collections.abc import Iterator from dataclasses import ( dataclass, field, ) -from typing import ( - Dict, - Iterator, - List, - Tuple, -) @dataclass @@ -17,7 +12,7 @@ class ModuleValidator: line_iterator: Iterator[str] line_number: int = field(init=False, default=0) - collisions: Dict[str, List[Tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list)) + collisions: dict[str, list[tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list)) def add_import(self, imp: str, number: int, full_line: str): """ diff --git a/src/betterproto2_compiler/plugin/parser.py b/src/betterproto2_compiler/plugin/parser.py index ca8ac9e9..9dd2cafb 100644 --- a/src/betterproto2_compiler/plugin/parser.py +++ b/src/betterproto2_compiler/plugin/parser.py @@ -1,11 +1,6 @@ import pathlib import sys -from typing import ( - Generator, - List, - Set, - Tuple, -) +from collections.abc import Generator from betterproto2_compiler.lib.google.protobuf import ( DescriptorProto, @@ -44,13 +39,13 @@ def traverse( proto_file: FileDescriptorProto, -) -> Generator[Tuple[EnumDescriptorProto | DescriptorProto, List[int]], None, None]: +) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: # Todo: Keep information about nested hierarchy def _traverse( - path: List[int], - items: List[EnumDescriptorProto] | List[DescriptorProto], + path: list[int], + items: list[EnumDescriptorProto] | list[DescriptorProto], prefix: str = "", - ) -> Generator[Tuple[EnumDescriptorProto | DescriptorProto, List[int]], None, None]: + ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple @@ -143,7 +138,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: service.ready() # Generate output files - output_paths: Set[pathlib.Path] = set() + output_paths: set[pathlib.Path] = set() for output_package_name, output_package in request_data.output_packages.items(): if not output_package.output: continue @@ -182,7 +177,7 @@ def _make_one_of_field_compiler( source_file: "FileDescriptorProto", parent: MessageCompiler, proto_obj: "FieldDescriptorProto", - path: List[int], + path: list[int], ) -> FieldCompiler: return OneOfFieldCompiler( source_file=source_file, @@ -195,7 +190,7 @@ def _make_one_of_field_compiler( def read_protobuf_type( item: DescriptorProto, - path: List[int], + path: list[int], source_file: "FileDescriptorProto", output_package: OutputTemplate, ) -> None: diff --git a/src/betterproto2_compiler/plugin/typing_compiler.py b/src/betterproto2_compiler/plugin/typing_compiler.py index 5acc60e2..fd1e120c 100644 --- a/src/betterproto2_compiler/plugin/typing_compiler.py +++ b/src/betterproto2_compiler/plugin/typing_compiler.py @@ -1,14 +1,11 @@ import abc +import builtins from collections import defaultdict +from collections.abc import Iterator from dataclasses import ( dataclass, field, ) -from typing import ( - Dict, - Iterator, - Set, -) class TypingCompiler(metaclass=abc.ABCMeta): @@ -41,7 +38,7 @@ def async_iterator(self, type_: str) -> str: raise NotImplementedError @abc.abstractmethod - def imports(self) -> Dict[str, Set[str] | None]: + def imports(self) -> builtins.dict[str, set[str] | None]: """ Returns either the direct import as a key with none as value, or a set of values to import from the key. @@ -62,7 +59,7 @@ def import_lines(self) -> Iterator: @dataclass class DirectImportTypingCompiler(TypingCompiler): - _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) + _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) def optional(self, type_: str) -> str: self._imports["typing"].add("Optional") @@ -92,7 +89,7 @@ def async_iterator(self, type_: str) -> str: self._imports["typing"].add("AsyncIterator") return f"AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Set[str] | None]: + def imports(self) -> builtins.dict[str, set[str] | None]: return {k: v if v else None for k, v in self._imports.items()} @@ -128,7 +125,7 @@ def async_iterator(self, type_: str) -> str: self._imported = True return f"typing.AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Set[str] | None]: + def imports(self) -> builtins.dict[str, set[str] | None]: if self._imported: return {"typing": None} return {} @@ -136,7 +133,7 @@ def imports(self) -> Dict[str, Set[str] | None]: @dataclass class NoTyping310TypingCompiler(TypingCompiler): - _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) + _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) def optional(self, type_: str) -> str: return f"{type_} | None" @@ -162,5 +159,5 @@ def async_iterator(self, type_: str) -> str: self._imports["collections.abc"].add("AsyncIterator") return f"AsyncIterator[{type_}]" - def imports(self) -> Dict[str, Set[str] | None]: + def imports(self) -> builtins.dict[str, set[str] | None]: return {k: v if v else None for k, v in self._imports.items()} diff --git a/tests/generate.py b/tests/generate.py index 67dad859..e30a1794 100644 --- a/tests/generate.py +++ b/tests/generate.py @@ -4,7 +4,6 @@ import shutil import sys from pathlib import Path -from typing import Set from tests.util import ( get_directories, @@ -28,7 +27,7 @@ def clear_directory(dir_path: Path): file_or_directory.unlink() -async def generate(whitelist: Set[str], verbose: bool): +async def generate(whitelist: set[str], verbose: bool): test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} path_whitelist = set() diff --git a/tests/test_module_validation.py b/tests/test_module_validation.py index b1b996d2..6995a590 100644 --- a/tests/test_module_validation.py +++ b/tests/test_module_validation.py @@ -1,8 +1,3 @@ -from typing import ( - List, - Set, -) - import pytest from betterproto2.plugin.module_validation import ModuleValidator @@ -98,7 +93,7 @@ ), ], ) -def test_module_validator(text: List[str], expected_collisions: Set[str] | None): +def test_module_validator(text: list[str], expected_collisions: set[str] | None): line_iterator = iter(text) validator = ModuleValidator(line_iterator) valid = validator.validate() diff --git a/tests/util.py b/tests/util.py index c2d7ff3f..b31030e4 100644 --- a/tests/util.py +++ b/tests/util.py @@ -5,16 +5,10 @@ import platform import sys import tempfile +from collections.abc import Callable, Generator from dataclasses import dataclass from pathlib import Path from types import ModuleType -from typing import ( - Callable, - Dict, - Generator, - List, - Tuple, -) os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -98,11 +92,11 @@ class TestCaseJsonFile: test_name: str file_name: str - def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]) -> bool: + def belongs_to(self, non_symmetrical_json: dict[str, tuple[str, ...]]) -> bool: return self.file_name in non_symmetrical_json.get(self.test_name, ()) -def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[TestCaseJsonFile]: +def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> list[TestCaseJsonFile]: """ :return: A list of all files found in "{inputs_path}/test_case_name" with names matching From 62be72d6b6dd48b56cd14a622f2892fb50a22b67 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Sun, 22 Dec 2024 15:32:14 +0100 Subject: [PATCH 4/5] Fix the poe check command --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa478ba9..f6e71c95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,8 +80,8 @@ sequence = ["_format", "_sort-imports"] help = "Format the source code, and sort the imports" [tool.poe.tasks.check] -sequence = ["_check-format", "_check-imports"] -help = "Check that the source code is formatted and the imports sorted" +sequence = ["_check-format", "_check-ruff-lint"] +help = "Check that the source code is formatted and the code passes the linter" [tool.poe.tasks._format] cmd = "ruff format src tests" @@ -95,9 +95,9 @@ help = "Sort the imports" cmd = "ruff format --diff src tests" help = "Check that the source code is formatted" -[tool.poe.tasks._check-imports] -cmd = "ruff check --select I src tests" -help = "Check that the imports are sorted" +[tool.poe.tasks._check-ruff-lint] +cmd = "ruff check src tests" +help = "Check the code with the Ruff linter" [tool.poe.tasks.generate_lib] cmd = """ From d5c1c15362749534d810f90f7944d8d0eb6f28e5 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Sun, 22 Dec 2024 15:33:47 +0100 Subject: [PATCH 5/5] Fix trailing commas --- pyproject.toml | 2 ++ src/betterproto2_compiler/enum.py | 2 +- src/betterproto2_compiler/plugin/compiler.py | 2 +- src/betterproto2_compiler/plugin/models.py | 3 ++- src/betterproto2_compiler/plugin/parser.py | 5 +++-- tests/generate.py | 2 +- tests/test_typing_compiler.py | 2 +- tests/util.py | 6 ++++-- 8 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f6e71c95..771cb0c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ select = [ "UP", "I", + + "COM812", # Trailing commas ] diff --git a/src/betterproto2_compiler/enum.py b/src/betterproto2_compiler/enum.py index 890060cd..7a6a5a71 100644 --- a/src/betterproto2_compiler/enum.py +++ b/src/betterproto2_compiler/enum.py @@ -37,7 +37,7 @@ def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) new_mcs = type( f"{name}Type", tuple( - dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type]) + dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type]), ), # reorder the bases so EnumType and type are last to avoid conflicts {"_value_map_": value_map, "_member_map_": member_map}, ) diff --git a/src/betterproto2_compiler/plugin/compiler.py b/src/betterproto2_compiler/plugin/compiler.py index e8d261c1..780bed82 100644 --- a/src/betterproto2_compiler/plugin/compiler.py +++ b/src/betterproto2_compiler/plugin/compiler.py @@ -14,7 +14,7 @@ "Please ensure that you've installed betterproto as " '`pip install "betterproto[compiler]"` so that compiler dependencies ' "are included." - "\033[0m" + "\033[0m", ) raise SystemExit(1) diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 10f6330b..1e44b307 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -420,7 +420,8 @@ def field_wraps(self) -> str | None: @property def repeated(self) -> bool: return self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED and not is_map( - self.proto_obj, self.parent + self.proto_obj, + self.parent, ) @property diff --git a/src/betterproto2_compiler/plugin/parser.py b/src/betterproto2_compiler/plugin/parser.py index 9dd2cafb..a2d9cbb3 100644 --- a/src/betterproto2_compiler/plugin/parser.py +++ b/src/betterproto2_compiler/plugin/parser.py @@ -76,7 +76,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: if output_package_name not in request_data.output_packages: # Create a new output if there is no output for this package request_data.output_packages[output_package_name] = OutputTemplate( - parent_request=request_data, package_proto_obj=proto_file + parent_request=request_data, + package_proto_obj=proto_file, ) # Add this input file to the output corresponding to this package request_data.output_packages[output_package_name].input_files.append(proto_file) @@ -152,7 +153,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: name=str(output_path), # Render and then format the output file content=outputfile_compiler(output_file=output_package), - ) + ), ) # Make each output directory a package with __init__ file diff --git a/tests/generate.py b/tests/generate.py index e30a1794..f2a2db0b 100644 --- a/tests/generate.py +++ b/tests/generate.py @@ -149,7 +149,7 @@ async def generate_test_case_output(test_case_input_path: Path, test_case_name: "", "NAMES One or more test-case names to generate classes for.", " python generate.py bool double enums", - ) + ), ) diff --git a/tests/test_typing_compiler.py b/tests/test_typing_compiler.py index e9157f40..1f8a2c66 100644 --- a/tests/test_typing_compiler.py +++ b/tests/test_typing_compiler.py @@ -30,7 +30,7 @@ def test_direct_import_typing_compiler(): "Iterable", "AsyncIterable", "AsyncIterator", - } + }, } diff --git a/tests/util.py b/tests/util.py index b31030e4..c7e1bf0f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -51,7 +51,7 @@ async def protoc( "@echo off", f"\nchdir {os.getcwd()}", f"\n{sys.executable} -u {plugin_path.as_posix()}", - ] + ], ) tf.flush() @@ -80,7 +80,9 @@ async def protoc( *[p.as_posix() for p in path.glob("*.proto")], ] proc = await asyncio.create_subprocess_exec( - *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + *command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() return stdout, stderr, proc.returncode