Skip to content

Commit 5759e32

Browse files
committed
Add ability to provide metadata, timeout & deadline args to requests
This is an enhancement of the ServiceStub abstract class that makes it more useful by making it possible to pass all arguments supported by the underlying grpclib request function. It extends to the existing high level API by allowing values to be set on the stub instance, and the low level API by allowing values to be set per call.
1 parent c762c9c commit 5759e32

File tree

2 files changed

+149
-11
lines changed

2 files changed

+149
-11
lines changed

betterproto/__init__.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
Any,
1212
AsyncGenerator,
1313
Callable,
14+
Collection,
1415
Dict,
1516
Generator,
1617
Iterable,
1718
List,
19+
Mapping,
1820
Optional,
1921
SupportsBytes,
2022
Tuple,
@@ -1000,32 +1002,80 @@ def _get_wrapper(proto_type: str) -> Type:
10001002
}[proto_type]
10011003

10021004

1005+
_Value = Union[str, bytes]
1006+
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
1007+
1008+
10031009
class ServiceStub(ABC):
10041010
"""
10051011
Base class for async gRPC service stubs.
10061012
"""
10071013

1008-
def __init__(self, channel: grpclib.client.Channel) -> None:
1014+
def __init__(
1015+
self,
1016+
channel: grpclib.client.Channel,
1017+
*,
1018+
timeout: Optional[float] = None,
1019+
deadline: Optional[grpclib.metadata.Deadline] = None,
1020+
metadata: Optional[_MetadataLike] = None,
1021+
) -> None:
10091022
self.channel = channel
1023+
self.timeout = timeout
1024+
self.deadline = deadline
1025+
self.metadata = metadata
1026+
1027+
def __resolve_request_kwargs(
1028+
self,
1029+
timeout: Optional[float],
1030+
deadline: Optional[grpclib.metadata.Deadline],
1031+
metadata: Optional[_MetadataLike],
1032+
):
1033+
return {
1034+
"timeout": self.timeout if timeout is None else timeout,
1035+
"deadline": self.deadline if deadline is None else deadline,
1036+
"metadata": self.metadata if metadata is None else metadata,
1037+
}
10101038

10111039
async def _unary_unary(
1012-
self, route: str, request: "IProtoMessage", response_type: Type[T]
1040+
self,
1041+
route: str,
1042+
request: "IProtoMessage",
1043+
response_type: Type[T],
1044+
*,
1045+
timeout: Optional[float] = None,
1046+
deadline: Optional[grpclib.metadata.Deadline] = None,
1047+
metadata: Optional[_MetadataLike] = None,
10131048
) -> T:
10141049
"""Make a unary request and return the response."""
10151050
async with self.channel.request(
1016-
route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type
1051+
route,
1052+
grpclib.const.Cardinality.UNARY_UNARY,
1053+
type(request),
1054+
response_type,
1055+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
10171056
) as stream:
10181057
await stream.send_message(request, end=True)
10191058
response = await stream.recv_message()
10201059
assert response is not None
10211060
return response
10221061

10231062
async def _unary_stream(
1024-
self, route: str, request: "IProtoMessage", response_type: Type[T]
1063+
self,
1064+
route: str,
1065+
request: "IProtoMessage",
1066+
response_type: Type[T],
1067+
*,
1068+
timeout: Optional[float] = None,
1069+
deadline: Optional[grpclib.metadata.Deadline] = None,
1070+
metadata: Optional[_MetadataLike] = None,
10251071
) -> AsyncGenerator[T, None]:
10261072
"""Make a unary request and return the stream response iterator."""
10271073
async with self.channel.request(
1028-
route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type
1074+
route,
1075+
grpclib.const.Cardinality.UNARY_STREAM,
1076+
type(request),
1077+
response_type,
1078+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
10291079
) as stream:
10301080
await stream.send_message(request, end=True)
10311081
async for message in stream:

betterproto/tests/test_service_stub.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,24 @@
77

