Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

## Upgrading

<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
* Updated interface and behavior for HMAC

This introduces a new positional argument to `parse_grpc_uri`.
If calling this function manually and passing `ChannelOptions`, it is recommended
to switch to passing `ChannelOptions` via keyword argument.

## New Features

Expand Down
21 changes: 6 additions & 15 deletions src/frequenz/client/base/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""An Interceptor that adds the API key to a gRPC call."""

import dataclasses
from typing import AsyncIterable, Callable

from grpc.aio import (
Expand Down Expand Up @@ -35,25 +34,17 @@ def _add_auth_header(
client_call_details.metadata["key"] = key


@dataclasses.dataclass(frozen=True)
class AuthenticationOptions:
"""Options for authenticating to the endpoint."""

api_key: str
"""The API key to authenticate with."""


# There is an issue in gRPC which means the type can not be specified correctly here.
class AuthenticationInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: AuthenticationOptions):
def __init__(self, api_key: str):
Copy link

Copilot AI May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the class documentation for AuthenticationInterceptorUnaryUnary to reflect that the API key is now passed directly as a parameter instead of via AuthenticationOptions.

Copilot uses AI. Check for mistakes.
"""Create an instance of the interceptor.

Args:
options: The options for authenticating to the endpoint.
api_key: The API key to send along for the request.
"""
self._key = options.api_key
self._key = api_key

async def intercept_unary_unary(
self,
Expand Down Expand Up @@ -83,13 +74,13 @@ async def intercept_unary_unary(
class AuthenticationInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: AuthenticationOptions):
def __init__(self, api_key: str):
"""Create an instance of the interceptor.

Args:
options: The options for authenticating to the endpoint.
api_key: The API key to send along for the request.
"""
self._key = options.api_key
self._key = api_key

async def intercept_unary_stream(
self,
Expand Down
35 changes: 4 additions & 31 deletions src/frequenz/client/base/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import dataclasses
import pathlib
from datetime import timedelta
from typing import assert_never
from typing import Sequence, assert_never
from urllib.parse import parse_qs, urlparse

from grpc import ssl_channel_credentials
Expand All @@ -17,17 +17,6 @@
secure_channel,
)

from .authentication import (
AuthenticationInterceptorUnaryStream,
AuthenticationInterceptorUnaryUnary,
AuthenticationOptions,
)
from .signing import (
SigningInterceptorUnaryStream,
SigningInterceptorUnaryUnary,
SigningOptions,
)


@dataclasses.dataclass(frozen=True)
class SslOptions:
Expand Down Expand Up @@ -85,15 +74,10 @@ class ChannelOptions:
keep_alive: KeepAliveOptions = KeepAliveOptions()
"""HTTP2 keep-alive options for the channel."""

sign: SigningOptions | None = None
"""Signing options for the channel."""

auth: AuthenticationOptions | None = None
"""Authentication options for the channel."""


def parse_grpc_uri(
uri: str,
interceptors: Sequence[ClientInterceptor] = (),
/,
defaults: ChannelOptions = ChannelOptions(),
) -> Channel:
Expand Down Expand Up @@ -131,6 +115,8 @@ def parse_grpc_uri(

Args:
uri: The gRPC URI specifying the connection parameters.
interceptors: A list of interceptors to apply to the channel. They are applied
in the same order as they are passed in (see grpc interceptor docs for details)
defaults: The default options use to create the channel when not specified in
the URI.

Expand Down Expand Up @@ -199,19 +185,6 @@ def parse_grpc_uri(
else None
)

interceptors: list[ClientInterceptor] = []
if defaults.auth is not None:
interceptors += [
AuthenticationInterceptorUnaryUnary(options=defaults.auth), # type: ignore [list-item]
AuthenticationInterceptorUnaryStream(options=defaults.auth), # type: ignore [list-item]
]

if defaults.sign is not None:
interceptors += [
SigningInterceptorUnaryUnary(options=defaults.sign), # type: ignore [list-item]
SigningInterceptorUnaryStream(options=defaults.sign), # type: ignore [list-item]
]

ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
if ssl:
return secure_channel(
Expand Down
64 changes: 59 additions & 5 deletions src/frequenz/client/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,25 @@
import abc
import inspect
from collections.abc import Awaitable, Callable
from types import EllipsisType
from typing import Any, Generic, Self, TypeVar, overload

from grpc.aio import (
AioRpcError,
Channel,
ClientInterceptor,
)

from .authentication import (
AuthenticationInterceptorUnaryStream,
AuthenticationInterceptorUnaryUnary,
)
from .channel import ChannelOptions, parse_grpc_uri
from .exception import ApiClientError, ClientNotConnected
from .signing import (
SigningInterceptorUnaryStream,
SigningInterceptorUnaryUnary,
)

StubT = TypeVar("StubT")
"""The type of the gRPC stub."""
Expand Down Expand Up @@ -153,13 +163,15 @@ async def main():
instances.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
server_url: str,
create_stub: Callable[[Channel], StubT],
*,
connect: bool = True,
channel_defaults: ChannelOptions = ChannelOptions(),
auth_key: str | None = None,
sign_secret: str | None = None,
) -> None:
"""Create an instance and connect to the server.

