diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 9a57d26..de8876d 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -6,7 +6,11 @@ ## Upgrading - +* Updated interface and behavior for HMAC + + This introduces a new positional argument to `parse_grpc_uri`. + If calling this function manually and passing `ChannelOptions`, it is recommended + to switch to passing `ChannelOptions` via keyword argument. ## New Features diff --git a/src/frequenz/client/base/authentication.py b/src/frequenz/client/base/authentication.py index 7ce42ea..3267c0f 100644 --- a/src/frequenz/client/base/authentication.py +++ b/src/frequenz/client/base/authentication.py @@ -3,7 +3,6 @@ """An Interceptor that adds the API key to a gRPC call.""" -import dataclasses from typing import AsyncIterable, Callable from grpc.aio import ( @@ -35,25 +34,17 @@ def _add_auth_header( client_call_details.metadata["key"] = key -@dataclasses.dataclass(frozen=True) -class AuthenticationOptions: - """Options for authenticating to the endpoint.""" - - api_key: str - """The API key to authenticate with.""" - - # There is an issue in gRPC which means the type can not be specified correctly here. class AuthenticationInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg] """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" - def __init__(self, options: AuthenticationOptions): + def __init__(self, api_key: str): """Create an instance of the interceptor. Args: - options: The options for authenticating to the endpoint. + api_key: The API key to send along for the request. """ - self._key = options.api_key + self._key = api_key async def intercept_unary_unary( self, @@ -83,13 +74,13 @@ async def intercept_unary_unary( class AuthenticationInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg] """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" - def __init__(self, options: AuthenticationOptions): + def __init__(self, api_key: str): """Create an instance of the interceptor. Args: - options: The options for authenticating to the endpoint. + api_key: The API key to send along for the request. """ - self._key = options.api_key + self._key = api_key async def intercept_unary_stream( self, diff --git a/src/frequenz/client/base/channel.py b/src/frequenz/client/base/channel.py index 9746120..d00324c 100644 --- a/src/frequenz/client/base/channel.py +++ b/src/frequenz/client/base/channel.py @@ -6,7 +6,7 @@ import dataclasses import pathlib from datetime import timedelta -from typing import assert_never +from typing import Sequence, assert_never from urllib.parse import parse_qs, urlparse from grpc import ssl_channel_credentials @@ -17,17 +17,6 @@ secure_channel, ) -from .authentication import ( - AuthenticationInterceptorUnaryStream, - AuthenticationInterceptorUnaryUnary, - AuthenticationOptions, -) -from .signing import ( - SigningInterceptorUnaryStream, - SigningInterceptorUnaryUnary, - SigningOptions, -) - @dataclasses.dataclass(frozen=True) class SslOptions: @@ -85,15 +74,10 @@ class ChannelOptions: keep_alive: KeepAliveOptions = KeepAliveOptions() """HTTP2 keep-alive options for the channel.""" - sign: SigningOptions | None = None - """Signing options for the channel.""" - - auth: AuthenticationOptions | None = None - """Authentication options for the channel.""" - def parse_grpc_uri( uri: str, + interceptors: Sequence[ClientInterceptor] = (), /, defaults: ChannelOptions = ChannelOptions(), ) -> Channel: @@ -131,6 +115,8 @@ def parse_grpc_uri( Args: uri: The gRPC URI specifying the connection parameters. + interceptors: A list of interceptors to apply to the channel. They are applied + in the same order as they are passed in (see grpc interceptor docs for details) defaults: The default options use to create the channel when not specified in the URI. @@ -199,19 +185,6 @@ def parse_grpc_uri( else None ) - interceptors: list[ClientInterceptor] = [] - if defaults.auth is not None: - interceptors += [ - AuthenticationInterceptorUnaryUnary(options=defaults.auth), # type: ignore [list-item] - AuthenticationInterceptorUnaryStream(options=defaults.auth), # type: ignore [list-item] - ] - - if defaults.sign is not None: - interceptors += [ - SigningInterceptorUnaryUnary(options=defaults.sign), # type: ignore [list-item] - SigningInterceptorUnaryStream(options=defaults.sign), # type: ignore [list-item] - ] - ssl = defaults.ssl.enabled if options.ssl is None else options.ssl if ssl: return secure_channel( diff --git a/src/frequenz/client/base/client.py b/src/frequenz/client/base/client.py index 3de4a64..5935c5a 100644 --- a/src/frequenz/client/base/client.py +++ b/src/frequenz/client/base/client.py @@ -6,15 +6,25 @@ import abc import inspect from collections.abc import Awaitable, Callable +from types import EllipsisType from typing import Any, Generic, Self, TypeVar, overload from grpc.aio import ( AioRpcError, Channel, + ClientInterceptor, ) +from .authentication import ( + AuthenticationInterceptorUnaryStream, + AuthenticationInterceptorUnaryUnary, +) from .channel import ChannelOptions, parse_grpc_uri from .exception import ApiClientError, ClientNotConnected +from .signing import ( + SigningInterceptorUnaryStream, + SigningInterceptorUnaryUnary, +) StubT = TypeVar("StubT") """The type of the gRPC stub.""" @@ -153,13 +163,15 @@ async def main(): instances. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, server_url: str, create_stub: Callable[[Channel], StubT], *, connect: bool = True, channel_defaults: ChannelOptions = ChannelOptions(), + auth_key: str | None = None, + sign_secret: str | None = None, ) -> None: """Create an instance and connect to the server. @@ -172,14 +184,21 @@ def __init__( called. channel_defaults: The default options for the gRPC channel to create using the server URL. + auth_key: The API key to use when connecting to the service. + sign_secret: The secret to use when creating message HMAC. + """ self._server_url: str = server_url self._create_stub: Callable[[Channel], StubT] = create_stub self._channel_defaults: ChannelOptions = channel_defaults + self._auth_key = auth_key + self._sign_secret = sign_secret self._channel: Channel | None = None self._stub: StubT | None = None if connect: - self.connect(server_url) + self.connect( + server_url=self._server_url, auth_key=auth_key, sign_secret=sign_secret + ) @property def server_url(self) -> str: @@ -212,7 +231,13 @@ def is_connected(self) -> bool: """Whether the client is connected to the server.""" return self._channel is not None - def connect(self, server_url: str | None = None) -> None: + def connect( + self, + server_url: str | None = None, + *, + auth_key: str | None | EllipsisType = ..., + sign_secret: str | None | EllipsisType = ..., + ) -> None: """Connect to the server, possibly using a new URL. If the client is already connected and the URL is the same as the previous URL, @@ -222,12 +247,41 @@ def connect(self, server_url: str | None = None) -> None: Args: server_url: The URL of the server to connect to. If not provided, the previously used URL is used. + auth_key: The API key to use when connecting to the service. If an Ellipsis + is provided, the previously used auth_key is used. + sign_secret: The secret to use when creating message HMAC. If an Ellipsis is + provided, """ + reconnect = False if server_url is not None and server_url != self._server_url: # URL changed self._server_url = server_url - elif self.is_connected: + reconnect = True + if auth_key is not ... and auth_key != self._auth_key: + self._auth_key = auth_key + reconnect = True + if sign_secret is not ... and sign_secret != self._sign_secret: + self._sign_secret = sign_secret + reconnect = True + if self.is_connected and not reconnect: # Desired connection already exists return - self._channel = parse_grpc_uri(self._server_url, self._channel_defaults) + + interceptors: list[ClientInterceptor] = [] + if self._auth_key is not None: + interceptors += [ + AuthenticationInterceptorUnaryUnary(self._auth_key), # type: ignore [list-item] + AuthenticationInterceptorUnaryStream(self._auth_key), # type: ignore [list-item] + ] + if self._sign_secret is not None: + interceptors += [ + SigningInterceptorUnaryUnary(self._sign_secret), # type: ignore [list-item] + SigningInterceptorUnaryStream(self._sign_secret), # type: ignore [list-item] + ] + + self._channel = parse_grpc_uri( + self._server_url, + interceptors, + defaults=self._channel_defaults, + ) self._stub = self._create_stub(self._channel) async def disconnect(self) -> None: diff --git a/src/frequenz/client/base/signing.py b/src/frequenz/client/base/signing.py index b76c45d..29b4373 100644 --- a/src/frequenz/client/base/signing.py +++ b/src/frequenz/client/base/signing.py @@ -3,7 +3,6 @@ """An Interceptor that adds HMAC signature of the metadata fields to a gRPC call.""" -import dataclasses import hmac import logging import secrets @@ -68,25 +67,17 @@ def _add_hmac( ) -@dataclasses.dataclass(frozen=True) -class SigningOptions: - """Options for message signing of messages.""" - - secret: str - """The secret to sign the message with.""" - - # There is an issue in gRPC which means the type can not be specified correctly here. class SigningInterceptorUnaryUnary(UnaryUnaryClientInterceptor): # type: ignore[type-arg] """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" - def __init__(self, options: SigningOptions): + def __init__(self, secret: str): """Create an instance of the interceptor. Args: - options: The options for signing the message. + secret: The secret used for signing the message. """ - self._secret = options.secret.encode() + self._secret = secret.encode() async def intercept_unary_unary( self, @@ -121,13 +112,13 @@ async def intercept_unary_unary( class SigningInterceptorUnaryStream(UnaryStreamClientInterceptor): # type: ignore[type-arg] """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" - def __init__(self, options: SigningOptions): + def __init__(self, secret: str): """Create an instance of the interceptor. Args: - options: The options for signing the message. + secret: The secret used for signing the message. """ - self._secret = options.secret.encode() + self._secret = secret.encode() async def intercept_unary_stream( self, diff --git a/tests/test_channel.py b/tests/test_channel.py index 94ff868..019556d 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -10,7 +10,7 @@ import pytest from grpc import ssl_channel_credentials -from grpc.aio import Channel +from grpc.aio import Channel, UnaryStreamClientInterceptor, UnaryUnaryClientInterceptor from frequenz.client.base.channel import ( ChannelOptions, @@ -257,7 +257,7 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals return_value=b"contents", ) as get_contents_mock, ): - channel = parse_grpc_uri(uri, defaults) + channel = parse_grpc_uri(uri, defaults=defaults) assert channel == expected_channel expected_target = f"{expected_host}:{expected_port}" @@ -318,11 +318,11 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals expected_target, expected_credentials, expected_channel_options, - interceptors=[], + interceptors=(), ) else: insecure_channel_mock.assert_called_once_with( - expected_target, expected_channel_options, interceptors=[] + expected_target, expected_channel_options, interceptors=() ) @@ -387,3 +387,22 @@ def test_invalid_url_no_default_port() -> None: match=r"The gRPC URI 'grpc://localhost' doesn't specify a port and there is no default.", ): parse_grpc_uri(uri) + + +def test_forward_interceptors() -> None: + """Test that the interceptors are properly forwarded to channel construction.""" + expected_channel = mock.MagicMock(name="mock_channel", spec=Channel) + mock_interceptors = [ + mock.MagicMock(name="mock_interceptorUU", spec=UnaryUnaryClientInterceptor), + mock.MagicMock(name="mock_interceptorUS", spec=UnaryStreamClientInterceptor), + ] + uri = "grpc://localhost:2355?keep_alive=0" + with mock.patch( + "frequenz.client.base.channel.secure_channel", + return_value=expected_channel, + ) as secure_channel_mock: + _ = parse_grpc_uri(uri, mock_interceptors) + + secure_channel_mock.assert_called_once_with( + "localhost:2355", mock.ANY, None, interceptors=mock_interceptors + ) diff --git a/tests/test_client.py b/tests/test_client.py index e5dc699..a4aceb8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -52,12 +52,14 @@ class _ClientMocks: _DEFAULT_SERVER_URL = "grpc://localhost" -def create_client_with_mocks( +def create_client_with_mocks( # pylint: disable=too-many-arguments mocker: pytest_mock.MockFixture, *, auto_connect: bool = True, server_url: str = _DEFAULT_SERVER_URL, channel_defaults: ChannelOptions | None = None, + auth_key: str | None = None, + sign_secret: str | None = None, ) -> tuple[BaseApiClient[mock.MagicMock], _ClientMocks]: """Create a BaseApiClient instance with mocks.""" mock_stub = mock.MagicMock(name="stub") @@ -73,6 +75,8 @@ def create_client_with_mocks( server_url=server_url, create_stub=mock_create_stub, connect=auto_connect, + auth_key=auth_key, + sign_secret=sign_secret, **kwargs, ) return client, _ClientMocks( @@ -93,7 +97,9 @@ def test_base_api_client_init( assert client.server_url == _DEFAULT_SERVER_URL if auto_connect: mocks.parse_grpc_uri.assert_called_once_with( - client.server_url, ChannelOptions() + client.server_url, + [], + defaults=ChannelOptions(), ) assert client.channel is mocks.channel assert client._stub is mocks.stub # pylint: disable=protected-access @@ -112,7 +118,11 @@ def test_base_api_client_init_with_channel_defaults( channel_defaults = ChannelOptions(ssl=SslOptions(enabled=False)) client, mocks = create_client_with_mocks(mocker, channel_defaults=channel_defaults) assert client.server_url == _DEFAULT_SERVER_URL - mocks.parse_grpc_uri.assert_called_once_with(client.server_url, channel_defaults) + mocks.parse_grpc_uri.assert_called_once_with( + client.server_url, + [], + defaults=channel_defaults, + ) assert client.channel is mocks.channel assert client._stub is mocks.stub # pylint: disable=protected-access assert client.is_connected @@ -155,7 +165,9 @@ def test_base_api_client_connect( mocks.create_stub.assert_not_called() else: mocks.parse_grpc_uri.assert_called_once_with( - client.server_url, ChannelOptions() + client.server_url, + [], + defaults=ChannelOptions(), ) mocks.create_stub.assert_called_once_with(mocks.channel) @@ -195,7 +207,9 @@ async def test_base_api_client_async_context_manager( mocks.create_stub.assert_not_called() else: mocks.parse_grpc_uri.assert_called_once_with( - client.server_url, ChannelOptions() + client.server_url, + [], + defaults=ChannelOptions(), ) mocks.create_stub.assert_called_once_with(mocks.channel) @@ -330,3 +344,23 @@ async def test_call_stub_method_success( assert response == (2 if transform else 1) if mock_transform: mock_transform.assert_called_once_with(1) + + +async def test_create_interceptors(mocker: pytest_mock.MockFixture) -> None: + """Test that the client constructor creates the interceptors as expected.""" + url = "grpc://localhost:2355?keep_alive=0" + _, mocks = create_client_with_mocks( + mocker, + auto_connect=True, + server_url=url, + auth_key="hunter2", + sign_secret="password1245", + ) + + mocks.parse_grpc_uri.assert_called_once_with( + url, [mock.ANY, mock.ANY, mock.ANY, mock.ANY], defaults=ChannelOptions() + ) + args, _ = mocks.parse_grpc_uri.call_args + interceptors = args[1] + for interceptor in interceptors: + assert isinstance(interceptor, grpc.aio.ClientInterceptor)