Skip to content

Commit a757da1

Browse files
hoznnat-n
authored andcommitted
Adding basic support (untested) for client streaming
1 parent a46979c commit a757da1

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

betterproto/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
Collection,
1515
Dict,
1616
Generator,
17+
Iterator,
1718
List,
1819
Mapping,
1920
Optional,
2021
Set,
22+
SupportsBytes,
2123
Tuple,
2224
Type,
2325
TypeVar,
@@ -431,6 +433,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
431433

432434
# Bound type variable to allow methods to return `self` of subclasses
433435
T = TypeVar("T", bound="Message")
436+
ST = TypeVar("ST", bound="IProtoMessage")
434437

435438

436439
class ProtoClassMetadata:
@@ -1104,3 +1107,38 @@ async def _unary_stream(
11041107
await stream.send_message(request, end=True)
11051108
async for message in stream:
11061109
yield message
1110+
1111+
async def _stream_unary(
1112+
self,
1113+
route: str,
1114+
request_iterator: Iterator["IProtoMessage"],
1115+
request_type: Type[ST],
1116+
response_type: Type[T],
1117+
) -> T:
1118+
"""Make a stream request and return the response."""
1119+
async with self.channel.request(
1120+
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
1121+
) as stream:
1122+
for message in request_iterator:
1123+
await stream.send_message(message)
1124+
await stream.send_request(end=True)
1125+
response = await stream.recv_message()
1126+
assert response is not None
1127+
return response
1128+
1129+
async def _stream_stream(
1130+
self,
1131+
route: str,
1132+
request_iterator: Iterator["IProtoMessage"],
1133+
request_type: Type[ST],
1134+
response_type: Type[T],
1135+
) -> AsyncGenerator[T, None]:
1136+
"""Make a stream request and return the stream response iterator."""
1137+
async with self.channel.request(
1138+
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
1139+
) as stream:
1140+
for message in request_iterator:
1141+
await stream.send_message(message)
1142+
await stream.send_request(end=True)
1143+
async for message in stream:
1144+
yield message

betterproto/plugin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,6 @@ def generate_code(request, response):
311311
}
312312

313313
for j, method in enumerate(service.method):
314-
if method.client_streaming:
315-
raise NotImplementedError("Client streaming not yet supported")
316314

317315
input_message = None
318316
input_type = get_ref_type(
@@ -350,6 +348,9 @@ def generate_code(request, response):
350348
if method.server_streaming:
351349
output["typing_imports"].add("AsyncGenerator")
352350

351+
if method.client_streaming:
352+
output["typing_imports"].add("Iterator")
353+
353354
output["services"].append(data)
354355

355356
output["imports"] = sorted(output["imports"])

betterproto/templates/template.py.j2

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
6363

6464
{% endif %}
6565
{% for method in service.methods %}
66-
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
66+
async def {{ method.py_name }}(self
67+
{%- if not method.client_streaming -%}
68+
{%- if method.input_message and method.input_message.properties -%}, *,
69+
{%- for field in method.input_message.properties -%}
70+
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
71+
Optional[{{ field.type }}]
72+
{%- else -%}
73+
{{ field.type }}
74+
{%- endif -%} = {{ field.zero }}
75+
{%- if not loop.last %}, {% endif -%}
76+
{%- endfor -%}
77+
{%- endif -%}
78+
{%- else -%}
79+
{# Client streaming: need a request iterator instead #}
80+
, request_iterator: Iterator["{{ method.input }}"]
81+
{%- endif -%}
82+
) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
6783
{% if method.comment %}
6884
{{ method.comment }}
6985

7086
{% endif %}
87+
{% if not method.client_streaming %}
7188
request = {{ method.input }}()
7289
{% for field in method.input_message.properties %}
7390
{% if field.field_type == 'message' %}
@@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
7794
request.{{ field.py_name }} = {{ field.py_name }}
7895
{% endif %}
7996
{% endfor %}
97+
{% endif %}
8098

8199
{% if method.server_streaming %}
100+
{% if method.client_streaming %}
101+
async for response in self._stream_stream(
102+
"{{ method.route }}",
103+
request_iterator,
104+
{{ method.input }},
105+
{{ method.output }},
106+
):
107+
yield response
108+
{% else %}{# i.e. not client streaming #}
82109
async for response in self._unary_stream(
83110
"{{ method.route }}",
84111
request,
85112
{{ method.output }},
86113
):
87114
yield response
88-
{% else %}
115+
116+
{% endif %}{# if client streaming #}
117+
{% else %}{# i.e. not server streaming #}
118+
{% if method.client_streaming %}
119+
return await self._stream_unary(
120+
"{{ method.route }}",
121+
request_iterator,
122+
{{ method.input }},
123+
{{ method.output }}
124+
)
125+
{% else %}{# i.e. not client streaming #}
89126
return await self._unary_unary(
90127
"{{ method.route }}",
91128
request,
92-
{{ method.output }},
129+
{{ method.output }}
93130
)
131+
{% endif %}{# client streaming #}
94132
{% endif %}
95133

96134
{% endfor %}

0 commit comments

Comments
 (0)