Expand All @@ -172,14 +184,21 @@ def __init__(
called.
channel_defaults: The default options for the gRPC channel to create using
the server URL.
auth_key: The API key to use when connecting to the service.
sign_secret: The secret to use when creating message HMAC.

"""
self._server_url: str = server_url
self._create_stub: Callable[[Channel], StubT] = create_stub
self._channel_defaults: ChannelOptions = channel_defaults
self._auth_key = auth_key
self._sign_secret = sign_secret
self._channel: Channel | None = None
self._stub: StubT | None = None
if connect:
self.connect(server_url)
self.connect(
server_url=self._server_url, auth_key=auth_key, sign_secret=sign_secret
)

@property
def server_url(self) -> str:
Expand Down Expand Up @@ -212,7 +231,13 @@ def is_connected(self) -> bool:
"""Whether the client is connected to the server."""
return self._channel is not None

def connect(self, server_url: str | None = None) -> None:
def connect(
self,
server_url: str | None = None,
*,
auth_key: str | None | EllipsisType = ...,
sign_secret: str | None | EllipsisType = ...,
) -> None:
"""Connect to the server, possibly using a new URL.

If the client is already connected and the URL is the same as the previous URL,
Expand All @@ -222,12 +247,41 @@ def connect(self, server_url: str | None = None) -> None:
Args:
server_url: The URL of the server to connect to. If not provided, the
previously used URL is used.
auth_key: The API key to use when connecting to the service. If an Ellipsis
is provided, the previously used auth_key is used.
sign_secret: The secret to use when creating message HMAC. If an Ellipsis is
provided,
"""
reconnect = False
if server_url is not None and server_url != self._server_url: # URL changed
self._server_url = server_url
elif self.is_connected:
reconnect = True
if auth_key is not ... and auth_key != self._auth_key:
self._auth_key = auth_key
reconnect = True
if sign_secret is not ... and sign_secret != self._sign_secret:
self._sign_secret = sign_secret
reconnect = True
if self.is_connected and not reconnect: # Desired connection already exists
return
self._channel = parse_grpc_uri(self._server_url, self._channel_defaults)

interceptors: list[ClientInterceptor] = []
if self._auth_key is not None:
interceptors += [
AuthenticationInterceptorUnaryUnary(self._auth_key), # type: ignore [list-item]
AuthenticationInterceptorUnaryStream(self._auth_key), # type: ignore [list-item]
]
if self._sign_secret is not None:
interceptors += [
SigningInterceptorUnaryUnary(self._sign_secret), # type: ignore [list-item]
SigningInterceptorUnaryStream(self._sign_secret), # type: ignore [list-item]
]

self._channel = parse_grpc_uri(
self._server_url,
interceptors,
defaults=self._channel_defaults,
)
self._stub = self._create_stub(self._channel)

async def disconnect(self) -> None:
Expand Down
21 changes: 6 additions & 15 deletions src/frequenz/client/base/signing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""An Interceptor that adds HMAC signature of the metadata fields to a gRPC call."""

import dataclasses
import hmac
import logging
import secrets
Expand Down Expand Up @@ -68,25 +67,17 @@ def _add_hmac(
)


@dataclasses.dataclass(frozen=True)
class SigningOptions:
"""Options for message signing of messages."""

secret: str
"""The secret to sign the message with."""


# There is an issue in gRPC which means the type can not be specified correctly here.
class SigningInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: SigningOptions):
def __init__(self, secret: str):
Copy link

Copilot AI May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the class documentation for SigningInterceptorUnaryUnary to reflect that the secret is now passed directly as a string parameter instead of via SigningOptions.

Copilot uses AI. Check for mistakes.
"""Create an instance of the interceptor.

Args:
options: The options for signing the message.
secret: The secret used for signing the message.
"""
self._secret = options.secret.encode()
self._secret = secret.encode()

async def intercept_unary_unary(
self,
Expand Down Expand Up @@ -121,13 +112,13 @@ async def intercept_unary_unary(
class SigningInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg]
"""An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""

def __init__(self, options: SigningOptions):
def __init__(self, secret: str):
"""Create an instance of the interceptor.

Args:
options: The options for signing the message.
secret: The secret used for signing the message.
"""
self._secret = options.secret.encode()
self._secret = secret.encode()

async def intercept_unary_stream(
self,
Expand Down
27 changes: 23 additions & 4 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pytest
from grpc import ssl_channel_credentials
from grpc.aio import Channel
from grpc.aio import Channel, UnaryStreamClientInterceptor, UnaryUnaryClientInterceptor

from frequenz.client.base.channel import (
ChannelOptions,
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
return_value=b"contents",
) as get_contents_mock,
):
channel = parse_grpc_uri(uri, defaults)
channel = parse_grpc_uri(uri, defaults=defaults)

assert channel == expected_channel
expected_target = f"{expected_host}:{expected_port}"
Expand Down Expand Up @@ -318,11 +318,11 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
expected_target,
expected_credentials,
expected_channel_options,
interceptors=[],
interceptors=(),
)
else:
insecure_channel_mock.assert_called_once_with(
expected_target, expected_channel_options, interceptors=[]
expected_target, expected_channel_options, interceptors=()
)


Expand Down Expand Up @@ -387,3 +387,22 @@ def test_invalid_url_no_default_port() -> None:
match=r"The gRPC URI 'grpc://localhost' doesn't specify a port and there is no default.",
):
parse_grpc_uri(uri)


def test_forward_interceptors() -> None:
"""Test that the interceptors are properly forwarded to channel construction."""
expected_channel = mock.MagicMock(name="mock_channel", spec=Channel)
mock_interceptors = [
mock.MagicMock(name="mock_interceptorUU", spec=UnaryUnaryClientInterceptor),
mock.MagicMock(name="mock_interceptorUS", spec=UnaryStreamClientInterceptor),
]
uri = "grpc://localhost:2355?keep_alive=0"
with mock.patch(
"frequenz.client.base.channel.secure_channel",
return_value=expected_channel,
) as secure_channel_mock:
_ = parse_grpc_uri(uri, mock_interceptors)

secure_channel_mock.assert_called_once_with(
"localhost:2355", mock.ANY, None, interceptors=mock_interceptors
)
Loading
Loading