Skip to content

Commit 18a518e

Browse files
authored
Expose timeout, deadline and metadata parameters from grpclib (#352)
1 parent 62da35b commit 18a518e

File tree

4 files changed

+101
-18
lines changed

4 files changed

+101
-18
lines changed

src/betterproto/grpc/grpclib_client.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from grpclib.metadata import Deadline
2323

2424

25-
_Value = Union[str, bytes]
26-
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
27-
_MessageLike = Union[T, ST]
28-
_MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
25+
Value = Union[str, bytes]
26+
MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
27+
MessageLike = Union[T, ST]
28+
MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
2929

3030

3131
class ServiceStub(ABC):
@@ -39,7 +39,7 @@ def __init__(
3939
*,
4040
timeout: Optional[float] = None,
4141
deadline: Optional["Deadline"] = None,
42-
metadata: Optional[_MetadataLike] = None,
42+
metadata: Optional[MetadataLike] = None,
4343
) -> None:
4444
self.channel = channel
4545
self.timeout = timeout
@@ -50,7 +50,7 @@ def __resolve_request_kwargs(
5050
self,
5151
timeout: Optional[float],
5252
deadline: Optional["Deadline"],
53-
metadata: Optional[_MetadataLike],
53+
metadata: Optional[MetadataLike],
5454
):
5555
return {
5656
"timeout": self.timeout if timeout is None else timeout,
@@ -61,12 +61,12 @@ def __resolve_request_kwargs(
6161
async def _unary_unary(
6262
self,
6363
route: str,
64-
request: _MessageLike,
64+
request: MessageLike,
6565
response_type: Type[T],
6666
*,
6767
timeout: Optional[float] = None,
6868
deadline: Optional["Deadline"] = None,
69-
metadata: Optional[_MetadataLike] = None,
69+
metadata: Optional[MetadataLike] = None,
7070
) -> T:
7171
"""Make a unary request and return the response."""
7272
async with self.channel.request(
@@ -84,12 +84,12 @@ async def _unary_unary(
8484
async def _unary_stream(
8585
self,
8686
route: str,
87-
request: _MessageLike,
87+
request: MessageLike,
8888
response_type: Type[T],
8989
*,
9090
timeout: Optional[float] = None,
9191
deadline: Optional["Deadline"] = None,
92-
metadata: Optional[_MetadataLike] = None,
92+
metadata: Optional[MetadataLike] = None,
9393
) -> AsyncIterator[T]:
9494
"""Make a unary request and return the stream response iterator."""
9595
async with self.channel.request(
@@ -106,13 +106,13 @@ async def _unary_stream(
106106
async def _stream_unary(
107107
self,
108108
route: str,
109-
request_iterator: _MessageSource,
109+
request_iterator: MessageSource,
110110
request_type: Type[ST],
111111
response_type: Type[T],
112112
*,
113113
timeout: Optional[float] = None,
114114
deadline: Optional["Deadline"] = None,
115-
metadata: Optional[_MetadataLike] = None,
115+
metadata: Optional[MetadataLike] = None,
116116
) -> T:
117117
"""Make a stream request and return the response."""
118118
async with self.channel.request(
@@ -130,13 +130,13 @@ async def _stream_unary(
130130
async def _stream_stream(
131131
self,
132132
route: str,
133-
request_iterator: _MessageSource,
133+
request_iterator: MessageSource,
134134
request_type: Type[ST],
135135
response_type: Type[T],
136136
*,
137137
timeout: Optional[float] = None,
138138
deadline: Optional["Deadline"] = None,
139-
metadata: Optional[_MetadataLike] = None,
139+
metadata: Optional[MetadataLike] = None,
140140
) -> AsyncIterator[T]:
141141
"""
142142
Make a stream request and return an AsyncIterator to iterate over response
@@ -161,7 +161,7 @@ async def _stream_stream(
161161
raise
162162

163163
@staticmethod
164-
async def _send_messages(stream, messages: _MessageSource):
164+
async def _send_messages(stream, messages: MessageSource):
165165
if isinstance(messages, AsyncIterable):
166166
async for message in messages:
167167
await stream.send_message(message)

src/betterproto/plugin/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ class OutputTemplate:
232232
messages: List["MessageCompiler"] = field(default_factory=list)
233233
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
234234
services: List["ServiceCompiler"] = field(default_factory=list)
235+
imports_type_checking_only: Set[str] = field(default_factory=set)
235236

236237
@property
237238
def package(self) -> str:
@@ -679,6 +680,15 @@ def __post_init__(self) -> None:
679680
if self.client_streaming or self.server_streaming:
680681
self.output_file.typing_imports.add("AsyncIterator")
681682

683+
# add imports required for request arguments timeout, deadline and metadata
684+
self.output_file.typing_imports.add("Optional")
685+
self.output_file.imports_type_checking_only.add(
686+
"from betterproto.grpc.grpclib_client import MetadataLike"
687+
)
688+
self.output_file.imports_type_checking_only.add(
689+
"from grpclib.metadata import Deadline"
690+
)
691+
682692
super().__post_init__() # check for unset fields
683693

684694
@property

src/betterproto/templates/template.py.j2

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ from betterproto.grpc.grpclib_server import ServiceBase
2020
import grpclib
2121
{% endif %}
2222

23+
{% if output_file.imports_type_checking_only %}
24+
from typing import TYPE_CHECKING
25+
26+
if TYPE_CHECKING:
27+
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
28+
{% endfor %}
29+
{% endif %}
2330

2431
{% if output_file.enums %}{% for enum in output_file.enums %}
2532
class {{ enum.py_name }}(betterproto.Enum):
@@ -86,6 +93,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
8693
{# Client streaming: need a request iterator instead #}
8794
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
8895
{%- endif -%}
96+
, timeout: Optional[float] = None
97+
, deadline: Optional["Deadline"] = None
98+
, metadata: Optional["_MetadataLike"] = None
8999
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
90100
{% if method.comment %}
91101
{{ method.comment }}
@@ -98,13 +108,19 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
98108
{{ method.py_input_message_param }}_iterator,
99109
{{ method.py_input_message_type }},
100110
{{ method.py_output_message_type.strip('"') }},
111+
timeout=timeout,
112+
deadline=deadline,
113+
metadata=metadata,
101114
):
102115
yield response
103116
{% else %}{# i.e. not client streaming #}
104117
async for response in self._unary_stream(
105118
"{{ method.route }}",
106119
{{ method.py_input_message_param }},
107120
{{ method.py_output_message_type.strip('"') }},
121+
timeout=timeout,
122+
deadline=deadline,
123+
metadata=metadata,
108124
):
109125
yield response
110126

@@ -115,13 +131,19 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
115131
"{{ method.route }}",
116132
{{ method.py_input_message_param }}_iterator,
117133
{{ method.py_input_message_type }},
118-
{{ method.py_output_message_type.strip('"') }}
134+
{{ method.py_output_message_type.strip('"') }},
135+
timeout=timeout,
136+
deadline=deadline,
137+
metadata=metadata,
119138
)
120139
{% else %}{# i.e. not client streaming #}
121140
return await self._unary_unary(
122141
"{{ method.route }}",
123142
{{ method.py_input_message_param }},
124-
{{ method.py_output_message_type.strip('"') }}
143+
{{ method.py_output_message_type.strip('"') }},
144+
timeout=timeout,
145+
deadline=deadline,
146+
metadata=metadata,
125147
)
126148
{% endif %}{# client streaming #}
127149
{% endif %}

tests/grpc/test_grpclib_client.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
22
import sys
3+
import uuid
34

45
import grpclib
56
import grpclib.metadata
67
import grpclib.server
8+
import grpclib.client
79
import pytest
810
from betterproto.grpc.util.async_channel import AsyncChannel
911
from grpclib.testing import ChannelFor
@@ -18,7 +20,7 @@
1820

1921

2022
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
21-
response = await client.do_thing(DoThingRequest(name=name))
23+
response = await client.do_thing(DoThingRequest(name=name), **kwargs)
2224
assert response.names == [name]
2325

2426

@@ -172,6 +174,55 @@ async def test_service_call_lower_level_with_overrides():
172174
assert response.names == [THING_TO_DO]
173175

174176

177+
@pytest.mark.asyncio
178+
@pytest.mark.parametrize(
179+
("overrides",),
180+
[
181+
(dict(timeout=10),),
182+
(dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
183+
(dict(metadata={"authorization": str(uuid.uuid4())}),),
184+
(dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
185+
],
186+
)
187+
async def test_service_call_high_level_with_overrides(mocker, overrides):
188+
request_spy = mocker.spy(grpclib.client.Channel, "request")
189+
name = str(uuid.uuid4())
190+
defaults = dict(
191+
timeout=99,
192+
deadline=grpclib.metadata.Deadline.from_timeout(99),
193+
metadata={"authorization": name},
194+
)
195+
196+
async with ChannelFor(
197+
[
198+
ThingService(
199+
test_hook=_assert_request_meta_received(
200+
deadline=grpclib.metadata.Deadline.from_timeout(
201+
overrides.get("timeout", 99)
202+
),
203+
metadata=overrides.get("metadata", defaults.get("metadata")),
204+
)
205+
)
206+
]
207+
) as channel:
208+
client = ThingServiceClient(channel, **defaults)
209+
await _test_client(client, name=name, **overrides)
210+
assert request_spy.call_count == 1
211+
212+
# for python <3.8 request_spy.call_args.kwargs do not work
213+
_, request_spy_call_kwargs = request_spy.call_args_list[0]
214+
215+
# ensure all overrides were successful
216+
for key, value in overrides.items():
217+
assert key in request_spy_call_kwargs
218+
assert request_spy_call_kwargs[key] == value
219+
220+
# ensure default values were retained
221+
for key in set(defaults.keys()) - set(overrides.keys()):
222+
assert key in request_spy_call_kwargs
223+
assert request_spy_call_kwargs[key] == defaults[key]
224+
225+
175226
@pytest.mark.asyncio
176227
async def test_async_gen_for_unary_stream_request():
177228
thing_name = "my milkshakes"

0 commit comments

Comments
 (0)