diff --git a/README.md b/README.md index db8bfe274..e640dcd2e 100644 --- a/README.md +++ b/README.md @@ -38,11 +38,12 @@ This project exists because I am unhappy with the state of the official Google p - Uses `SerializeToString()` rather than the built-in `__bytes__()` - Special wrapped types don't use Python's `None` - Timestamp/duration types don't use Python's built-in `datetime` module + This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical. ## Installation -First, install the package. Note that the `[compiler]` feature flag tells it to install extra dependencies only needed by the `protoc` plugin: +First, install the package. Note that the `[compiler]` feature flag tells it to install extra dependencies only needed by the code generator: ```sh # Install both the library and compiler @@ -71,18 +72,10 @@ message Greeting { } ``` -You can run the following to invoke protoc directly: +To compile the protobuf you would run the following: ```sh -mkdir lib -protoc -I . --python_betterproto_out=lib example.proto -``` - -or run the following to invoke protoc via grpcio-tools: - -```sh -pip install grpcio-tools -python -m grpc_tools.protoc -I . --python_betterproto_out=lib example.proto +betterproto compile example.proto --output=lib ``` This will generate `lib/hello/__init__.py` which looks like: @@ -160,12 +153,6 @@ service Echo { } ``` -Generate echo proto file: - -``` -python -m grpc_tools.protoc -I . --python_betterproto_out=. echo.proto -``` - A client can be implemented as follows: ```python import asyncio @@ -175,16 +162,13 @@ from grpclib.client import Channel async def main(): - channel = Channel(host="127.0.0.1", port=50051) - service = echo.EchoStub(channel) - response = await service.echo(echo.EchoRequest(value="hello", extra_times=1)) - print(response) - - async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)): - print(response) - - # don't forget to close the channel when done! - channel.close() + async with Channel(host="127.0.0.1", port=50051) as channel: + service = echo.EchoStub(channel) + response = await service.echo(echo.EchoRequest(value="hello", extra_times=1)) + print(response) + + async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)): + print(response) if __name__ == "__main__": @@ -278,23 +262,23 @@ You can use `betterproto.which_one_of(message, group_name)` to determine which o ```py >>> test = Test() >>> betterproto.which_one_of(test, "foo") -["", None] +("", None) >>> test.on = True >>> betterproto.which_one_of(test, "foo") -["on", True] +("on", True) # Setting one member of the group resets the others. >>> test.count = 57 >>> betterproto.which_one_of(test, "foo") -["count", 57] +("count", 57) >>> test.on False # Default (zero) values also work. >>> test.name = "" >>> betterproto.which_one_of(test, "foo") -["name", ""] +("name", "") >>> test.count 0 >>> test.on @@ -310,7 +294,7 @@ Again this is a little different than the official Google code generator: # New way (this project) >>> betterproto.which_one_of(message, "group") -["foo", "foo's value"] +("foo", "foo's value") ``` ### Well-Known Google Types @@ -445,7 +429,7 @@ poe full-test Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google). Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests. -Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`. +Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`. Assuming your `google.protobuf` source files (included with all releases of `protoc`) are located in `/usr/local/include`, you can regenerate them as follows: diff --git a/docs/migrating.rst b/docs/migrating.rst index 0f18eac5f..f6680fb0c 100644 --- a/docs/migrating.rst +++ b/docs/migrating.rst @@ -4,7 +4,7 @@ Migrating Guide Google's protocolbuffers ------------------------ -betterproto has a mostly 1 to 1 drop in replacement for Google's protocolbuffers (after +betterproto has a mostly 1 to 1 drop in replacement for Google's protocol buffers (after regenerating your protobufs of course) although there are some minor differences. .. note:: diff --git a/docs/quick-start.rst b/docs/quick-start.rst index 73598c6fc..34abeb411 100644 --- a/docs/quick-start.rst +++ b/docs/quick-start.rst @@ -40,22 +40,11 @@ Given you installed the compiler and have a proto file, e.g ``example.proto``: string message = 1; } -To compile the proto you would run the following: - -You can run the following to invoke protoc directly: +To compile the protobuf you would run the following: .. code-block:: sh - mkdir hello - protoc -I . --python_betterproto_out=lib example.proto - -or run the following to invoke protoc via grpcio-tools: - -.. code-block:: sh - - pip install grpcio-tools - python -m grpc_tools.protoc -I . --python_betterproto_out=lib example.proto - + betterproto compile example.proto --output=lib This will generate ``lib/__init__.py`` which looks like: @@ -141,16 +130,13 @@ The generated client can be used like so: async def main(): - channel = Channel(host="127.0.0.1", port=50051) - service = echo.EchoStub(channel) - response = await service.echo(value="hello", extra_times=1) - print(response) - - async for response in service.echo_stream(value="hello", extra_times=1): + async with Channel(host="127.0.0.1", port=50051) as channel: + service = echo.EchoStub(channel) + response = await service.echo(value="hello", extra_times=1) print(response) - # don't forget to close the channel when you're done! - channel.close() + async for response in service.echo_stream(value="hello", extra_times=1): + print(response) asyncio.run(main()) # python 3.7 only diff --git a/pyproject.toml b/pyproject.toml index 27e745c02..69bcf7409 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ black = { version = ">=19.3b0", optional = true } dataclasses = { version = "^0.7", python = ">=3.6, <3.7" } grpclib = "^0.4.1" jinja2 = { version = ">=2.11.2", optional = true } +typer = { version = "^0.3.2", optional = true } +rich = { version = "^11.2.0", optional = true } python-dateutil = "^2.8" [tool.poetry.dev-dependencies] @@ -36,13 +38,14 @@ sphinx = "3.1.2" sphinx-rtd-theme = "0.5.0" tomlkit = "^0.7.0" tox = "^3.15.1" - +# protobuf_parser = "1.0.0" [tool.poetry.scripts] +betterproto = "betterproto:__main__.main" protoc-gen-python_betterproto = "betterproto.plugin:main" [tool.poetry.extras] -compiler = ["black", "jinja2"] +compiler = ["black", "jinja2", "typer", "rich", "protobuf_parser"] # Dev workflow tasks @@ -81,12 +84,7 @@ help = "Clean out generated files from the workspace" [tool.poe.tasks.generate_lib] cmd = """ -protoc - --plugin=protoc-gen-custom=src/betterproto/plugin/main.py - --custom_opt=INCLUDE_GOOGLE - --custom_out=src/betterproto/lib - -I /usr/local/include/ - /usr/local/include/google/protobuf/**/*.proto +betterproto compile /usr/local/include/google/protobuf/**/*.proto """ help = "Regenerate the types in betterproto.lib.google" diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index d5465791b..e06a9ec4b 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -9,7 +9,6 @@ from abc import ABC from base64 import b64decode, b64encode from datetime import datetime, timedelta, timezone -from dateutil.parser import isoparse from typing import ( Any, Callable, @@ -25,12 +24,13 @@ get_type_hints, ) +from dateutil.parser import isoparse + from ._types import T from ._version import __version__ from .casing import camel_case, safe_snake_case, snake_case from .grpc.grpclib_client import ServiceStub - # Proto 3 data types TYPE_ENUM = "enum" TYPE_BOOL = "bool" diff --git a/src/betterproto/__main__.py b/src/betterproto/__main__.py new file mode 100644 index 000000000..d225536f5 --- /dev/null +++ b/src/betterproto/__main__.py @@ -0,0 +1,7 @@ +from .plugin.exception_hook import install_exception_hook + +install_exception_hook() +from .plugin.cli import app as main + +if __name__ == "__main__": + main() diff --git a/src/betterproto/_types.py b/src/betterproto/_types.py index 26b734406..7b4e9b1a8 100644 --- a/src/betterproto/_types.py +++ b/src/betterproto/_types.py @@ -2,6 +2,7 @@ if TYPE_CHECKING: from grpclib._typing import IProtoMessage + from . import Message # Bound type variable to allow methods to return `self` of subclasses diff --git a/src/betterproto/plugin/__init__.py b/src/betterproto/plugin/__init__.py index c28a133f2..e69de29bb 100644 --- a/src/betterproto/plugin/__init__.py +++ b/src/betterproto/plugin/__init__.py @@ -1 +0,0 @@ -from .main import main diff --git a/src/betterproto/plugin/__main__.py b/src/betterproto/plugin/__main__.py index bd95daead..40e2b013f 100644 --- a/src/betterproto/plugin/__main__.py +++ b/src/betterproto/plugin/__main__.py @@ -1,4 +1,4 @@ from .main import main - -main() +if __name__ == "__main__": + main() diff --git a/src/betterproto/plugin/cli/__init__.py b/src/betterproto/plugin/cli/__init__.py new file mode 100644 index 000000000..f32d6de6b --- /dev/null +++ b/src/betterproto/plugin/cli/__init__.py @@ -0,0 +1,6 @@ +VERBOSE = False + +from black.const import DEFAULT_LINE_LENGTH as DEFAULT_LINE_LENGTH + +from .commands import app as app +from .runner import compile_protobufs as compile_protobufs diff --git a/src/betterproto/plugin/cli/commands.py b/src/betterproto/plugin/cli/commands.py new file mode 100644 index 000000000..f3e652934 --- /dev/null +++ b/src/betterproto/plugin/cli/commands.py @@ -0,0 +1,115 @@ +import sys +from pathlib import Path +from typing import List, Optional + +import protobuf_parser +import rich +import typer +from rich.syntax import Syntax + +from ... import __version__ +from ..models import monkey_patch_oneof_index +from . import DEFAULT_LINE_LENGTH, VERBOSE, utils +from .runner import compile_protobufs + +monkey_patch_oneof_index() +app = typer.Typer() + + +@app.callback(context_settings={"help_option_names": ["-h", "--help"]}) +def callback(ctx: typer.Context) -> None: + """The callback for all things betterproto""" + if ctx.invoked_subcommand is None: + rich.print(ctx.get_help()) + + +@app.command() +def version(ctx: typer.Context) -> None: + rich.print("betterproto version:", __version__) + + +@app.command(context_settings={"help_option_names": ["-h", "--help"]}) +@utils.run_sync +async def compile( + verbose: bool = typer.Option( + VERBOSE, "-v", "--verbose", help="Whether or not to be verbose" + ), + line_length: int = typer.Option( + DEFAULT_LINE_LENGTH, + "-l", + "--line-length", + help="The line length to format with", + ), + generate_services: bool = typer.Option( + True, help="Whether or not to generate servicer stubs" + ), + output: Optional[Path] = typer.Option( + None, + help="The name of the output directory", + file_okay=False, + allow_dash=True, + ), + paths: List[Path] = typer.Argument( + ..., + help="The protobuf files to compile", + exists=True, + allow_dash=True, + readable=False, + ), +) -> None: + """The recommended way to compile your protobuf files.""" + files = utils.get_files(paths) + + if not files: + return rich.print("[bold]No files found to compile") + + for output_path, protos in files.items(): + output = output or (Path(output_path.parent.name) / output_path.name).resolve() + output.mkdir(exist_ok=True, parents=True) + + results = await compile_protobufs( + *protos, + output=output, + verbose=verbose, + generate_services=generate_services, + line_length=line_length, + from_cli=True, + ) + + for result in results: + for error in result.errors: + if error.message.startswith("Syntax error"): + rich.print( + f"[red]File {str(result.file)}:\n", + Syntax.from_path( + str(result.file), + line_numbers=True, + line_range=(max(error.line - 5, 0), error.line), + ), + f"{' ' * (error.column + 3)}^\nSyntaxError: {error.message}[red]", + file=sys.stderr, + ) + elif isinstance(error, protobuf_parser.Warning): + rich.print(f"Warning: {error}", file=sys.stderr) + else: + failed_files = "\n".join(f" - {file}" for file in protos) + rich.print( + f"[red]Protoc failed to generate outputs for:\n\n" + f"{failed_files}\n\nSee the output for the issue:\n{error}[red]", + file=sys.stderr, + ) + + # has_warnings = all(isinstance(e, Warning) for e in errors) + # if not errors or has_warnings: + if True: + rich.print( + f"[bold green]Finished generating output for " + f"{len(protos)} file{'s' if len(protos) != 1 else ''}, " + f"output is in {output.as_posix()}" + ) + + # if errors: + # if not has_warnings: + # exit(2) + # exit(1) + exit(0) diff --git a/src/betterproto/plugin/cli/runner.py b/src/betterproto/plugin/cli/runner.py new file mode 100644 index 000000000..6c10eac68 --- /dev/null +++ b/src/betterproto/plugin/cli/runner.py @@ -0,0 +1,69 @@ +import asyncio +from typing import TYPE_CHECKING, Any, Sequence + +import protobuf_parser + +from ...lib.google.protobuf import FileDescriptorProto +from ...lib.google.protobuf.compiler import ( + CodeGeneratorRequest, + CodeGeneratorResponseFile, +) +from ..parser import generate_code +from . import utils + +if TYPE_CHECKING: + from pathlib import Path + + +def write_file(output: "Path", file: CodeGeneratorResponseFile) -> None: + path = output.joinpath(file.name).resolve() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(file.content) + + +async def compile_protobufs( + *files: "Path", + output: "Path", + use_betterproto: bool = True, + **kwargs: Any, +) -> Sequence[protobuf_parser.ParseResult["Path"]]: + """ + A programmatic way to compile protobufs. + + Parameters + ---------- + *files + The locations of the protobuf files to be generated. + output + The output directory. + **kwargs: + Any keyword arguments to pass to generate_code. + + Returns + ------- + A of exceptions from protoc. + """ + if use_betterproto: + results = await utils.to_thread(protobuf_parser.parse, *files) + request = CodeGeneratorRequest( + proto_file=[ + FileDescriptorProto().parse(result.parsed) for result in results + ] + ) + + # Generate code + response = await utils.to_thread(generate_code, request, **kwargs) + + await asyncio.gather( + *(utils.to_thread(write_file, output, file) for file in response.file) + ) + return results + else: + errors = await utils.to_thread( + protobuf_parser.run, + *(f'"{file.as_posix()}"' for file in files), + proto_path=files[0].parent.as_posix(), + python_out=output.as_posix(), + ) + + return [] diff --git a/src/betterproto/plugin/cli/utils.py b/src/betterproto/plugin/cli/utils.py new file mode 100644 index 000000000..f02e9bd2d --- /dev/null +++ b/src/betterproto/plugin/cli/utils.py @@ -0,0 +1,63 @@ +import asyncio +import functools +import sys +from collections.abc import Mapping +from collections import defaultdict +from pathlib import Path +from typing import Any, Awaitable, Callable, Iterable, List, Optional, Set, TypeVar + +from typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + + +def get_files(paths: List[Path]) -> "Mapping[Path, Set[Path]]": + """Return a list of files ready for :func:`generate_command`""" + + new_paths: "defaultdict[Path, Set[Path]]" = defaultdict(set) + for path in paths: + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + + if path.is_dir(): + new_paths[path].update( + sorted(path.glob("*.proto")) + ) # ensure order for files when debugging compilation errors + else: + new_paths[path.parent].add(path) + + return dict(new_paths) + + +def run_sync(func: Callable[P, Awaitable[T]]) -> Callable[P, T]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + coro = func(*args, **kwargs) + + if hasattr(asyncio, "run"): + return asyncio.run(coro) + + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + return wrapper + + +def find(predicate: Callable[[T], bool], iterable: Iterable[T]) -> Optional[T]: + for i in iterable: + if predicate(i): + return i + + +if sys.version_info >= (3, 9): + to_thread = asyncio.to_thread +else: + + async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: + loop = asyncio.get_running_loop() + func_call = functools.partial(func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) diff --git a/src/betterproto/plugin/compiler.py b/src/betterproto/plugin/compiler.py index ba2284e63..fe0fdc073 100644 --- a/src/betterproto/plugin/compiler.py +++ b/src/betterproto/plugin/compiler.py @@ -1,24 +1,16 @@ import os.path -try: - # betterproto[compiler] specific dependencies - import black - import jinja2 -except ImportError as err: - print( - "\033[31m" - f"Unable to import `{err.name}` from betterproto plugin! " - "Please ensure that you've installed betterproto as " - '`pip install "betterproto[compiler]"` so that compiler dependencies ' - "are included." - "\033[0m" - ) - raise SystemExit(1) +import black +import jinja2 +from black.const import DEFAULT_LINE_LENGTH +from black.mode import Mode, TargetVersion from .models import OutputTemplate -def outputfile_compiler(output_file: OutputTemplate) -> str: +def outputfile_compiler( + output_file: OutputTemplate, line_length: int = DEFAULT_LINE_LENGTH +) -> str: templates_folder = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "templates") @@ -33,5 +25,5 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: return black.format_str( template.render(output_file=output_file), - mode=black.Mode(), + mode=Mode(line_length=line_length, target_versions={TargetVersion.PY37}), ) diff --git a/src/betterproto/plugin/exception_hook.py b/src/betterproto/plugin/exception_hook.py new file mode 100644 index 000000000..ed2043b5d --- /dev/null +++ b/src/betterproto/plugin/exception_hook.py @@ -0,0 +1,50 @@ +import sys +import traceback +from pathlib import Path +from types import TracebackType +from typing import Type + +IMPORT_ERROR_MESSAGE = ( + "Unable to import `{0.name}` from betterproto plugin! Please ensure that you've " + 'installed betterproto as `pip install "betterproto[compiler]"` so that compiler ' + "dependencies are included." +) + +STDLIB_MODULES = getattr( + sys, + "builtin_module_names", + [ + p.stem + for p in Path(traceback.__file__).parent.iterdir() + if p.suffix == ".py" or p.is_dir() + ], +) + + +def import_exception_hook( + type: Type[BaseException], value: ImportError, tb: TracebackType +) -> None: + """Set an exception hook to automatically print: + + "Unable to import `x` from betterproto plugin! Please ensure that you've installed + betterproto as `pip install "betterproto[compiler]"` so that compiler dependencies are + included." + + if the module imported is not found and the exception is raised in this sub module + """ + bottom_frame = list(traceback.walk_tb(tb))[-1][0] + module = bottom_frame.f_globals.get("__name__", "__main__") + if ( + not module.startswith(__name__) + or not isinstance(value, ImportError) + or value.name in STDLIB_MODULES + or (value.name or "").startswith("betterproto") + ): + return sys.__excepthook__(type, value, tb) + + print(f"\033[31m{IMPORT_ERROR_MESSAGE.format(value)}\033[0m", file=sys.stderr) + exit(1) + + +def install_exception_hook(): + sys.excepthook = import_exception_hook diff --git a/src/betterproto/plugin/main.py b/src/betterproto/plugin/main.py index 8982321f1..d19266cdd 100755 --- a/src/betterproto/plugin/main.py +++ b/src/betterproto/plugin/main.py @@ -1,15 +1,16 @@ #!/usr/bin/env python -import os import sys -from betterproto.lib.google.protobuf.compiler import ( - CodeGeneratorRequest, - CodeGeneratorResponse, -) +from .exception_hook import install_exception_hook -from betterproto.plugin.parser import generate_code -from betterproto.plugin.models import monkey_patch_oneof_index +install_exception_hook() + +import rich + +from ..lib.google.protobuf.compiler import CodeGeneratorRequest +from .models import monkey_patch_oneof_index +from .parser import generate_code def main() -> None: @@ -21,12 +22,14 @@ def main() -> None: monkey_patch_oneof_index() # Parse request - request = CodeGeneratorRequest() - request.parse(data) + request = CodeGeneratorRequest().parse(data) - dump_file = os.getenv("BETTERPROTO_DUMP") - if dump_file: - dump_request(dump_file, request) + rich.print( + "Direct invocation of the protoc plugin is depreciated over using the CLI\n" + "To do so you just need to type:\n" + f"betterproto compile {' '.join(request.file_to_generate)}", + file=sys.stderr, + ) # Generate code response = generate_code(request) @@ -36,18 +39,3 @@ def main() -> None: # Write to stdout sys.stdout.buffer.write(output) - - -def dump_request(dump_file: str, request: CodeGeneratorRequest) -> None: - """ - For developers: Supports running plugin.py standalone so its possible to debug it. - Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. - Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. - """ - with open(str(dump_file), "wb") as fh: - sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") - fh.write(request.SerializeToString()) - - -if __name__ == "__main__": - main() diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index c2fccfcd2..19b1f64f5 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -29,12 +29,11 @@ reference to `A` to `B`'s `fields` attribute. """ - import builtins import re import textwrap from dataclasses import dataclass, field -from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union +from typing import Dict, Iterator, List, Optional, Set, Type, Union import betterproto from betterproto import which_one_of @@ -57,6 +56,7 @@ ) from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest +from .. import Message, which_one_of from ..casing import sanitize_name from ..compile.importing import get_type_reference, parse_source_type_name from ..compile.naming import ( @@ -64,6 +64,18 @@ pythonize_field_name, pythonize_method_name, ) +from ..lib.google.protobuf import ( + DescriptorProto, + EnumDescriptorProto, + Field, + FieldDescriptorProto, + FieldDescriptorProtoLabel, + FieldDescriptorProtoType, + FileDescriptorProto, + MethodDescriptorProto, + ServiceDescriptorProto, +) +from ..lib.google.protobuf.compiler import CodeGeneratorRequest # Create a unique placeholder to deal with # https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses @@ -159,7 +171,7 @@ class ProtoContentBase: source_file: FileDescriptorProto path: List[int] comment_indent: int = 4 - parent: Union["betterproto.Message", "OutputTemplate"] + parent: Union["Message", "OutputTemplate"] __dataclass_fields__: Dict[str, object] @@ -224,7 +236,7 @@ class OutputTemplate: parent_request: PluginRequestCompiler package_proto_obj: FileDescriptorProto - input_files: List[str] = field(default_factory=list) + input_files: List[FileDescriptorProto] = field(default_factory=list) imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set) typing_imports: Set[str] = field(default_factory=set) @@ -245,12 +257,12 @@ def package(self) -> str: return self.package_proto_obj.package @property - def input_filenames(self) -> Iterable[str]: + def input_filenames(self) -> List[str]: """Names of the input files used to build this output. Returns ------- - Iterable[str] + List[str] Names of the input files used to build this output. """ return sorted(f.name for f in self.input_files) @@ -634,7 +646,7 @@ def default_value_string(self) -> str: @dataclass class ServiceCompiler(ProtoContentBase): parent: OutputTemplate = PLACEHOLDER - proto_obj: DescriptorProto = PLACEHOLDER + proto_obj: ServiceDescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER methods: List["ServiceMethodCompiler"] = field(default_factory=list) diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 21a2caf14..68dc6af1d 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -1,20 +1,23 @@ -from betterproto.lib.google.protobuf import ( +import itertools +import pathlib +from contextlib import AbstractContextManager +from typing import Any, Iterator, List, Sequence, Set, Tuple, TypeAlias, Union + +from black.const import DEFAULT_LINE_LENGTH +from rich.progress import Progress + +from ..lib.google.protobuf import ( DescriptorProto, EnumDescriptorProto, - FieldDescriptorProto, FileDescriptorProto, ServiceDescriptorProto, ) -from betterproto.lib.google.protobuf.compiler import ( +from ..lib.google.protobuf.compiler import ( CodeGeneratorRequest, CodeGeneratorResponse, CodeGeneratorResponseFeature, CodeGeneratorResponseFile, ) -import itertools -import pathlib -import sys -from typing import Iterator, List, Set, Tuple, TYPE_CHECKING, Union from .compiler import outputfile_compiler from .models import ( EnumDefinitionCompiler, @@ -30,17 +33,18 @@ is_oneof, ) -if TYPE_CHECKING: - from google.protobuf.descriptor import Descriptor +TraverseType: TypeAlias = ( + "Tuple[Union[DescriptorProto, EnumDescriptorProto], List[int]]" +) -def traverse( - proto_file: FieldDescriptorProto, -) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": +def traverse(proto_file: FileDescriptorProto) -> "itertools.chain[TraverseType]": # Todo: Keep information about nested hierarchy def _traverse( - path: List[int], items: List["EnumDescriptorProto"], prefix="" - ) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]: + path: List[int], + items: Sequence[Union[DescriptorProto, EnumDescriptorProto]], + prefix: str = "", + ) -> Iterator[TraverseType]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple @@ -53,93 +57,164 @@ def _traverse( yield enum, path + [i, 4] if item.nested_type: - for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix): - yield n, p + yield from _traverse(path + [i, 3], item.nested_type, next_prefix) return itertools.chain( _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) ) -def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: - response = CodeGeneratorResponse() +class NoopProgress(AbstractContextManager): + def add_task(self, *args: Any, **kwargs: Any) -> None: + ... + + def update(self, *args: Any, **kwargs: Any) -> None: + ... + + def __exit__(self, *args: Any) -> None: + ... + + +def generate_code( + request: CodeGeneratorRequest, + *, + include_google: bool = False, + line_length: int = DEFAULT_LINE_LENGTH, + generate_services: bool = True, + verbose: bool = False, + from_cli: bool = False, +) -> CodeGeneratorResponse: + """Generate the protobuf response file for a given request. + + Parameters + ---------- + request: :class:`.CodeGeneratorRequest` + The request to generate the protobufs from. + include_google: :class:`bool` + Whether or not to include the google protobufs in the response files. + line_length: :class:`int` + The line length to pass to black for formatting. + generate_services: :class:`bool` + Whether or not to include services. + verbose: :class:`bool` + Whether or not to run the plugin in verbose mode. + from_cli: :class:`bool` + Whether or not the plugin is being ran from the CLI. + Returns + ------- + :class:`.CodeGeneratorResponse` + The response for the request. + """ + + response = CodeGeneratorResponse() plugin_options = request.parameter.split(",") if request.parameter else [] response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL + include_google = "INCLUDE_GOOGLE" in plugin_options or include_google request_data = PluginRequestCompiler(plugin_request_obj=request) # Gather output packages - for proto_file in request.proto_file: - if ( - proto_file.package == "google.protobuf" - and "INCLUDE_GOOGLE" not in plugin_options - ): - # If not INCLUDE_GOOGLE, - # skip re-compiling Google's well-known types - continue - - output_package_name = proto_file.package - if output_package_name not in request_data.output_packages: - # Create a new output if there is no output for this package - request_data.output_packages[output_package_name] = OutputTemplate( - parent_request=request_data, package_proto_obj=proto_file + + with Progress(transient=True) if from_cli else NoopProgress() as progress: + reading_progress_bar = progress.add_task( + "[green]Reading protobuf files...", total=len(request.proto_file) + ) + for proto_file in request.proto_file: + if proto_file.package == "google.protobuf" and not include_google: + # If not include_google skip re-compiling Google's well-known types + continue + + output_package_name = proto_file.package + if output_package_name not in request_data.output_packages: + # Create a new output if there is no output for this package + request_data.output_packages[output_package_name] = OutputTemplate( + parent_request=request_data, package_proto_obj=proto_file + ) + # Add this input file to the output corresponding to this package + request_data.output_packages[output_package_name].input_files.append( + proto_file ) - # Add this input file to the output corresponding to this package - request_data.output_packages[output_package_name].input_files.append(proto_file) + progress.update(reading_progress_bar, advance=1) # Read Messages and Enums # We need to read Messages before Services in so that we can # get the references to input/output messages for each service - for output_package_name, output_package in request_data.output_packages.items(): - for proto_input_file in output_package.input_files: - for item, path in traverse(proto_input_file): - read_protobuf_type( - source_file=proto_input_file, - item=item, - path=path, - output_package=output_package, - ) + with Progress(transient=True) if from_cli else NoopProgress() as progress: + parsing_progress_bar = progress.add_task( + "[green]Parsing protobuf enums and messages...", + total=sum( + len(message.package_proto_obj.enum_type) + + len(message.package_proto_obj.message_type) + for message in request_data.output_packages.values() + ), + ) + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for item, path in traverse(proto_input_file): + read_protobuf_type( + source_file=proto_input_file, + item=item, + path=path, + output_package=output_package, + ) + progress.update(parsing_progress_bar, advance=1) # Read Services - for output_package_name, output_package in request_data.output_packages.items(): - for proto_input_file in output_package.input_files: - for index, service in enumerate(proto_input_file.service): - read_protobuf_service(service, index, output_package) + if generate_services: + with Progress(transient=True) if from_cli else NoopProgress() as progress: + parsing_progress_bar = progress.add_task( + "[green]Parsing protobuf services...", + total=sum( + len(message.package_proto_obj.service) + for message in request_data.output_packages.values() + ), + ) + for ( + output_package_name, + output_package, + ) in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for index, service in enumerate(proto_input_file.service): + read_protobuf_service(service, index, output_package) + progress.update(parsing_progress_bar, advance=1) # Generate output files output_paths: Set[pathlib.Path] = set() - for output_package_name, output_package in request_data.output_packages.items(): + with Progress(transient=True) if from_cli else NoopProgress() as progress: + compiling_progress_bar = progress.add_task( + "[green]Compiling protobuf files...", + total=len(request_data.output_packages), + ) + for output_package_name, output_package in request_data.output_packages.items(): - # Add files to the response object - output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") - output_paths.add(output_path) + # Add files to the response object + output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") + output_paths.add(output_path) - response.file.append( - CodeGeneratorResponseFile( - name=str(output_path), - # Render and then format the output file - content=outputfile_compiler(output_file=output_package), + response.file.append( + CodeGeneratorResponseFile( + name=str(output_path), + # Render and then format the output file + content=outputfile_compiler( + output_file=output_package, line_length=line_length + ), + ) ) - ) + progress.update(compiling_progress_bar, advance=1) # Make each output directory a package with __init__ file init_files = { - directory.joinpath("__init__.py") - for path in output_paths - for directory in path.parents + directory / "__init__.py" for path in output_paths for directory in path.parents } - output_paths for init_file in init_files: response.file.append(CodeGeneratorResponseFile(name=str(init_file))) - for output_package_name in sorted(output_paths.union(init_files)): - print(f"Writing {output_package_name}", file=sys.stderr) - return response def read_protobuf_type( - item: DescriptorProto, + item: Union[DescriptorProto, EnumDescriptorProto], path: List[int], source_file: "FileDescriptorProto", output_package: OutputTemplate, diff --git a/tests/generate.py b/tests/generate.py index 1d7d3e985..35b81111f 100755 --- a/tests/generate.py +++ b/tests/generate.py @@ -1,26 +1,29 @@ #!/usr/bin/env python import asyncio import os -from pathlib import Path -import platform import shutil import sys -from typing import Set +from pathlib import Path +from typing import List, Optional, Set +import rich +import typer + +from betterproto.plugin.cli import compile_protobufs, utils from tests.util import ( get_directories, inputs_path, output_path_betterproto, output_path_reference, - protoc, ) # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" +app = typer.Typer() -def clear_directory(dir_path: Path): +def clear_directory(dir_path: Path) -> None: for file_or_directory in dir_path.glob("*"): if file_or_directory.is_dir(): shutil.rmtree(file_or_directory) @@ -28,7 +31,7 @@ def clear_directory(dir_path: Path): file_or_directory.unlink() -async def generate(whitelist: Set[str], verbose: bool): +async def generate(whitelist: Set[Path], verbose: bool) -> None: test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} path_whitelist = set() @@ -41,7 +44,7 @@ async def generate(whitelist: Set[str], verbose: bool): generation_tasks = [] for test_case_name in sorted(test_case_names): - test_case_input_path = inputs_path.joinpath(test_case_name).resolve() + test_case_input_path = (inputs_path / test_case_name).resolve() if ( whitelist and str(test_case_input_path) not in path_whitelist @@ -53,32 +56,32 @@ async def generate(whitelist: Set[str], verbose: bool): ) failed_test_cases = [] - # Wait for all subprocs and match any failures to names to report - for test_case_name, result in zip( + # Wait for processes before match any failures to names to report + for test_case_name, exception in zip( sorted(test_case_names), await asyncio.gather(*generation_tasks) ): - if result != 0: + if exception is not None: failed_test_cases.append(test_case_name) if len(failed_test_cases) > 0: - sys.stderr.write( - "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n" + rich.print( + "[red bold]\nFailed to generate the following test cases:", + *(f"[red]- {failed_test_case}" for failed_test_case in failed_test_cases), + sep="\n", ) - for failed_test_case in failed_test_cases: - sys.stderr.write(f"- {failed_test_case}\n") sys.exit(1) async def generate_test_case_output( test_case_input_path: Path, test_case_name: str, verbose: bool -) -> int: +) -> Optional[Exception]: """ Returns the max of the subprocess return values """ - test_case_output_path_reference = output_path_reference.joinpath(test_case_name) - test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name) + test_case_output_path_reference = output_path_reference / test_case_name + test_case_output_path_betterproto = output_path_betterproto / test_case_name os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_betterproto, exist_ok=True) @@ -86,83 +89,46 @@ async def generate_test_case_output( clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_betterproto) - ( - (ref_out, ref_err, ref_code), - (plg_out, plg_err, plg_code), - ) = await asyncio.gather( - protoc(test_case_input_path, test_case_output_path_reference, True), - protoc(test_case_input_path, test_case_output_path_betterproto, False), + files = list(test_case_input_path.glob("*.proto")) + ref_errs, plg_errs = await asyncio.gather( + compile_protobufs( + *files, output=test_case_output_path_reference, use_betterproto=False + ), + compile_protobufs( + *files, output=test_case_output_path_betterproto, from_cli=True + ), ) - if ref_code == 0: - print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m") - else: - print( - f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m" - ) - - if verbose: - if ref_out: - print("Reference stdout:") - sys.stdout.buffer.write(ref_out) - sys.stdout.buffer.flush() - - if ref_err: - print("Reference stderr:") - sys.stderr.buffer.write(ref_err) - sys.stderr.buffer.flush() - - if plg_code == 0: - print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m") - else: - print( - f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m" - ) - + rich.print(f"[bold red]Generated output for {test_case_name!r}") if verbose: - if plg_out: - print("Plugin stdout:") - sys.stdout.buffer.write(plg_out) - sys.stdout.buffer.flush() - - if plg_err: - print("Plugin stderr:") - sys.stderr.buffer.write(plg_err) - sys.stderr.buffer.flush() - - return max(ref_code, plg_code) - - -HELP = "\n".join( - ( - "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]", - "Generate python classes for standard tests.", - "", - "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.", - " python generate.py inputs/bool inputs/double inputs/enum", - "", - "NAMES One or more test-case names to generate classes for.", - " python generate.py bool double enums", - ) -) - - -def main(): - if set(sys.argv).intersection({"-h", "--help"}): - print(HELP) - return - if sys.argv[1:2] == ["-v"]: - verbose = True - whitelist = set(sys.argv[2:]) - else: - verbose = False - whitelist = set(sys.argv[1:]) - - if platform.system() == "Windows": - asyncio.set_event_loop(asyncio.ProactorEventLoop()) - - asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) + for ref_err in ref_errs: + rich.print(f"[red]{ref_err}", file=sys.stderr) + for plg_err in plg_errs: + rich.print(f"[red]{plg_err}", file=sys.stderr) + sys.stderr.flush() + + return ref_errs or plg_errs or None + + +@app.command(context_settings={"help_option_names": ["-h", "--help"]}) +@utils.run_sync +async def main( + verbose: bool = typer.Option( + False, + "-v", + "--verbose", + help="Whether or not to run the plugin in verbose mode.", + ), + directories: Optional[List[Path]] = typer.Argument( + None, + help="One or more relative or absolute directories or test-case names " + "test-cases to generate classes for. e.g. ``inputs/bool inputs/double " + "inputs/enum`` or ``bool double enum``", + ), +) -> None: + """Generate python classes for standard tests.""" + await generate(set(directories or ()), verbose) if __name__ == "__main__": - main() + app() diff --git a/tests/test_features.py b/tests/test_features.py index 787520dee..4523d14bd 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional, List, Dict from datetime import datetime from inspect import Parameter, signature from typing import Dict, List, Optional diff --git a/tests/test_get_ref_type.py b/tests/test_get_ref_type.py index cbee4caa5..fa2e77bf1 100644 --- a/tests/test_get_ref_type.py +++ b/tests/test_get_ref_type.py @@ -35,9 +35,7 @@ def test_reference_google_wellknown_types_non_wrappers( name = get_type_reference(package="", imports=imports, source_type=google_type) assert name == expected_name - assert imports.__contains__( - expected_import - ), f"{expected_import} not found in {imports}" + assert expected_import in imports, f"{expected_import} not found in {imports}" @pytest.mark.parametrize( diff --git a/tests/util.py b/tests/util.py index 950cf7af7..20c848385 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,9 +1,10 @@ import asyncio -from dataclasses import dataclass import importlib import os -from pathlib import Path +import pathlib import sys +from dataclasses import dataclass +from pathlib import Path from types import ModuleType from typing import Callable, Dict, Generator, List, Optional, Tuple, Union @@ -15,13 +16,7 @@ output_path_betterproto = root_path.joinpath("output_betterproto") -def get_files(path, suffix: str) -> Generator[str, None, None]: - for r, dirs, files in os.walk(path): - for filename in [f for f in files if f.endswith(suffix)]: - yield os.path.join(r, filename) - - -def get_directories(path): +def get_directories(path: Path) -> Generator[str, None, None]: for root, directories, files in os.walk(path): yield from directories @@ -66,10 +61,10 @@ def get_test_case_json_data( f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names """ - test_case_dir = inputs_path.joinpath(test_case_name) + test_case_dir = inputs_path / test_case_name possible_file_paths = [ - *(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names), - test_case_dir.joinpath(f"{test_case_name}.json"), + *(test_case_dir / json_file_name for json_file_name in json_file_names), + test_case_dir / f"{test_case_name}.json", *test_case_dir.glob(f"{test_case_name}_*.json"), ] @@ -77,12 +72,13 @@ def get_test_case_json_data( for test_data_file_path in possible_file_paths: if not test_data_file_path.exists(): continue - with test_data_file_path.open("r") as fh: - result.append( - TestCaseJsonFile( - fh.read(), test_case_name, test_data_file_path.name.split(".")[0] - ) + result.append( + TestCaseJsonFile( + test_data_file_path.read_text(), + test_case_name, + test_data_file_path.name.split(".")[0], ) + ) return result