Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "betterproto2_compiler"
version = "0.2.1"
version = "0.2.2"
description = "Compiler for betterproto2"
authors = ["Adrien Vannson <[email protected]>", "Daniel G. Taylor <[email protected]>"]
readme = "README.md"
Expand Down
42 changes: 12 additions & 30 deletions src/betterproto2_compiler/casing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,25 @@ def safe_snake_case(value: str) -> str:
return value


def snake_case(value: str, strict: bool = True) -> str:
def snake_case(name: str) -> str:
"""
Join words with an underscore into lowercase and remove symbols.
"""

Parameters
-----------
value: :class:`str`
The value to convert.
strict: :class:`bool`
Whether or not to force single underscores.
# If there are already underscores in the name, don't break it
if "_" in name or not any([c.isupper() for c in name]):
return name

Returns
--------
:class:`str`
The value in snake_case.
"""
# Add an underscore before capital letters
name = re.sub(r"(?<=[a-z0-9])([A-Z])", r"_\1", name)

def substitute_word(symbols: str, word: str, is_start: bool) -> str:
if not word:
return ""
if strict:
delimiter_count = 0 if is_start else 1 # Single underscore if strict.
elif is_start:
delimiter_count = len(symbols)
elif word.isupper() or word.islower():
delimiter_count = max(1, len(symbols)) # Preserve all delimiters if not strict.
else:
delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
# Add an underscore before capital letters following an acronym
name = re.sub(r"(?<=[A-Z])([A-Z])(?=[a-z])", r"_\1", name)

return ("_" * delimiter_count) + word.lower()
# Add an underscore before digits
name = re.sub(r"(?<=[a-zA-Z])([0-9])", r"_\1", name)

snake = re.sub(
f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
value,
)
return snake
return name.lower()


def pascal_case(value: str, strict: bool = True) -> str:
Expand Down
9 changes: 7 additions & 2 deletions src/betterproto2_compiler/compile/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -
Returns a reference to a python type located in the root, i.e. sys.path.
"""
string_import = ".".join(py_package)
string_alias = safe_snake_case(string_import)
string_alias = "__".join([safe_snake_case(name) for name in py_package])
imports.add(f"import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"

Expand Down Expand Up @@ -175,6 +175,11 @@ def reference_cousin(current_package: list[str], imports: set[str], py_package:
string_from = f".{'.' * distance_up}" + ".".join(py_package[len(shared_ancestry) : -1])
string_import = py_package[-1]
# Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__"
# string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__"
string_alias = (
f"{'_' * distance_up}"
+ "__".join([safe_snake_case(name) for name in py_package[len(shared_ancestry) :]])
+ "__"
)
imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"
11 changes: 0 additions & 11 deletions src/betterproto2_compiler/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,17 +600,6 @@ def is_input_msg_empty(self: "ServiceMethodCompiler") -> bool:

return not bool(msg.fields)

@property
def py_input_message_param(self) -> str:
"""Param name corresponding to py_input_message_type.

Returns
-------
str
Param name corresponding to py_input_message_type.
"""
return pythonize_field_name(self.py_input_message_type)

@property
def py_output_message_type(self) -> str:
"""String representation of the Python type corresponding to the
Expand Down
24 changes: 12 additions & 12 deletions src/betterproto2_compiler/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
, {{ method.py_input_message_param }}:
, message:
{%- if method.is_input_msg_empty -%}
"{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None
{%- else -%}
"{{ method.py_input_message_type }}"
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
, messages: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
{%- endif -%}
,
*
Expand All @@ -128,7 +128,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
{% if method.client_streaming %}
async for response in self._stream_stream(
"{{ method.route }}",
{{ method.py_input_message_param }}_iterator,
messages,
{{ method.py_input_message_type }},
{{ method.py_output_message_type }},
timeout=timeout,
Expand All @@ -138,13 +138,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
yield response
{% else %}{# i.e. not client streaming #}
{% if method.is_input_msg_empty %}
if {{ method.py_input_message_param }} is None:
{{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
if message is None:
message = {{ method.py_input_message_type }}()

{% endif %}
async for response in self._unary_stream(
"{{ method.route }}",
{{ method.py_input_message_param }},
message,
{{ method.py_output_message_type }},
timeout=timeout,
deadline=deadline,
Expand All @@ -157,7 +157,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
{% if method.client_streaming %}
return await self._stream_unary(
"{{ method.route }}",
{{ method.py_input_message_param }}_iterator,
messages,
{{ method.py_input_message_type }},
{{ method.py_output_message_type }},
timeout=timeout,
Expand All @@ -166,13 +166,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
)
{% else %}{# i.e. not client streaming #}
{% if method.is_input_msg_empty %}
if {{ method.py_input_message_param }} is None:
{{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
if message is None:
message = {{ method.py_input_message_type }}()

{% endif %}
return await self._unary_unary(
"{{ method.route }}",
{{ method.py_input_message_param }},
message,
{{ method.py_output_message_type }},
timeout=timeout,
deadline=deadline,
Expand All @@ -199,10 +199,10 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
, message: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
, messages: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
{%- endif -%}
) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
Expand Down
32 changes: 32 additions & 0 deletions tests/test_casing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
def test_snake_case() -> None:
from betterproto2_compiler.casing import snake_case

# Simple renaming
assert snake_case("methodName") == "method_name"
assert snake_case("MethodName") == "method_name"

# Don't break acronyms
assert snake_case("HTTPRequest") == "http_request"
assert snake_case("RequestHTTP") == "request_http"
assert snake_case("HTTPRequest2") == "http_request_2"
assert snake_case("RequestHTTP2") == "request_http_2"
assert snake_case("GetAResponse") == "get_a_response"

# Split digits
assert snake_case("Get2025Results") == "get_2025_results"
assert snake_case("Get10yResults") == "get_10y_results"

# If the name already contains an underscore or is lowercase, don't change it at all.
# There is a risk of breaking names otherwise.
assert snake_case("aaa_123_bbb") == "aaa_123_bbb"
assert snake_case("aaa_123bbb") == "aaa_123bbb"
assert snake_case("aaa123_bbb") == "aaa123_bbb"
assert snake_case("get_HTTP_response") == "get_HTTP_response"
assert snake_case("_methodName") == "_methodName"
assert snake_case("make_gRPC_request") == "make_gRPC_request"

assert snake_case("value1") == "value1"
assert snake_case("value1string") == "value1string"

# It is difficult to cover all the cases with a simple algorithm...
# "GetValueAsUInt32" -> "get_value_as_u_int_32"
Loading