Skip to content

Commit 4573a9b

Browse files
Add interceptors for stream requests
This adds authentication and signing interceptors for stream requests. Specifically for those where the client sends a single message and receives a stream. Signed-off-by: Florian Wagner <[email protected]>
1 parent 97946f0 commit 4573a9b

File tree

6 files changed

+180
-82
lines changed

6 files changed

+180
-82
lines changed

RELEASE_NOTES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
## Features
44

5-
* Added support for HMAC signing of client messages
5+
* Added support for HMAC signing of UnaryUnary client messages
6+
* Added support for HMAC signing of UnaryStream client messages
67

78
## Upgrading
89

src/frequenz/client/base/authentication.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,37 @@
44
"""An Interceptor that adds the API key to a gRPC call."""
55

66
import dataclasses
7-
from typing import Callable
7+
from typing import AsyncIterable, Callable
88

99
from grpc.aio import (
1010
ClientCallDetails,
1111
Metadata,
12+
UnaryStreamCall,
13+
UnaryStreamClientInterceptor,
1214
UnaryUnaryCall,
1315
UnaryUnaryClientInterceptor,
1416
)
1517

1618

19+
def _add_auth_header(
20+
key: str,
21+
client_call_details: ClientCallDetails,
22+
) -> None:
23+
"""Add the API key as a metadata field to the call.
24+
25+
The API key is used by the later sign interceptor to calculate the HMAC.
26+
In addition it is used as a first layer of authentication by the server.
27+
28+
Args:
29+
key: The API key to use for the service.
30+
client_call_details: The call details.
31+
"""
32+
if client_call_details.metadata is None:
33+
client_call_details.metadata = Metadata()
34+
35+
client_call_details.metadata["key"] = key
36+
37+
1738
@dataclasses.dataclass(frozen=True)
1839
class AuthenticationOptions:
1940
"""Options for authenticating to the endpoint."""
@@ -22,8 +43,8 @@ class AuthenticationOptions:
2243
"""The API key to authenticate with."""
2344

2445

25-
# There is an issue in gRPC that causes the type to be unspecifieable correctly here.
26-
class AuthenticationInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
46+
# There is an issue in gRPC which means the type can not be specified correctly here.
47+
class AuthenticationInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
2748
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
2849

2950
def __init__(self, options: AuthenticationOptions):
@@ -54,24 +75,42 @@ async def intercept_unary_unary(
5475
Returns:
5576
The response object (this implementation does not modify the response).
5677
"""
57-
self.add_auth_header(
58-
client_call_details,
59-
)
78+
_add_auth_header(self._key, client_call_details)
6079
return await continuation(client_call_details, request)
6180

62-
def add_auth_header(
81+
82+
# There is an issue in gRPC which means the type can not be specified correctly here.
83+
class AuthenticationInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
84+
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
85+
86+
def __init__(self, options: AuthenticationOptions):
87+
"""Create an instance of the interceptor.
88+
89+
Args:
90+
options: The options for authenticating to the endpoint.
91+
"""
92+
self._key = options.api_key
93+
94+
async def intercept_unary_stream(
6395
self,
96+
continuation: Callable[
97+
[ClientCallDetails, object], UnaryStreamCall[object, object]
98+
],
6499
client_call_details: ClientCallDetails,
65-
) -> None:
66-
"""Add the API key as a metadata field to the call.
100+
request: object,
101+
) -> AsyncIterable[object] | UnaryStreamCall[object, object]:
102+
"""Intercept the call to add HMAC authentication to the metadata fields.
67103
68-
The API key is used by the later sign interceptor to calculate the HMAC.
69-
In addition it is used as a first layer of authentication by the server.
104+
This is a known method from the base class that is overridden.
70105
71106
Args:
107+
continuation: The next interceptor in the chain.
72108
client_call_details: The call details.
109+
request: The request object.
110+
111+
Returns:
112+
The response object (this implementation does not modify the response).
73113
"""
74-
if client_call_details.metadata is None:
75-
client_call_details.metadata = Metadata()
114+
_add_auth_header(self._key, client_call_details)
76115

77-
client_call_details.metadata["key"] = self._key
116+
return await continuation(client_call_details, request) # type: ignore

src/frequenz/client/base/channel.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717
secure_channel,
1818
)
1919

20-
from .authentication import AuthenticationInterceptor, AuthenticationOptions
21-
from .signing import SigningInterceptor, SigningOptions
20+
from .authentication import (
21+
AuthenticationInterceptorUnaryStream,
22+
AuthenticationInterceptorUnaryUnary,
23+
AuthenticationOptions,
24+
)
25+
from .signing import (
26+
SigningInterceptorUnaryStream,
27+
SigningInterceptorUnaryUnary,
28+
SigningOptions,
29+
)
2230

