Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Features

* Added support for HMAC signing of client messages
* Added support for HMAC signing of UnaryUnary client messages
* Added support for HMAC signing of UnaryStream client messages

## Upgrading

Expand Down
67 changes: 53 additions & 14 deletions src/frequenz/client/base/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,37 @@
"""An Interceptor that adds the API key to a gRPC call."""

import dataclasses
from typing import Callable
from typing import AsyncIterable, Callable

from grpc.aio import (
ClientCallDetails,
Metadata,
UnaryStreamCall,
UnaryStreamClientInterceptor,
UnaryUnaryCall,
UnaryUnaryClientInterceptor,
)


def _add_auth_header(
key: str,
client_call_details: ClientCallDetails,
) -> None:
"""Add the API key as a metadata field to the call.

The API key is used by the later sign interceptor to calculate the HMAC.
In addition it is used as a first layer of authentication by the server.

Args:
key: The API key to use for the service.
client_call_details: The call details.
"""
if client_call_details.metadata is None:
client_call_details.metadata = Metadata()

client_call_details.metadata["key"] = key


@dataclasses.dataclass(frozen=True)
class AuthenticationOptions:
"""Options for authenticating to the endpoint."""
Expand All @@ -22,8 +43,8 @@ class AuthenticationOptions:
"""The API key to authenticate with."""


# There is an issue in gRPC that causes the type to be unspecifieable correctly here.
class AuthenticationInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
# There is an issue in gRPC which means the type can not be specified correctly here.
class AuthenticationInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: AuthenticationOptions):
Expand Down Expand Up @@ -54,24 +75,42 @@ async def intercept_unary_unary(
Returns:
The response object (this implementation does not modify the response).
"""
self.add_auth_header(
client_call_details,
)
_add_auth_header(self._key, client_call_details)
return await continuation(client_call_details, request)

def add_auth_header(

# There is an issue in gRPC which means the type can not be specified correctly here.
class AuthenticationInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: AuthenticationOptions):
"""Create an instance of the interceptor.

Args:
options: The options for authenticating to the endpoint.
"""
self._key = options.api_key

async def intercept_unary_stream(
self,
continuation: Callable[
[ClientCallDetails, object], UnaryStreamCall[object, object]
],
client_call_details: ClientCallDetails,
) -> None:
"""Add the API key as a metadata field to the call.
request: object,
) -> AsyncIterable[object] | UnaryStreamCall[object, object]:
"""Intercept the call to add HMAC authentication to the metadata fields.

The API key is used by the later sign interceptor to calculate the HMAC.
In addition it is used as a first layer of authentication by the server.
This is a known method from the base class that is overridden.

Args:
continuation: The next interceptor in the chain.
client_call_details: The call details.
request: The request object.

Returns:
The response object (this implementation does not modify the response).
"""
if client_call_details.metadata is None:
client_call_details.metadata = Metadata()
_add_auth_header(self._key, client_call_details)

client_call_details.metadata["key"] = self._key
return await continuation(client_call_details, request) # type: ignore
42 changes: 26 additions & 16 deletions src/frequenz/client/base/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@
secure_channel,
)

from .authentication import AuthenticationInterceptor, AuthenticationOptions
from .signing import SigningInterceptor, SigningOptions
from .authentication import (
AuthenticationInterceptorUnaryStream,
AuthenticationInterceptorUnaryUnary,
AuthenticationOptions,
)
from .signing import (
SigningInterceptorUnaryStream,
SigningInterceptorUnaryUnary,
SigningOptions,
)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -193,14 +201,16 @@ def parse_grpc_uri(

interceptors: list[ClientInterceptor] = []
if defaults.auth is not None:
interceptors.append(
AuthenticationInterceptor(options=defaults.auth) # type: ignore[arg-type]
)
interceptors += [
AuthenticationInterceptorUnaryUnary(options=defaults.auth), # type: ignore [list-item]
AuthenticationInterceptorUnaryStream(options=defaults.auth), # type: ignore [list-item]
]

if defaults.sign is not None:
interceptors.append(
SigningInterceptor(options=defaults.sign) # type: ignore[arg-type]
)
interceptors += [
SigningInterceptorUnaryUnary(options=defaults.sign), # type: ignore [list-item]
SigningInterceptorUnaryStream(options=defaults.sign), # type: ignore [list-item]
]

ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
if ssl:
Expand Down Expand Up @@ -278,13 +288,13 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
}

if ssl is False:
erros = []
errors = []
for opt_name, opt in ssl_opts.items():
if opt is not None:
erros.append(opt_name)
if erros:
errors.append(opt_name)
if errors:
raise ValueError(
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled",
f"Option(s) {', '.join(errors)} found in URI {uri!r}, but SSL is disabled",
)

keep_alive_option = options.pop("keep_alive", None)
Expand All @@ -298,13 +308,13 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
}

if keep_alive is False:
erros = []
errors = []
for opt_name, opt in keep_alive_opts.items():
if opt is not None:
erros.append(opt_name)
if erros:
errors.append(opt_name)
if errors:
raise ValueError(
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but keep_alive is disabled",
f"Option(s) {', '.join(errors)} found in URI {uri!r}, but keep_alive is disabled",
)

if options:
Expand Down
121 changes: 89 additions & 32 deletions src/frequenz/client/base/signing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,65 @@
import secrets
import time
from base64 import urlsafe_b64encode
from typing import Any, Callable
from typing import Any, AsyncIterable, Callable

