Skip to content

Commit 29cb342

Browse files
Allow calling a method without providing the request if the request is empty (#17)
1 parent 6e99b60 commit 29cb342

File tree

4 files changed

+62
-3
lines changed

4 files changed

+62
-3
lines changed

src/betterproto/plugin/models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@
6464
)
6565
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
6666

67-
from .. import which_one_of
68-
from ..compile.importing import get_type_reference
67+
from ..compile.importing import get_type_reference, parse_source_type_name
6968
from ..compile.naming import (
7069
pythonize_class_name,
7170
pythonize_enum_member_name,
@@ -703,6 +702,14 @@ def py_input_message_type(self) -> str:
703702
pydantic=self.output_file.pydantic_dataclasses,
704703
).strip('"')
705704

705+
@property
706+
def is_input_msg_empty(self: "ServiceMethodCompiler") -> bool:
707+
package, name = parse_source_type_name(self.proto_obj.input_type, self.request)
708+
709+
msg = self.request.output_packages[package].messages[name]
710+
711+
return not bool(msg.fields)
712+
706713
@property
707714
def py_input_message_param(self) -> str:
708715
"""Param name corresponding to py_input_message_type.

src/betterproto/templates/template.py.j2

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
8383
{% for method in service.methods %}
8484
async def {{ method.py_name }}(self
8585
{%- if not method.client_streaming -%}
86-
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
86+
, {{ method.py_input_message_param }}:
87+
{%- if method.is_input_msg_empty -%}
88+
"{{ method.py_input_message_type }} | None" = None
89+
{%- else -%}
90+
"{{ method.py_input_message_type }}"
91+
{%- endif -%}
8792
{%- else -%}
8893
{# Client streaming: need a request iterator instead #}
8994
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}"
@@ -117,6 +122,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
117122
):
118123
yield response
119124
{% else %}{# i.e. not client streaming #}
125+
{% if method.is_input_msg_empty %}
126+
if {{ method.py_input_message_param }} is None:
127+
{{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
128+
129+
{% endif %}
120130
async for response in self._unary_stream(
121131
"{{ method.route }}",
122132
{{ method.py_input_message_param }},
@@ -140,6 +150,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
140150
metadata=metadata,
141151
)
142152
{% else %}{# i.e. not client streaming #}
153+
{% if method.is_input_msg_empty %}
154+
if {{ method.py_input_message_param }} is None:
155+
{{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
156+
157+
{% endif %}
143158
return await self._unary_unary(
144159
"{{ method.route }}",
145160
{{ method.py_input_message_param }},
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
package rpc_empty_input_message;
4+
5+
message Test {}
6+
7+
message Response {
8+
int32 v = 1;
9+
}
10+
11+
service Service {
12+
rpc read(Test) returns (Response);
13+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
from grpclib.testing import ChannelFor
3+
4+
5+
@pytest.mark.asyncio
6+
async def test_rpc_input_message():
7+
from tests.output_betterproto.rpc_empty_input_message import (
8+
Response,
9+
ServiceBase,
10+
ServiceStub,
11+
Test,
12+
)
13+
14+
class Service(ServiceBase):
15+
async def read(self, test: "Test") -> "Response":
16+
return Response(v=42)
17+
18+
async with ChannelFor([Service()]) as channel:
19+
client = ServiceStub(channel)
20+
21+
assert (await client.read(Test())).v == 42
22+
23+
# Check that we can call the method without providing the message
24+
assert (await client.read()).v == 42

0 commit comments

Comments
 (0)