Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.

Commit 95293b9

Browse files
Fix snake case (#39)
* Make snake case safer * Rename default parameter name * Add test * Update version
1 parent 435e0f2 commit 95293b9

File tree

6 files changed

+64
-56
lines changed

6 files changed

+64
-56
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "betterproto2_compiler"
3-
version = "0.2.1"
3+
version = "0.2.2"
44
description = "Compiler for betterproto2"
55
authors = ["Adrien Vannson <[email protected]>", "Daniel G. Taylor <[email protected]>"]
66
readme = "README.md"

src/betterproto2_compiler/casing.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,25 @@ def safe_snake_case(value: str) -> str:
2121
return value
2222

2323

24-
def snake_case(value: str, strict: bool = True) -> str:
24+
def snake_case(name: str) -> str:
2525
"""
2626
Join words with an underscore into lowercase and remove symbols.
27+
"""
2728

28-
Parameters
29-
-----------
30-
value: :class:`str`
31-
The value to convert.
32-
strict: :class:`bool`
33-
Whether or not to force single underscores.
29+
# If there are already underscores in the name, don't break it
30+
if "_" in name or not any([c.isupper() for c in name]):
31+
return name
3432

35-
Returns
36-
--------
37-
:class:`str`
38-
The value in snake_case.
39-
"""
33+
# Add an underscore before capital letters
34+
name = re.sub(r"(?<=[a-z0-9])([A-Z])", r"_\1", name)
4035

41-
def substitute_word(symbols: str, word: str, is_start: bool) -> str:
42-
if not word:
43-
return ""
44-
if strict:
45-
delimiter_count = 0 if is_start else 1 # Single underscore if strict.
46-
elif is_start:
47-
delimiter_count = len(symbols)
48-
elif word.isupper() or word.islower():
49-
delimiter_count = max(1, len(symbols)) # Preserve all delimiters if not strict.
50-
else:
51-
delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
36+
# Add an underscore before capital letters following an acronym
37+
name = re.sub(r"(?<=[A-Z])([A-Z])(?=[a-z])", r"_\1", name)
5238

53-
return ("_" * delimiter_count) + word.lower()
39+
# Add an underscore before digits
40+
name = re.sub(r"(?<=[a-zA-Z])([0-9])", r"_\1", name)
5441

55-
snake = re.sub(
56-
f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
57-
lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
58-
value,
59-
)
60-
return snake
42+
return name.lower()
6143

6244

6345
def pascal_case(value: str, strict: bool = True) -> str:

src/betterproto2_compiler/compile/importing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -
114114
Returns a reference to a python type located in the root, i.e. sys.path.
115115
"""
116116
string_import = ".".join(py_package)
117-
string_alias = safe_snake_case(string_import)
117+
string_alias = "__".join([safe_snake_case(name) for name in py_package])
118118
imports.add(f"import {string_import} as {string_alias}")
119119
return f"{string_alias}.{py_type}"
120120

@@ -175,6 +175,11 @@ def reference_cousin(current_package: list[str], imports: set[str], py_package:
175175
string_from = f".{'.' * distance_up}" + ".".join(py_package[len(shared_ancestry) : -1])
176176
string_import = py_package[-1]
177177
# Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
178-
string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__"
178+
# string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__"
179+
string_alias = (
180+
f"{'_' * distance_up}"
181+
+ "__".join([safe_snake_case(name) for name in py_package[len(shared_ancestry) :]])
182+
+ "__"
183+
)
179184
imports.add(f"from {string_from} import {string_import} as {string_alias}")
180185
return f"{string_alias}.{py_type}"

src/betterproto2_compiler/plugin/models.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -600,17 +600,6 @@ def is_input_msg_empty(self: "ServiceMethodCompiler") -> bool:
600600

601601
return not bool(msg.fields)
602602

603-
@property
604-
def py_input_message_param(self) -> str:
605-
"""Param name corresponding to py_input_message_type.
606-
607-
Returns
608-
-------
609-
str
610-
Param name corresponding to py_input_message_type.
611-
"""
612-
return pythonize_field_name(self.py_input_message_type)
613-
614603
@property
615604
def py_output_message_type(self) -> str:
616605
"""String representation of the Python type corresponding to the

src/betterproto2_compiler/templates/template.py.j2

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,15 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
9898
{% for method in service.methods %}
9999
async def {{ method.py_name }}(self
100100
{%- if not method.client_streaming -%}
101-
, {{ method.py_input_message_param }}:
101+
, message:
102102
{%- if method.is_input_msg_empty -%}
103103
"{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None
104104
{%- else -%}
105105
"{{ method.py_input_message_type }}"
106106
{%- endif -%}
107107
{%- else -%}
108108
{# Client streaming: need a request iterator instead #}
109-
, {{ method.py_input_message_param }}_iterator: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
109+
, messages: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
110110
{%- endif -%}
111111
,
112112
*
@@ -128,7 +128,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
128128
{% if method.client_streaming %}
129129
async for response in self._stream_stream(
130130
"{{ method.route }}",
131-
{{ method.py_input_message_param }}_iterator,
131+
messages,
132132
{{ method.py_input_message_type }},
133133
{{ method.py_output_message_type }},
134134
timeout=timeout,
@@ -138,13 +138,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
138138
yield response
139139
{% else %}{# i.e. not client streaming #}
140140
{% if method.is_input_msg_empty %}
141-
if {{ method.py_input_message_param }} is None:
142-
{{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
141+
if message is None:
142+
message = {{ method.py_input_message_type }}()
143143

144144
{% endif %}
145145
async for response in self._unary_stream(
146146
"{{ method.route }}",
147-
{{ method.py_input_message_param }},
147+
message,
148148
{{ method.py_output_message_type }},
149149
timeout=timeout,
150150
deadline=deadline,
@@ -157,7 +157,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
157157
{% if method.client_streaming %}
158158
return await self._stream_unary(
159159
"{{ method.route }}",
160-
{{ method.py_input_message_param }}_iterator,
160+
messages,
161161
{{ method.py_input_message_type }},
162162
{{ method.py_output_message_type }},
163163
timeout=timeout,
@@ -166,13 +166,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
166166
)
167167
{% else %}{# i.e. not client streaming #}
168168
{% if method.is_input_msg_empty %}
169-
if {{ method.py_input_message_param }} is None:
170-
{{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
169+
if message is None:
170+
message = {{ method.py_input_message_type }}()
171171

172172
{% endif %}
173173
return await self._unary_unary(
174174
"{{ method.route }}",
175-
{{ method.py_input_message_param }},
175+
message,
176176
{{ method.py_output_message_type }},
177177
timeout=timeout,
178178
deadline=deadline,
@@ -199,10 +199,10 @@ class {{ service.py_name }}Base(ServiceBase):
199199
{% for method in service.methods %}
200200
async def {{ method.py_name }}(self
201201
{%- if not method.client_streaming -%}
202-
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
202+
, message: "{{ method.py_input_message_type }}"
203203
{%- else -%}
204204
{# Client streaming: need a request iterator instead #}
205-
, {{ method.py_input_message_param }}_iterator: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
205+
, messages: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
206206
{%- endif -%}
207207
) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
208208
{% if method.comment %}

tests/test_casing.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
def test_snake_case() -> None:
2+
from betterproto2_compiler.casing import snake_case
3+
4+
# Simple renaming
5+
assert snake_case("methodName") == "method_name"
6+
assert snake_case("MethodName") == "method_name"
7+
8+
# Don't break acronyms
9+
assert snake_case("HTTPRequest") == "http_request"
10+
assert snake_case("RequestHTTP") == "request_http"
11+
assert snake_case("HTTPRequest2") == "http_request_2"
12+
assert snake_case("RequestHTTP2") == "request_http_2"
13+
assert snake_case("GetAResponse") == "get_a_response"
14+
15+
# Split digits
16+
assert snake_case("Get2025Results") == "get_2025_results"
17+
assert snake_case("Get10yResults") == "get_10y_results"
18+
19+
# If the name already contains an underscore or is lowercase, don't change it at all.
20+
# There is a risk of breaking names otherwise.
21+
assert snake_case("aaa_123_bbb") == "aaa_123_bbb"
22+
assert snake_case("aaa_123bbb") == "aaa_123bbb"
23+
assert snake_case("aaa123_bbb") == "aaa123_bbb"
24+
assert snake_case("get_HTTP_response") == "get_HTTP_response"
25+
assert snake_case("_methodName") == "_methodName"
26+
assert snake_case("make_gRPC_request") == "make_gRPC_request"
27+
28+
assert snake_case("value1") == "value1"
29+
assert snake_case("value1string") == "value1string"
30+
31+
# It is difficult to cover all the cases with a simple algorithm...
32+
# "GetValueAsUInt32" -> "get_value_as_u_int_32"

0 commit comments

Comments
 (0)