from grpc.aio import (
ClientCallDetails,
UnaryStreamCall,
UnaryStreamClientInterceptor,
UnaryUnaryCall,
UnaryUnaryClientInterceptor,
)

_logger = logging.getLogger(__name__)


def _add_hmac(
secret: bytes, client_call_details: ClientCallDetails, ts: int, nonce: bytes
) -> None:
"""Add the HMAC authentication to the metadata fields of the call details.

The extra headers are directly added to the client_call details.

Args:
secret: The symmetric secret shared with the service.
client_call_details: The call details.
ts: The timestamp to use for the HMAC.
nonce: The nonce to use for the HMAC.
"""
if client_call_details.metadata is None:
_logger.error(
"No metadata found, cannot extract an api key. Therefore, cannot sign the request."
)
return

key: Any = client_call_details.metadata.get("key")
if key is None:
_logger.error("No key found in metadata, cannot sign the request.")
return

# Make into a base10 integer string and then encode to bytes
# We can not use a raw bytes timestamp as the underlying network library
# really hates zero bytes in the metadata values
ts_bytes = str(ts).encode()
nonce_bytes = urlsafe_b64encode(nonce)

hmac_obj = hmac.new(secret, digestmod="sha256")
hmac_obj.update(key.encode())
hmac_obj.update(ts_bytes)
hmac_obj.update(nonce_bytes)

# Once again, gRPC is mistyped
hmac_obj.update(client_call_details.method.split(b"/")[-1]) # type: ignore[arg-type]

client_call_details.metadata["ts"] = ts_bytes
client_call_details.metadata["nonce"] = nonce_bytes
# By definition the signature is base64 encoded _without_ the padding, so we strip that
client_call_details.metadata["sig"] = urlsafe_b64encode(hmac_obj.digest()).strip(
b"="
)


@dataclasses.dataclass(frozen=True)
class SigningOptions:
"""Options for message signing of messages."""
Expand All @@ -28,8 +76,8 @@ class SigningOptions:
"""The secret to sign the message with."""


# There is an issue in gRPC that causes the type to be unspecifieable correctly here.
class SigningInterceptor(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
# There is an issue in gRPC which means the type can not be specified correctly here.
class SigningInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: SigningOptions):
Expand Down Expand Up @@ -60,42 +108,51 @@ async def intercept_unary_unary(
Returns:
The response object (this implementation does not modify the response).
"""
self.add_hmac(
_add_hmac(
self._secret,
client_call_details,
int(time.time()).to_bytes(8, "big"),
int(time.time()),
secrets.token_bytes(16),
)
return await continuation(client_call_details, request)

def add_hmac(
self, client_call_details: ClientCallDetails, ts: bytes, nonce: bytes
) -> None:
"""Add the HMAC authentication to the metadata fields of the call details.

The extra headers are directly added to the client_call details.
# There is an issue in gRPC which means the type can not be specified correctly here.
class SigningInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: SigningOptions):
"""Create an instance of the interceptor.

Args:
options: The options for signing the message.
"""
self._secret = options.secret.encode()

async def intercept_unary_stream(
self,
continuation: Callable[
[ClientCallDetails, Any], UnaryStreamCall[object, object]
],
client_call_details: ClientCallDetails,
request: Any,
) -> AsyncIterable[object] | UnaryStreamCall[object, object]:
"""Intercept the call to add HMAC authentication to the metadata fields.

This is a known method from the base class that is overridden.

Args:
continuation: The next interceptor in the chain.
client_call_details: The call details.
ts: The timestamp to use for the HMAC.
nonce: The nonce to use for the HMAC.
request: The request object.

Returns:
The response object (this implementation does not modify the response).
"""
if client_call_details.metadata is None:
_logger.error(
"No metadata found, cannot extract an api key. Therefore, cannot sign the request."
)
return

key: Any = client_call_details.metadata.get("key")
if key is None:
_logger.error("No key found in metadata, cannot sign the request.")
return
hmac_obj = hmac.new(self._secret, digestmod="sha256")
hmac_obj.update(key.encode())
hmac_obj.update(ts)
hmac_obj.update(nonce)

hmac_obj.update(client_call_details.method.encode())

client_call_details.metadata["ts"] = ts
client_call_details.metadata["nonce"] = nonce
client_call_details.metadata["sig"] = urlsafe_b64encode(hmac_obj.digest())
_add_hmac(
self._secret,
client_call_details,
int(time.time()),
secrets.token_bytes(16),
)
return await continuation(client_call_details, request) # type: ignore
14 changes: 4 additions & 10 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,18 @@

from unittest import mock

from frequenz.client.base.authentication import (
AuthenticationInterceptor,
AuthenticationOptions,
)
from frequenz.client.base.authentication import _add_auth_header


async def test_auth_interceptor() -> None:
"""Test that the Auth Interceptor adds the correct header."""
auth: AuthenticationOptions = AuthenticationOptions(api_key="my_key")
auth_interceptor: AuthenticationInterceptor = AuthenticationInterceptor(
options=auth
)

metadata: dict[str, str] = {}

client_call_details = mock.MagicMock(method="my_rpc")
client_call_details.metadata = metadata

auth_interceptor.add_auth_header(client_call_details)
key = "my_key"

_add_auth_header(key, client_call_details)

assert metadata["key"] == "my_key"
Loading
Loading