2331

2432
@dataclasses.dataclass(frozen=True)
@@ -193,14 +201,16 @@ def parse_grpc_uri(
193201

194202
interceptors: list[ClientInterceptor] = []
195203
if defaults.auth is not None:
196-
interceptors.append(
197-
AuthenticationInterceptor(options=defaults.auth) # type: ignore[arg-type]
198-
)
204+
interceptors += [
205+
AuthenticationInterceptorUnaryUnary(options=defaults.auth), # type: ignore [list-item]
206+
AuthenticationInterceptorUnaryStream(options=defaults.auth), # type: ignore [list-item]
207+
]
199208

200209
if defaults.sign is not None:
201-
interceptors.append(
202-
SigningInterceptor(options=defaults.sign) # type: ignore[arg-type]
203-
)
210+
interceptors += [
211+
SigningInterceptorUnaryUnary(options=defaults.sign), # type: ignore [list-item]
212+
SigningInterceptorUnaryStream(options=defaults.sign), # type: ignore [list-item]
213+
]
204214

205215
ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
206216
if ssl:
@@ -278,13 +288,13 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
278288
}
279289

280290
if ssl is False:
281-
erros = []
291+
errors = []
282292
for opt_name, opt in ssl_opts.items():
283293
if opt is not None:
284-
erros.append(opt_name)
285-
if erros:
294+
errors.append(opt_name)
295+
if errors:
286296
raise ValueError(
287-
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled",
297+
f"Option(s) {', '.join(errors)} found in URI {uri!r}, but SSL is disabled",
288298
)
289299

290300
keep_alive_option = options.pop("keep_alive", None)
@@ -298,13 +308,13 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
298308
}
299309

300310
if keep_alive is False:
301-
erros = []
311+
errors = []
302312
for opt_name, opt in keep_alive_opts.items():
303313
if opt is not None:
304-
erros.append(opt_name)
305-
if erros:
314+
errors.append(opt_name)
315+
if errors:
306316
raise ValueError(
307-
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but keep_alive is disabled",
317+
f"Option(s) {', '.join(errors)} found in URI {uri!r}, but keep_alive is disabled",
308318
)
309319

310320
if options:

src/frequenz/client/base/signing.py

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,65 @@
99
import secrets
1010
import time
1111
from base64 import urlsafe_b64encode
12-
from typing import Any, Callable
12+
from typing import Any, AsyncIterable, Callable
1313

1414
from grpc.aio import (
1515
ClientCallDetails,
16+
UnaryStreamCall,
17+
UnaryStreamClientInterceptor,
1618
UnaryUnaryCall,
1719
UnaryUnaryClientInterceptor,
1820
)
1921

2022
_logger = logging.getLogger(__name__)
2123

2224

25+
def _add_hmac(
26+
secret: bytes, client_call_details: ClientCallDetails, ts: int, nonce: bytes
27+
) -> None:
28+
"""Add the HMAC authentication to the metadata fields of the call details.
29+
30+
The extra headers are directly added to the client_call details.
31+
32+
Args:
33+
secret: The symmetric secret shared with the service.
34+
client_call_details: The call details.
35+
ts: The timestamp to use for the HMAC.
36+
nonce: The nonce to use for the HMAC.
37+
"""
38+
if client_call_details.metadata is None:
39+
_logger.error(
40+
"No metadata found, cannot extract an api key. Therefore, cannot sign the request."
41+
)
42+
return
43+
44+
key: Any = client_call_details.metadata.get("key")
45+
if key is None:
46+
_logger.error("No key found in metadata, cannot sign the request.")
47+
return
48+
49+
# Make into a base10 integer string and then encode to bytes
50+
# We can not use a raw bytes timestamp as the underlying network library
51+
# really hates zero bytes in the metadata values
52+
ts_bytes = str(ts).encode()
53+
nonce_bytes = urlsafe_b64encode(nonce)
54+
55+
hmac_obj = hmac.new(secret, digestmod="sha256")
56+
hmac_obj.update(key.encode())
57+
hmac_obj.update(ts_bytes)
58+
hmac_obj.update(nonce_bytes)
59+
60+
# Once again, gRPC is mistyped
61+
hmac_obj.update(client_call_details.method.split(b"/")[-1]) # type: ignore[arg-type]
62+
63+
client_call_details.metadata["ts"] = ts_bytes
64+
client_call_details.metadata["nonce"] = nonce_bytes
65+
# By definition the signature is base64 encoded _without_ the padding, so we strip that
66+
client_call_details.metadata["sig"] = urlsafe_b64encode(hmac_obj.digest()).strip(
67+
b"="
68+
)
69+
70+
2371
@dataclasses.dataclass(frozen=True)
2472
class SigningOptions:
2573
"""Options for message signing of messages."""
@@ -28,8 +76,8 @@ class SigningOptions:
2876
"""The secret to sign the message with."""
2977

