Skip to content

Commit 0b38565

Browse files
Add interceptors for stream requests
This adds authentication and signing interceptors for stream requests. Specifically for those where the client sends a single message and receives a stream. Signed-off-by: Florian Wagner <[email protected]>
1 parent 97946f0 commit 0b38565

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

src/frequenz/client/base/authentication.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
"""An Interceptor that adds the API key to a gRPC call."""
55

66
import dataclasses
7-
from typing import Callable
7+
from typing import AsyncIterable, Callable
88

99
from grpc.aio import (
1010
ClientCallDetails,
1111
Metadata,
1212
UnaryUnaryCall,
1313
UnaryUnaryClientInterceptor,
14+
UnaryStreamCall,
15+
UnaryStreamClientInterceptor,
1416
)
1517

1618

@@ -22,8 +24,8 @@ class AuthenticationOptions:
2224
"""The API key to authenticate with."""
2325

2426

25-
# There is an issue in gRPC that causes the type to be unspecifieable correctly here.
26-
class AuthenticationInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
27+
# There is an issue in gRPC which means the type can not be specified correctly here.
28+
class AuthenticationInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor): # type: ignore[type-arg]
2729
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
2830

2931
def __init__(self, options: AuthenticationOptions):
@@ -54,11 +56,20 @@ async def intercept_unary_unary(
5456
Returns:
5557
The response object (this implementation does not modify the response).
5658
"""
57-
self.add_auth_header(
58-
client_call_details,
59-
)
59+
self.add_auth_header(client_call_details)
6060
return await continuation(client_call_details, request)
6161

62+
async def intercept_unary_stream(
63+
self,
64+
continuation: Callable[[ClientCallDetails, object], UnaryStreamCall],
65+
client_call_details: ClientCallDetails,
66+
request: object,
67+
) -> AsyncIterable | UnaryStreamCall:
68+
69+
self.add_auth_header(client_call_details)
70+
71+
return continuation(client_call_details, request)
72+
6273
def add_auth_header(
6374
self,
6475
client_call_details: ClientCallDetails,

src/frequenz/client/base/channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
secure_channel,
1818
)
1919

20-
from .authentication import AuthenticationInterceptor, AuthenticationOptions
20+
from .authentication import AuthenticationInterceptorUnaryUnary, AuthenticationOptions
2121
from .signing import SigningInterceptor, SigningOptions
2222

2323

@@ -194,7 +194,7 @@ def parse_grpc_uri(
194194
interceptors: list[ClientInterceptor] = []
195195
if defaults.auth is not None:
196196
interceptors.append(
197-
AuthenticationInterceptor(options=defaults.auth) # type: ignore[arg-type]
197+
AuthenticationInterceptorUnaryUnary(options=defaults.auth) # type: ignore[arg-type]
198198
)
199199

200200
if defaults.sign is not None:

src/frequenz/client/base/signing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import secrets
1010
import time
1111
from base64 import urlsafe_b64encode
12-
from typing import Any, Callable
12+
from typing import Any, AsyncIterable, Callable
1313

1414
from grpc.aio import (
1515
ClientCallDetails,
1616
UnaryUnaryCall,
1717
UnaryUnaryClientInterceptor,
18+
UnaryStreamCall,
19+
UnaryStreamClientInterceptor,
1820
)
1921

2022
_logger = logging.getLogger(__name__)
@@ -28,8 +30,8 @@ class SigningOptions:
2830
"""The secret to sign the message with."""
2931

3032

31-
# There is an issue in gRPC that causes the type to be unspecifieable correctly here.
32-
class SigningInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
33+
# There is an issue in gRPC which means the type can not be specified correctly here.
34+
class SigningInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor): # type: ignore[type-arg]
3335
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
3436

3537
def __init__(self, options: SigningOptions):
@@ -67,6 +69,21 @@ async def intercept_unary_unary(
6769
)
6870
return await continuation(client_call_details, request)
6971

72+
73+
async def intercept_unary_stream(
74+
self,
75+
continuation: Callable[[ClientCallDetails, Any], UnaryStreamCall],
76+
client_call_details: ClientCallDetails,
77+
request: Any,
78+
) -> AsyncIterable | UnaryStreamCall:
79+
self.add_hmac(
80+
client_call_details,
81+
int(time.time()).to_bytes(8, "big"),
82+
secrets.token_bytes(16),
83+
)
84+
return continuation(client_call_details, request)
85+
86+
7087
def add_hmac(
7188
self, client_call_details: ClientCallDetails, ts: bytes, nonce: bytes
7289
) -> None:

0 commit comments

Comments
 (0)