From e3f8e5fc1f2e047d30aec002ba2ba99f9251f83b Mon Sep 17 00:00:00 2001 From: Florian Wagner Date: Fri, 23 May 2025 11:21:06 +0200 Subject: [PATCH 1/2] Fix Typo Signed-off-by: Florian Wagner --- src/frequenz/client/base/channel.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/frequenz/client/base/channel.py b/src/frequenz/client/base/channel.py index a145af1..b6b24fe 100644 --- a/src/frequenz/client/base/channel.py +++ b/src/frequenz/client/base/channel.py @@ -278,13 +278,13 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams: } if ssl is False: - erros = [] + errors = [] for opt_name, opt in ssl_opts.items(): if opt is not None: - erros.append(opt_name) - if erros: + errors.append(opt_name) + if errors: raise ValueError( - f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled", + f"Option(s) {', '.join(errors)} found in URI {uri!r}, but SSL is disabled", ) keep_alive_option = options.pop("keep_alive", None) @@ -298,13 +298,13 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams: } if keep_alive is False: - erros = [] + errors = [] for opt_name, opt in keep_alive_opts.items(): if opt is not None: - erros.append(opt_name) - if erros: + errors.append(opt_name) + if errors: raise ValueError( - f"Option(s) {', '.join(erros)} found in URI {uri!r}, but keep_alive is disabled", + f"Option(s) {', '.join(errors)} found in URI {uri!r}, but keep_alive is disabled", ) if options: From 591354f6077e12520fb85f8dc41a5134d3df1912 Mon Sep 17 00:00:00 2001 From: Florian Wagner Date: Fri, 23 May 2025 11:21:49 +0200 Subject: [PATCH 2/2] Add interceptors for stream requests This adds authentication and signing interceptors for stream requests. Specifically for those where the client sends a single message and receives a stream. Signed-off-by: Florian Wagner --- RELEASE_NOTES.md | 3 +- src/frequenz/client/base/authentication.py | 67 +++++++++--- src/frequenz/client/base/channel.py | 26 +++-- src/frequenz/client/base/signing.py | 121 +++++++++++++++------ tests/test_authentication.py | 14 +-- tests/test_signing.py | 15 +-- 6 files changed, 172 insertions(+), 74 deletions(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 57d8b38..40aaf6b 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -2,7 +2,8 @@ ## Features -* Added support for HMAC signing of client messages +* Added support for HMAC signing of UnaryUnary client messages +* Added support for HMAC signing of UnaryStream client messages ## Upgrading diff --git a/src/frequenz/client/base/authentication.py b/src/frequenz/client/base/authentication.py index 614a0eb..7ce42ea 100644 --- a/src/frequenz/client/base/authentication.py +++ b/src/frequenz/client/base/authentication.py @@ -4,16 +4,37 @@ """An Interceptor that adds the API key to a gRPC call.""" import dataclasses -from typing import Callable +from typing import AsyncIterable, Callable from grpc.aio import ( ClientCallDetails, Metadata, + UnaryStreamCall, + UnaryStreamClientInterceptor, UnaryUnaryCall, UnaryUnaryClientInterceptor, ) +def _add_auth_header( + key: str, + client_call_details: ClientCallDetails, +) -> None: + """Add the API key as a metadata field to the call. + + The API key is used by the later sign interceptor to calculate the HMAC. + In addition it is used as a first layer of authentication by the server. + + Args: + key: The API key to use for the service. + client_call_details: The call details. + """ + if client_call_details.metadata is None: + client_call_details.metadata = Metadata() + + client_call_details.metadata["key"] = key + + @dataclasses.dataclass(frozen=True) class AuthenticationOptions: """Options for authenticating to the endpoint.""" @@ -22,8 +43,8 @@ class AuthenticationOptions: """The API key to authenticate with.""" -# There is an issue in gRPC that causes the type to be unspecifieable correctly here. -class AuthenticationInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg] +# There is an issue in gRPC which means the type can not be specified correctly here. +class AuthenticationInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg] """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" def __init__(self, options: AuthenticationOptions): @@ -54,24 +75,42 @@ async def intercept_unary_unary( Returns: The response object (this implementation does not modify the response). """ - self.add_auth_header( - client_call_details, - ) + _add_auth_header(self._key, client_call_details) return await continuation(client_call_details, request) - def add_auth_header( + +# There is an issue in gRPC which means the type can not be specified correctly here. +class AuthenticationInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg] + """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" + + def __init__(self, options: AuthenticationOptions): + """Create an instance of the interceptor. + + Args: + options: The options for authenticating to the endpoint. + """ + self._key = options.api_key + + async def intercept_unary_stream( self, + continuation: Callable[ + [ClientCallDetails, object], UnaryStreamCall[object, object] + ], client_call_details: ClientCallDetails, - ) -> None: - """Add the API key as a metadata field to the call. + request: object, + ) -> AsyncIterable[object] | UnaryStreamCall[object, object]: + """Intercept the call to add HMAC authentication to the metadata fields. - The API key is used by the later sign interceptor to calculate the HMAC. - In addition it is used as a first layer of authentication by the server. + This is a known method from the base class that is overridden. Args: + continuation: The next interceptor in the chain. client_call_details: The call details. + request: The request object. + + Returns: + The response object (this implementation does not modify the response). """ - if client_call_details.metadata is None: - client_call_details.metadata = Metadata() + _add_auth_header(self._key, client_call_details) - client_call_details.metadata["key"] = self._key + return await continuation(client_call_details, request) # type: ignore diff --git a/src/frequenz/client/base/channel.py b/src/frequenz/client/base/channel.py index b6b24fe..9746120 100644 --- a/src/frequenz/client/base/channel.py +++ b/src/frequenz/client/base/channel.py @@ -17,8 +17,16 @@ secure_channel, ) -from .authentication import AuthenticationInterceptor, AuthenticationOptions -from .signing import SigningInterceptor, SigningOptions +from .authentication import ( + AuthenticationInterceptorUnaryStream, + AuthenticationInterceptorUnaryUnary, + AuthenticationOptions, +) +from .signing import ( + SigningInterceptorUnaryStream, + SigningInterceptorUnaryUnary, + SigningOptions, +) @dataclasses.dataclass(frozen=True) @@ -193,14 +201,16 @@ def parse_grpc_uri( interceptors: list[ClientInterceptor] = [] if defaults.auth is not None: - interceptors.append( - AuthenticationInterceptor(options=defaults.auth) # type: ignore[arg-type] - ) + interceptors += [ + AuthenticationInterceptorUnaryUnary(options=defaults.auth), # type: ignore [list-item] + AuthenticationInterceptorUnaryStream(options=defaults.auth), # type: ignore [list-item] + ] if defaults.sign is not None: - interceptors.append( - SigningInterceptor(options=defaults.sign) # type: ignore[arg-type] - ) + interceptors += [ + SigningInterceptorUnaryUnary(options=defaults.sign), # type: ignore [list-item] + SigningInterceptorUnaryStream(options=defaults.sign), # type: ignore [list-item] + ] ssl = defaults.ssl.enabled if options.ssl is None else options.ssl if ssl: diff --git a/src/frequenz/client/base/signing.py b/src/frequenz/client/base/signing.py index 3a1ea1d..b76c45d 100644 --- a/src/frequenz/client/base/signing.py +++ b/src/frequenz/client/base/signing.py @@ -9,10 +9,12 @@ import secrets import time from base64 import urlsafe_b64encode -from typing import Any, Callable +from typing import Any, AsyncIterable, Callable from grpc.aio import ( ClientCallDetails, + UnaryStreamCall, + UnaryStreamClientInterceptor, UnaryUnaryCall, UnaryUnaryClientInterceptor, ) @@ -20,6 +22,52 @@ _logger = logging.getLogger(__name__) +def _add_hmac( + secret: bytes, client_call_details: ClientCallDetails, ts: int, nonce: bytes +) -> None: + """Add the HMAC authentication to the metadata fields of the call details. + + The extra headers are directly added to the client_call details. + + Args: + secret: The symmetric secret shared with the service. + client_call_details: The call details. + ts: The timestamp to use for the HMAC. + nonce: The nonce to use for the HMAC. + """ + if client_call_details.metadata is None: + _logger.error( + "No metadata found, cannot extract an api key. Therefore, cannot sign the request." + ) + return + + key: Any = client_call_details.metadata.get("key") + if key is None: + _logger.error("No key found in metadata, cannot sign the request.") + return + + # Make into a base10 integer string and then encode to bytes + # We can not use a raw bytes timestamp as the underlying network library + # really hates zero bytes in the metadata values + ts_bytes = str(ts).encode() + nonce_bytes = urlsafe_b64encode(nonce) + + hmac_obj = hmac.new(secret, digestmod="sha256") + hmac_obj.update(key.encode()) + hmac_obj.update(ts_bytes) + hmac_obj.update(nonce_bytes) + + # Once again, gRPC is mistyped + hmac_obj.update(client_call_details.method.split(b"/")[-1]) # type: ignore[arg-type] + + client_call_details.metadata["ts"] = ts_bytes + client_call_details.metadata["nonce"] = nonce_bytes + # By definition the signature is base64 encoded _without_ the padding, so we strip that + client_call_details.metadata["sig"] = urlsafe_b64encode(hmac_obj.digest()).strip( + b"=" + ) + + @dataclasses.dataclass(frozen=True) class SigningOptions: """Options for message signing of messages.""" @@ -28,8 +76,8 @@ class SigningOptions: """The secret to sign the message with.""" -# There is an issue in gRPC that causes the type to be unspecifieable correctly here. -class SigningInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg] +# There is an issue in gRPC which means the type can not be specified correctly here. +class SigningInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg] """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" def __init__(self, options: SigningOptions): @@ -60,42 +108,51 @@ async def intercept_unary_unary( Returns: The response object (this implementation does not modify the response). """ - self.add_hmac( + _add_hmac( + self._secret, client_call_details, - int(time.time()).to_bytes(8, "big"), + int(time.time()), secrets.token_bytes(16), ) return await continuation(client_call_details, request) - def add_hmac( - self, client_call_details: ClientCallDetails, ts: bytes, nonce: bytes - ) -> None: - """Add the HMAC authentication to the metadata fields of the call details. - The extra headers are directly added to the client_call details. +# There is an issue in gRPC which means the type can not be specified correctly here. +class SigningInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg] + """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" + + def __init__(self, options: SigningOptions): + """Create an instance of the interceptor. Args: + options: The options for signing the message. + """ + self._secret = options.secret.encode() + + async def intercept_unary_stream( + self, + continuation: Callable[ + [ClientCallDetails, Any], UnaryStreamCall[object, object] + ], + client_call_details: ClientCallDetails, + request: Any, + ) -> AsyncIterable[object] | UnaryStreamCall[object, object]: + """Intercept the call to add HMAC authentication to the metadata fields. + + This is a known method from the base class that is overridden. + + Args: + continuation: The next interceptor in the chain. client_call_details: The call details. - ts: The timestamp to use for the HMAC. - nonce: The nonce to use for the HMAC. + request: The request object. + + Returns: + The response object (this implementation does not modify the response). """ - if client_call_details.metadata is None: - _logger.error( - "No metadata found, cannot extract an api key. Therefore, cannot sign the request." - ) - return - - key: Any = client_call_details.metadata.get("key") - if key is None: - _logger.error("No key found in metadata, cannot sign the request.") - return - hmac_obj = hmac.new(self._secret, digestmod="sha256") - hmac_obj.update(key.encode()) - hmac_obj.update(ts) - hmac_obj.update(nonce) - - hmac_obj.update(client_call_details.method.encode()) - - client_call_details.metadata["ts"] = ts - client_call_details.metadata["nonce"] = nonce - client_call_details.metadata["sig"] = urlsafe_b64encode(hmac_obj.digest()) + _add_hmac( + self._secret, + client_call_details, + int(time.time()), + secrets.token_bytes(16), + ) + return await continuation(client_call_details, request) # type: ignore diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 92e8cfa..76a28b1 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -5,24 +5,18 @@ from unittest import mock -from frequenz.client.base.authentication import ( - AuthenticationInterceptor, - AuthenticationOptions, -) +from frequenz.client.base.authentication import _add_auth_header async def test_auth_interceptor() -> None: """Test that the Auth Interceptor adds the correct header.""" - auth: AuthenticationOptions = AuthenticationOptions(api_key="my_key") - auth_interceptor: AuthenticationInterceptor = AuthenticationInterceptor( - options=auth - ) - metadata: dict[str, str] = {} client_call_details = mock.MagicMock(method="my_rpc") client_call_details.metadata = metadata - auth_interceptor.add_auth_header(client_call_details) + key = "my_key" + + _add_auth_header(key, client_call_details) assert metadata["key"] == "my_key" diff --git a/tests/test_signing.py b/tests/test_signing.py index 3c6b2e6..990a0ca 100644 --- a/tests/test_signing.py +++ b/tests/test_signing.py @@ -5,24 +5,21 @@ from unittest import mock -from frequenz.client.base.signing import ( - SigningInterceptor, - SigningOptions, -) +from frequenz.client.base.signing import _add_hmac async def test_sign_interceptor() -> None: """Test that the HMAC is calculated correctly so that it will match the value of the server.""" - sign: SigningOptions = SigningOptions(secret="my_secret") - sign_interceptor: SigningInterceptor = SigningInterceptor(options=sign) - metadata: dict[str, str | bytes] = {"key": "my_key"} client_call_details = mock.MagicMock(method="my_rpc") client_call_details.metadata = metadata + client_call_details.method = ( + b"/frequenz.api.wishlist.v1.Wishlist/ElectrifyTheFutureRequest" + ) - sign_interceptor.add_hmac(client_call_details, b"1634567890", b"123456789") + _add_hmac(b"hunter2", client_call_details, 1634567890, b"123456789") - assert metadata["sig"] == "NJDvrkRZhOPekn5AvPiaJsYTJYCgnLzA-LQFC2D7GNE=".encode( + assert metadata["sig"] == "yNCJYXjac-waeqLhlYJE2cql9rUGIq-7Flz4MAOZefQ".encode( "utf-8" )