3078

31-
# There is an issue in gRPC that causes the type to be unspecifieable correctly here.
32-
class SigningInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
79+
# There is an issue in gRPC which means the type can not be specified correctly here.
80+
class SigningInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
3381
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
3482

3583
def __init__(self, options: SigningOptions):
@@ -60,42 +108,51 @@ async def intercept_unary_unary(
60108
Returns:
61109
The response object (this implementation does not modify the response).
62110
"""
63-
self.add_hmac(
111+
_add_hmac(
112+
self._secret,
64113
client_call_details,
65-
int(time.time()).to_bytes(8, "big"),
114+
int(time.time()),
66115
secrets.token_bytes(16),
67116
)
68117
return await continuation(client_call_details, request)
69118

70-
def add_hmac(
71-
self, client_call_details: ClientCallDetails, ts: bytes, nonce: bytes
72-
) -> None:
73-
"""Add the HMAC authentication to the metadata fields of the call details.
74119

75-
The extra headers are directly added to the client_call details.
120+
# There is an issue in gRPC which means the type can not be specified correctly here.
121+
class SigningInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
122+
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
123+
124+
def __init__(self, options: SigningOptions):
125+
"""Create an instance of the interceptor.
76126
77127
Args:
128+
options: The options for signing the message.
129+
"""
130+
self._secret = options.secret.encode()
131+
132+
async def intercept_unary_stream(
133+
self,
134+
continuation: Callable[
135+
[ClientCallDetails, Any], UnaryStreamCall[object, object]
136+
],
137+
client_call_details: ClientCallDetails,
138+
request: Any,
139+
) -> AsyncIterable[object] | UnaryStreamCall[object, object]:
140+
"""Intercept the call to add HMAC authentication to the metadata fields.
141+
142+
This is a known method from the base class that is overridden.
143+
144+
Args:
145+
continuation: The next interceptor in the chain.
78146
client_call_details: The call details.
79-
ts: The timestamp to use for the HMAC.
80-
nonce: The nonce to use for the HMAC.
147+
request: The request object.
148+
149+
Returns:
150+
The response object (this implementation does not modify the response).
81151
"""
82-
if client_call_details.metadata is None:
83-
_logger.error(
84-
"No metadata found, cannot extract an api key. Therefore, cannot sign the request."
85-
)
86-
return
87-
88-
key: Any = client_call_details.metadata.get("key")
89-
if key is None:
90-
_logger.error("No key found in metadata, cannot sign the request.")
91-
return
92-
hmac_obj = hmac.new(self._secret, digestmod="sha256")
93-
hmac_obj.update(key.encode())
94-
hmac_obj.update(ts)
95-
hmac_obj.update(nonce)
96-
97-
hmac_obj.update(client_call_details.method.encode())
98-
99-
client_call_details.metadata["ts"] = ts
100-
client_call_details.metadata["nonce"] = nonce
101-
client_call_details.metadata["sig"] = urlsafe_b64encode(hmac_obj.digest())
152+
_add_hmac(
153+
self._secret,
154+
client_call_details,
155+
int(time.time()),
156+
secrets.token_bytes(16),
157+
)
158+
return await continuation(client_call_details, request) # type: ignore

tests/test_authentication.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,18 @@
55

66
from unittest import mock
77

8-
from frequenz.client.base.authentication import (
9-
AuthenticationInterceptor,
10-
AuthenticationOptions,
11-
)
8+
from frequenz.client.base.authentication import _add_auth_header
129

1310

1411
async def test_auth_interceptor() -> None:
1512
"""Test that the Auth Interceptor adds the correct header."""
16-
auth: AuthenticationOptions = AuthenticationOptions(api_key="my_key")
17-
auth_interceptor: AuthenticationInterceptor = AuthenticationInterceptor(
18-
options=auth
19-
)
20-
2113
metadata: dict[str, str] = {}
2214

2315
client_call_details = mock.MagicMock(method="my_rpc")
2416
client_call_details.metadata = metadata
2517

26-
auth_interceptor.add_auth_header(client_call_details)
18+
key = "my_key"
19+
20+
_add_auth_header(key, client_call_details)
2721

2822
assert metadata["key"] == "my_key"

0 commit comments

Comments
 (0)