From 92eccf8a76cde737ccb6df3467f6d646db4b0fb3 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Mon, 20 Jan 2025 17:14:05 +0100 Subject: [PATCH 1/4] Make snake case safer --- src/betterproto2_compiler/casing.py | 42 ++++++------------- .../compile/importing.py | 9 +++- src/betterproto2_compiler/plugin/models.py | 3 +- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/betterproto2_compiler/casing.py b/src/betterproto2_compiler/casing.py index 741adf76..c5597cb2 100644 --- a/src/betterproto2_compiler/casing.py +++ b/src/betterproto2_compiler/casing.py @@ -21,43 +21,25 @@ def safe_snake_case(value: str) -> str: return value -def snake_case(value: str, strict: bool = True) -> str: +def snake_case(name: str) -> str: """ Join words with an underscore into lowercase and remove symbols. + """ - Parameters - ----------- - value: :class:`str` - The value to convert. - strict: :class:`bool` - Whether or not to force single underscores. + # If there are already underscores in the name, don't break it + if "_" in name or not any([c.isupper() for c in name]): + return name - Returns - -------- - :class:`str` - The value in snake_case. - """ + # Add an underscore before capital letters + name = re.sub(r"(?<=[a-z0-9])([A-Z])", r"_\1", name) - def substitute_word(symbols: str, word: str, is_start: bool) -> str: - if not word: - return "" - if strict: - delimiter_count = 0 if is_start else 1 # Single underscore if strict. - elif is_start: - delimiter_count = len(symbols) - elif word.isupper() or word.islower(): - delimiter_count = max(1, len(symbols)) # Preserve all delimiters if not strict. - else: - delimiter_count = len(symbols) + 1 # Extra underscore for leading capital. + # Add an underscore before capital letters following an acronym + name = re.sub(r"(?<=[A-Z])([A-Z])(?=[a-z])", r"_\1", name) - return ("_" * delimiter_count) + word.lower() + # Add an underscore before digits + name = re.sub(r"(?<=[a-zA-Z])([0-9])", r"_\1", name) - snake = re.sub( - f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})", - lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None), - value, - ) - return snake + return name.lower() def pascal_case(value: str, strict: bool = True) -> str: diff --git a/src/betterproto2_compiler/compile/importing.py b/src/betterproto2_compiler/compile/importing.py index c954aa95..62d2b9c8 100644 --- a/src/betterproto2_compiler/compile/importing.py +++ b/src/betterproto2_compiler/compile/importing.py @@ -114,7 +114,7 @@ def reference_absolute(imports: set[str], py_package: list[str], py_type: str) - Returns a reference to a python type located in the root, i.e. sys.path. """ string_import = ".".join(py_package) - string_alias = safe_snake_case(string_import) + string_alias = "__".join([safe_snake_case(name) for name in py_package]) imports.add(f"import {string_import} as {string_alias}") return f"{string_alias}.{py_type}" @@ -175,6 +175,11 @@ def reference_cousin(current_package: list[str], imports: set[str], py_package: string_from = f".{'.' * distance_up}" + ".".join(py_package[len(shared_ancestry) : -1]) string_import = py_package[-1] # Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34) - string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__" + # string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__" + string_alias = ( + f"{'_' * distance_up}" + + "__".join([safe_snake_case(name) for name in py_package[len(shared_ancestry) :]]) + + "__" + ) imports.add(f"from {string_from} import {string_import} as {string_alias}") return f"{string_alias}.{py_type}" diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 71f91a72..34f3a26b 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -36,6 +36,7 @@ import betterproto2 from betterproto2 import unwrap +from betterproto2_compiler.casing import safe_snake_case from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name from betterproto2_compiler.compile.naming import ( pythonize_class_name, @@ -609,7 +610,7 @@ def py_input_message_param(self) -> str: str Param name corresponding to py_input_message_type. """ - return pythonize_field_name(self.py_input_message_type) + return safe_snake_case(self.py_input_message_type.split(".")[-1]) @property def py_output_message_type(self) -> str: From 12850f8efa34a3d66aa12286d0bef3eac27a6cf3 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Mon, 20 Jan 2025 17:22:41 +0100 Subject: [PATCH 2/4] Rename default parameter name --- src/betterproto2_compiler/plugin/models.py | 12 ---------- .../templates/template.py.j2 | 24 +++++++++---------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 34f3a26b..931e2226 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -36,7 +36,6 @@ import betterproto2 from betterproto2 import unwrap -from betterproto2_compiler.casing import safe_snake_case from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name from betterproto2_compiler.compile.naming import ( pythonize_class_name, @@ -601,17 +600,6 @@ def is_input_msg_empty(self: "ServiceMethodCompiler") -> bool: return not bool(msg.fields) - @property - def py_input_message_param(self) -> str: - """Param name corresponding to py_input_message_type. - - Returns - ------- - str - Param name corresponding to py_input_message_type. - """ - return safe_snake_case(self.py_input_message_type.split(".")[-1]) - @property def py_output_message_type(self) -> str: """String representation of the Python type corresponding to the diff --git a/src/betterproto2_compiler/templates/template.py.j2 b/src/betterproto2_compiler/templates/template.py.j2 index d5632a72..4a99a4dd 100644 --- a/src/betterproto2_compiler/templates/template.py.j2 +++ b/src/betterproto2_compiler/templates/template.py.j2 @@ -98,7 +98,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): {% for method in service.methods %} async def {{ method.py_name }}(self {%- if not method.client_streaming -%} - , {{ method.py_input_message_param }}: + , message: {%- if method.is_input_msg_empty -%} "{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None {%- else -%} @@ -106,7 +106,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}" + , messages: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}" {%- endif -%} , * @@ -128,7 +128,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): {% if method.client_streaming %} async for response in self._stream_stream( "{{ method.route }}", - {{ method.py_input_message_param }}_iterator, + messages, {{ method.py_input_message_type }}, {{ method.py_output_message_type }}, timeout=timeout, @@ -138,13 +138,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): yield response {% else %}{# i.e. not client streaming #} {% if method.is_input_msg_empty %} - if {{ method.py_input_message_param }} is None: - {{ method.py_input_message_param }} = {{ method.py_input_message_type }}() + if message is None: + message = {{ method.py_input_message_type }}() {% endif %} async for response in self._unary_stream( "{{ method.route }}", - {{ method.py_input_message_param }}, + message, {{ method.py_output_message_type }}, timeout=timeout, deadline=deadline, @@ -157,7 +157,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): {% if method.client_streaming %} return await self._stream_unary( "{{ method.route }}", - {{ method.py_input_message_param }}_iterator, + messages, {{ method.py_input_message_type }}, {{ method.py_output_message_type }}, timeout=timeout, @@ -166,13 +166,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): ) {% else %}{# i.e. not client streaming #} {% if method.is_input_msg_empty %} - if {{ method.py_input_message_param }} is None: - {{ method.py_input_message_param }} = {{ method.py_input_message_type }}() + if message is None: + message = {{ method.py_input_message_type }}() {% endif %} return await self._unary_unary( "{{ method.route }}", - {{ method.py_input_message_param }}, + message, {{ method.py_output_message_type }}, timeout=timeout, deadline=deadline, @@ -199,10 +199,10 @@ class {{ service.py_name }}Base(ServiceBase): {% for method in service.methods %} async def {{ method.py_name }}(self {%- if not method.client_streaming -%} - , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" + , message: "{{ method.py_input_message_type }}" {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }} + , messages: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }} {%- endif -%} ) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} From e51a542cf86a06cbc96628dc4838041d5e09750e Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Mon, 20 Jan 2025 18:01:12 +0100 Subject: [PATCH 3/4] Add test --- tests/test_casing.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/test_casing.py diff --git a/tests/test_casing.py b/tests/test_casing.py new file mode 100644 index 00000000..01debe50 --- /dev/null +++ b/tests/test_casing.py @@ -0,0 +1,32 @@ +def test_snake_case() -> None: + from betterproto2_compiler.casing import snake_case + + # Simple renaming + assert snake_case("methodName") == "method_name" + assert snake_case("MethodName") == "method_name" + + # Don't break acronyms + assert snake_case("HTTPRequest") == "http_request" + assert snake_case("RequestHTTP") == "request_http" + assert snake_case("HTTPRequest2") == "http_request_2" + assert snake_case("RequestHTTP2") == "request_http_2" + assert snake_case("GetAResponse") == "get_a_response" + + # Split digits + assert snake_case("Get2025Results") == "get_2025_results" + assert snake_case("Get10yResults") == "get_10y_results" + + # If the name already contains an underscore or is lowercase, don't change it at all. + # There is a risk of breaking names otherwise. + assert snake_case("aaa_123_bbb") == "aaa_123_bbb" + assert snake_case("aaa_123bbb") == "aaa_123bbb" + assert snake_case("aaa123_bbb") == "aaa123_bbb" + assert snake_case("get_HTTP_response") == "get_HTTP_response" + assert snake_case("_methodName") == "_methodName" + assert snake_case("make_gRPC_request") == "make_gRPC_request" + + assert snake_case("value1") == "value1" + assert snake_case("value1string") == "value1string" + + # It is difficult to cover all the cases with a simple algorithm... + # "GetValueAsUInt32" -> "get_value_as_u_int_32" From 62c1faef6805b9845de92db7f1d53ac405958dfc Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Mon, 20 Jan 2025 18:03:06 +0100 Subject: [PATCH 4/4] Update version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e93ccaa3..878eb62c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "betterproto2_compiler" -version = "0.2.1" +version = "0.2.2" description = "Compiler for betterproto2" authors = ["Adrien Vannson ", "Daniel G. Taylor "] readme = "README.md"