Skip to content

Commit b24c99f

Browse files
Change interface for HMAC
This changes the interface for communicating the API key and the signing secret. Specifically it removes them from "authentication options" and moves them to their own parameters instead. This allows using the default pattern for channel options in places that previously did not touch it. This is desired as soon most clients will require the API key and the signing secret to function properly, and we want to keep the changes needed to them minimal. Signed-off-by: Florian Wagner <[email protected]>
1 parent b2bd71f commit b24c99f

File tree

7 files changed

+142
-76
lines changed

7 files changed

+142
-76
lines changed

RELEASE_NOTES.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
## Upgrading
88

9-
<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
9+
* Updated interface and behavior for HMAC
10+
11+
This introduces a new positional argument to `parse_grpc_uri`.
12+
If calling this function manually and passing `ChannelOptions`, it is recommended
13+
to switch to passing `ChannelOptions` via keyword argument.
1014

1115
## New Features
1216

src/frequenz/client/base/authentication.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
"""An Interceptor that adds the API key to a gRPC call."""
55

6-
import dataclasses
76
from typing import AsyncIterable, Callable
87

98
from grpc.aio import (
@@ -35,25 +34,17 @@ def _add_auth_header(
3534
client_call_details.metadata["key"] = key
3635

3736

38-
@dataclasses.dataclass(frozen=True)
39-
class AuthenticationOptions:
40-
"""Options for authenticating to the endpoint."""
41-
42-
api_key: str
43-
"""The API key to authenticate with."""
44-
45-
4637
# There is an issue in gRPC which means the type can not be specified correctly here.
4738
class AuthenticationInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
4839
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
4940

50-
def __init__(self, options: AuthenticationOptions):
41+
def __init__(self, api_key: str):
5142
"""Create an instance of the interceptor.
5243
5344
Args:
54-
options: The options for authenticating to the endpoint.
45+
api_key: The API key to send along for the request.
5546
"""
56-
self._key = options.api_key
47+
self._key = api_key
5748

5849
async def intercept_unary_unary(
5950
self,
@@ -83,13 +74,13 @@ async def intercept_unary_unary(
8374
class AuthenticationInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
8475
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
8576

86-
def __init__(self, options: AuthenticationOptions):
77+
def __init__(self, api_key: str):
8778
"""Create an instance of the interceptor.
8879
8980
Args:
90-
options: The options for authenticating to the endpoint.
81+
api_key: The API key to send along for the request.
9182
"""
92-
self._key = options.api_key
83+
self._key = api_key
9384

9485
async def intercept_unary_stream(
9586
self,

src/frequenz/client/base/channel.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dataclasses
77
import pathlib
88
from datetime import timedelta
9-
from typing import assert_never
9+
from typing import Sequence, assert_never
1010
from urllib.parse import parse_qs, urlparse
1111

1212
from grpc import ssl_channel_credentials
@@ -17,17 +17,6 @@
1717
secure_channel,
1818
)
1919

20-
from .authentication import (
21-
AuthenticationInterceptorUnaryStream,
22-
AuthenticationInterceptorUnaryUnary,
23-
AuthenticationOptions,
24-
)
25-
from .signing import (
26-
SigningInterceptorUnaryStream,
27-
SigningInterceptorUnaryUnary,
28-
SigningOptions,
29-
)
30-
3120

3221
@dataclasses.dataclass(frozen=True)
3322
class SslOptions:
@@ -85,15 +74,10 @@ class ChannelOptions:
8574
keep_alive: KeepAliveOptions = KeepAliveOptions()
8675
"""HTTP2 keep-alive options for the channel."""
8776

88-
sign: SigningOptions | None = None
89-
"""Signing options for the channel."""
90-
91-
auth: AuthenticationOptions | None = None
92-
"""Authentication options for the channel."""
93-
9477

9578
def parse_grpc_uri(
9679
uri: str,
80+
interceptors: Sequence[ClientInterceptor] = (),
9781
/,
9882
defaults: ChannelOptions = ChannelOptions(),
9983
) -> Channel:
@@ -131,6 +115,8 @@ def parse_grpc_uri(
131115
132116
Args:
133117
uri: The gRPC URI specifying the connection parameters.
118+
interceptors: A list of interceptors to apply to the channel. They are applied
119+
in the same order as they are passed in (see grpc interceptor docs for details)
134120
defaults: The default options use to create the channel when not specified in
135121
the URI.
136122
@@ -199,19 +185,6 @@ def parse_grpc_uri(
199185
else None
200186
)
201187

202-
interceptors: list[ClientInterceptor] = []
203-
if defaults.auth is not None:
204-
interceptors += [
205-
AuthenticationInterceptorUnaryUnary(options=defaults.auth), # type: ignore [list-item]
206-
AuthenticationInterceptorUnaryStream(options=defaults.auth), # type: ignore [list-item]
207-
]
208-
209-
if defaults.sign is not None:
210-
interceptors += [
211-
SigningInterceptorUnaryUnary(options=defaults.sign), # type: ignore [list-item]
212-
SigningInterceptorUnaryStream(options=defaults.sign), # type: ignore [list-item]
213-
]
214-
215188
ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
216189
if ssl:
217190
return secure_channel(

src/frequenz/client/base/client.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,25 @@
66
import abc
77
import inspect
88
from collections.abc import Awaitable, Callable
9+
from types import EllipsisType
910
from typing import Any, Generic, Self, TypeVar, overload
1011

1112
from grpc.aio import (
1213
AioRpcError,
1314
Channel,
15+
ClientInterceptor,
1416
)
1517

18+
from .authentication import (
19+
AuthenticationInterceptorUnaryStream,
20+
AuthenticationInterceptorUnaryUnary,
21+
)
1622
from .channel import ChannelOptions, parse_grpc_uri
1723
from .exception import ApiClientError, ClientNotConnected
24+
from .signing import (
25+
SigningInterceptorUnaryStream,
26+
SigningInterceptorUnaryUnary,
27+
)
1828

1929
StubT = TypeVar("StubT")
2030
"""The type of the gRPC stub."""
@@ -153,13 +163,15 @@ async def main():
153163
instances.
154164
"""
155165

156-
def __init__(
166+
def __init__( # pylint: disable=too-many-arguments
157167
self,
158168
server_url: str,
159169
create_stub: Callable[[Channel], StubT],
160170
*,
161171
connect: bool = True,
162172
channel_defaults: ChannelOptions = ChannelOptions(),
173+
auth_key: str | None = None,
174+
sign_secret: str | None = None,
163175
) -> None:
164176
"""Create an instance and connect to the server.
165177
@@ -172,14 +184,21 @@ def __init__(
172184
called.
173185
channel_defaults: The default options for the gRPC channel to create using
174186
the server URL.
187+
auth_key: The API key to use when connecting to the service.
188+
sign_secret: The secret to use when creating message HMAC.
189+
175190
"""
176191
self._server_url: str = server_url
177192
self._create_stub: Callable[[Channel], StubT] = create_stub
178193
self._channel_defaults: ChannelOptions = channel_defaults
194+
self._auth_key = auth_key
195+
self._sign_secret = sign_secret
179196
self._channel: Channel | None = None
180197
self._stub: StubT | None = None
181198
if connect:
182-
self.connect(server_url)
199+
self.connect(
200+
server_url=self._server_url, auth_key=auth_key, sign_secret=sign_secret
201+
)
183202

184203
@property
185204
def server_url(self) -> str:
@@ -212,7 +231,13 @@ def is_connected(self) -> bool:
212231
"""Whether the client is connected to the server."""
213232
return self._channel is not None
214233

215-
def connect(self, server_url: str | None = None) -> None:
234+
def connect(
235+
self,
236+
server_url: str | None = None,
237+
*,
238+
auth_key: str | None | EllipsisType = ...,
239+
sign_secret: str | None | EllipsisType = ...,
240+
) -> None:
216241
"""Connect to the server, possibly using a new URL.
217242
218243
If the client is already connected and the URL is the same as the previous URL,
@@ -222,12 +247,41 @@ def connect(self, server_url: str | None = None) -> None:
222247
Args:
223248
server_url: The URL of the server to connect to. If not provided, the
224249
previously used URL is used.
250+
auth_key: The API key to use when connecting to the service. If an Ellipsis
251+
is provided, the previously used auth_key is used.
252+
sign_secret: The secret to use when creating message HMAC. If an Ellipsis is
253+
provided,
225254
"""
255+
reconnect = False
226256
if server_url is not None and server_url != self._server_url: # URL changed
227257
self._server_url = server_url
228-
elif self.is_connected:
258+
reconnect = True
259+
if auth_key is not ... and auth_key != self._auth_key:
260+
self._auth_key = auth_key
261+
reconnect = True
262+
if sign_secret is not ... and sign_secret != self._sign_secret:
263+
self._sign_secret = sign_secret
264+
reconnect = True
265+
if self.is_connected and not reconnect: # Desired connection already exists
229266
return
230-
self._channel = parse_grpc_uri(self._server_url, self._channel_defaults)
267+
268+
interceptors: list[ClientInterceptor] = []
269+
if self._auth_key is not None:
270+
interceptors += [
271+
AuthenticationInterceptorUnaryUnary(self._auth_key), # type: ignore [list-item]
272+
AuthenticationInterceptorUnaryStream(self._auth_key), # type: ignore [list-item]
273+
]
274+
if self._sign_secret is not None:
275+
interceptors += [
276+
SigningInterceptorUnaryUnary(self._sign_secret), # type: ignore [list-item]
277+
SigningInterceptorUnaryStream(self._sign_secret), # type: ignore [list-item]
278+
]
279+
280+
self._channel = parse_grpc_uri(
281+
self._server_url,
282+
interceptors,
283+
defaults=self._channel_defaults,
284+
)
231285
self._stub = self._create_stub(self._channel)
232286

233287
async def disconnect(self) -> None:

src/frequenz/client/base/signing.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
"""An Interceptor that adds HMAC signature of the metadata fields to a gRPC call."""
55

6-
import dataclasses
76
import hmac
87
import logging
98
import secrets
@@ -68,25 +67,17 @@ def _add_hmac(
6867
)
6968

7069

71-
@dataclasses.dataclass(frozen=True)
72-
class SigningOptions:
73-
"""Options for message signing of messages."""
74-
75-
secret: str
76-
"""The secret to sign the message with."""
77-
78-
7970
# There is an issue in gRPC which means the type can not be specified correctly here.
8071
class SigningInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
8172
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
8273

83-
def __init__(self, options: SigningOptions):
74+
def __init__(self, secret: str):
8475
"""Create an instance of the interceptor.
8576
8677
Args:
87-
options: The options for signing the message.
78+
secret: The secret used for signing the message.
8879
"""
89-
self._secret = options.secret.encode()
80+
self._secret = secret.encode()
9081

9182
async def intercept_unary_unary(
9283
self,
@@ -121,13 +112,13 @@ async def intercept_unary_unary(
121112
class SigningInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
122113
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
123114

124-
def __init__(self, options: SigningOptions):
115+
def __init__(self, secret: str):
125116
"""Create an instance of the interceptor.
126117
127118
Args:
128-
options: The options for signing the message.
119+
secret: The secret used for signing the message.
129120
"""
130-
self._secret = options.secret.encode()
121+
self._secret = secret.encode()
131122

132123
async def intercept_unary_stream(
133124
self,

tests/test_channel.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import pytest
1212
from grpc import ssl_channel_credentials
13-
from grpc.aio import Channel
13+
from grpc.aio import Channel, UnaryStreamClientInterceptor, UnaryUnaryClientInterceptor
1414

1515
from frequenz.client.base.channel import (
1616
ChannelOptions,
@@ -257,7 +257,7 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
257257
return_value=b"contents",
258258
) as get_contents_mock,
259259
):
260-
channel = parse_grpc_uri(uri, defaults)
260+
channel = parse_grpc_uri(uri, defaults=defaults)
261261

262262
assert channel == expected_channel
263263
expected_target = f"{expected_host}:{expected_port}"
@@ -318,11 +318,11 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
318318
expected_target,
319319
expected_credentials,
320320
expected_channel_options,
321-
interceptors=[],
321+
interceptors=(),
322322
)
323323
else:
324324
insecure_channel_mock.assert_called_once_with(
325-
expected_target, expected_channel_options, interceptors=[]
325+
expected_target, expected_channel_options, interceptors=()
326326
)
327327

328328

@@ -387,3 +387,22 @@ def test_invalid_url_no_default_port() -> None:
387387
match=r"The gRPC URI 'grpc://localhost' doesn't specify a port and there is no default.",
388388
):
389389
parse_grpc_uri(uri)
390+
391+
392+
def test_forward_interceptors() -> None:
393+
"""Test that the interceptors are properly forwarded to channel construction."""
394+
expected_channel = mock.MagicMock(name="mock_channel", spec=Channel)
395+
mock_interceptors = [
396+
mock.MagicMock(name="mock_interceptorUU", spec=UnaryUnaryClientInterceptor),
397+
mock.MagicMock(name="mock_interceptorUS", spec=UnaryStreamClientInterceptor),
398+
]
399+
uri = "grpc://localhost:2355?keep_alive=0"
400+
with mock.patch(
401+
"frequenz.client.base.channel.secure_channel",
402+
return_value=expected_channel,
403+
) as secure_channel_mock:
404+
_ = parse_grpc_uri(uri, mock_interceptors)
405+
406+
secure_channel_mock.assert_called_once_with(
407+
"localhost:2355", mock.ANY, None, interceptors=mock_interceptors
408+
)

0 commit comments

Comments
 (0)