88

99
class ExampleService:
10+
def __init__(self, test_hook=None):
11+
# This lets us pass assertions to the servicer ;)
12+
self.test_hook = test_hook
1013

11-
async def DoThing(self, stream: 'grpclib.server.Stream[DoThingRequest, DoThingResponse]'):
14+
async def DoThing(
15+
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
16+
):
1217
request = await stream.recv_message()
18+
print("self.test_hook", self.test_hook)
19+
if self.test_hook is not None:
20+
self.test_hook(stream)
1321
for iteration in range(request.iterations):
1422
pass
1523
await stream.send_message(DoThingResponse(request.iterations))
1624

17-
1825
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
1926
return {
20-
'/service.ExampleService/DoThing': grpclib.const.Handler(
27+
"/service.ExampleService/DoThing": grpclib.const.Handler(
2128
self.DoThing,
2229
grpclib.const.Cardinality.UNARY_UNARY,
2330
DoThingRequest,
@@ -26,10 +33,91 @@ def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
2633
}
2734

2835

36+
async def _test_stub(stub, iterations=42, **kwargs):
37+
response = await stub.do_thing(iterations=iterations)
38+
assert response.successful_iterations == iterations
39+
40+
41+
def _get_server_side_test(deadline, metadata):
42+
def server_side_test(stream):
43+
assert stream.deadline._timestamp == pytest.approx(
44+
deadline._timestamp, 1
45+
), "The provided deadline should be recieved serverside"
46+
assert (
47+
stream.metadata["authorization"] == metadata["authorization"]
48+
), "The provided authorization metadata should be recieved serverside"
49+
50+
return server_side_test
51+
52+
2953
@pytest.mark.asyncio
3054
async def test_simple_service_call():
31-
ITERATIONS = 42
3255
async with ChannelFor([ExampleService()]) as channel:
33-
stub = ExampleServiceStub(channel)
34-
response = await stub.do_thing(iterations=ITERATIONS)
56+
await _test_stub(ExampleServiceStub(channel))
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_service_call_with_upfront_request_params():
61+
# Setting deadline
62+
deadline = grpclib.metadata.Deadline.from_timeout(22)
63+
metadata = {"authorization": "12345"}
64+
async with ChannelFor(
65+
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
66+
) as channel:
67+
await _test_stub(
68+
ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
69+
)
70+
71+
# Setting timeout
72+
timeout = 99
73+
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
74+
metadata = {"authorization": "12345"}
75+
async with ChannelFor(
76+
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
77+
) as channel:
78+
await _test_stub(
79+
ExampleServiceStub(channel, timeout=timeout, metadata=metadata)
80+
)
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_service_call_lower_level_with_overrides():
85+
ITERATIONS = 99
86+
87+
# Setting deadline
88+
deadline = grpclib.metadata.Deadline.from_timeout(22)
89+
metadata = {"authorization": "12345"}
90+
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
91+
kwarg_metadata = {"authorization": "12345"}
92+
async with ChannelFor(
93+
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
94+
) as channel:
95+
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
96+
response = await stub._unary_unary(
97+
"/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse,
98+
deadline=kwarg_deadline,
99+
metadata=kwarg_metadata,
100+
)
101+
assert response.successful_iterations == ITERATIONS
102+
103+
# Setting timeout
104+
timeout = 99
105+
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
106+
metadata = {"authorization": "12345"}
107+
kwarg_timeout = 9000
108+
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
109+
kwarg_metadata = {"authorization": "09876"}
110+
async with ChannelFor(
111+
[
112+
ExampleService(
113+
test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata)
114+
)
115+
]
116+
) as channel:
117+
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
118+
response = await stub._unary_unary(
119+
"/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse,
120+
timeout=kwarg_timeout,
121+
metadata=kwarg_metadata,
122+
)
35123
assert response.successful_iterations == ITERATIONS

0 commit comments

Comments
 (0)