Skip to content

Commit 76f2bdd

Browse files
authored
Allow passing SSL options via server URL (#73)
2 parents f83c57a + 1b6f65d commit 76f2bdd

File tree

6 files changed

+432
-79
lines changed

6 files changed

+432
-79
lines changed

RELEASE_NOTES.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212

1313
- The `parse_grpc_uri` function (and `BaseApiClient` constructor) now enables SSL by default (`ssl=false` should be passed to disable it).
1414

15-
- The `parse_grpc_uri` function now accepts an optional `default_ssl` parameter to set the default value for the `ssl` parameter when not present in the URI.
15+
- The `parse_grpc_uri` and `BaseApiClient` function now accepts a set of defaults to use when the URI does not specify a value for a given option.
1616

1717
## New Features
1818

19-
<!-- Here goes the main new features and examples or instructions on how to use them -->
19+
- The connection URI can now have a few new SSL options:
20+
21+
* `ssl_root_certificates_path` to specify the path to the root certificates file.
22+
* `ssl_private_key_path` to specify the path to the private key file.
23+
* `ssl_certificate_chain_path` to specify the path to the certificate chain file.
2024

2125
## Bug Fixes
2226

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ dev-mkdocs = [
5656
"mkdocs-macros-plugin == 1.0.5",
5757
"mkdocs-material == 9.5.31",
5858
"mkdocstrings[python] == 0.25.2",
59-
"mkdocstrings-python == 1.9.2",
59+
"mkdocstrings-python == 1.10.8",
6060
"frequenz-repo-config[lib] == 0.10.0",
6161
"frequenz-client-base",
6262
]

src/frequenz/client/base/channel.py

Lines changed: 189 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,50 +3,90 @@
33

44
"""Handling of gRPC channels."""
55

6+
import dataclasses
7+
import pathlib
8+
from typing import assert_never
69
from urllib.parse import parse_qs, urlparse
710

811
from grpc import ssl_channel_credentials
912
from grpc.aio import Channel, insecure_channel, secure_channel
1013

1114

12-
def _to_bool(value: str) -> bool:
13-
value = value.lower()
14-
if value in ("true", "on", "1"):
15-
return True
16-
if value in ("false", "off", "0"):
17-
return False
18-
raise ValueError(f"Invalid boolean value '{value}'")
15+
@dataclasses.dataclass(frozen=True)
16+
class SslOptions:
17+
"""SSL options for a gRPC channel."""
18+
19+
enabled: bool = True
20+
"""Whether SSL should be enabled."""
21+
22+
root_certificates: pathlib.Path | bytes | None = None
23+
"""The PEM-encoded root certificates.
24+
25+
This can be a path to a file containing the certificates, a byte string, or None to
26+
retrieve them from a default location chosen by gRPC runtime.
27+
"""
28+
29+
private_key: pathlib.Path | bytes | None = None
30+
"""The PEM-encoded private key.
31+
32+
This can be a path to a file containing the key, a byte string, or None if no key
33+
should be used.
34+
"""
35+
36+
certificate_chain: pathlib.Path | bytes | None = None
37+
"""The PEM-encoded certificate chain.
38+
39+
This can be a path to a file containing the chain, a byte string, or None if no
40+
chain should be used.
41+
"""
42+
43+
44+
@dataclasses.dataclass(frozen=True)
45+
class ChannelOptions:
46+
"""Options for a gRPC channel."""
47+
48+
port: int = 9090
49+
"""The port number to connect to."""
50+
51+
ssl: SslOptions = SslOptions()
52+
"""SSL options for the channel."""
1953

2054

2155
def parse_grpc_uri(
2256
uri: str,
2357
/,
24-
*,
25-
default_port: int = 9090,
26-
default_ssl: bool = True,
58+
defaults: ChannelOptions = ChannelOptions(),
2759
) -> Channel:
2860
"""Create a client channel from a URI.
2961
3062
The URI must have the following format:
3163
3264
```
33-
grpc://hostname[:port][?ssl=<bool>]
65+
grpc://hostname[:port][?param=value&...]
3466
```
3567
3668
A few things to consider about URI components:
3769
3870
- If any other components are present in the URI, a [`ValueError`][] is raised.
3971
- If the port is omitted, the `default_port` is used.
4072
- If a query parameter is passed many times, the last value is used.
41-
- The only supported query parameter is `ssl`, which must be a boolean value and
42-
defaults to the `default_ssl` argument if not present.
4373
- Boolean query parameters can be specified with the following values
4474
(case-insensitive): `true`, `1`, `on`, `false`, `0`, `off`.
4575
76+
Supported query parameters:
77+
78+
- `ssl` (bool): Enable or disable SSL. Defaults to `default_ssl`.
79+
- `ssl_root_certificates_path` (str): Path to the root certificates file. Only
80+
valid if SSL is enabled. Will raise a `ValueError` if the file cannot be read.
81+
- `ssl_private_key_path` (str): Path to the private key file. Only valid if SSL is
82+
enabled. Will raise a `ValueError` if the file cannot be read.
83+
- `ssl_certificate_chain_path` (str): Path to the certificate chain file. Only
84+
valid if SSL is enabled. Will raise a `ValueError` if the file cannot be read.
85+
4686
Args:
4787
uri: The gRPC URI specifying the connection parameters.
48-
default_port: The default port number to use if the URI does not specify one.
49-
default_ssl: The default SSL setting to use if the URI does not specify one.
88+
defaults: The default options use to create the channel when not specified in
89+
the URI.
5090
5191
Returns:
5292
A client channel object.
@@ -68,18 +108,143 @@ def parse_grpc_uri(
68108
uri,
69109
)
70110

71-
options = {k: v[-1] for k, v in parse_qs(parsed_uri.query).items()}
111+
options = _parse_query_params(uri, parsed_uri.query)
112+
113+
host = parsed_uri.hostname
114+
port = parsed_uri.port or defaults.port
115+
target = f"{host}:{port}"
116+
117+
ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
118+
if ssl:
119+
return secure_channel(
120+
target,
121+
ssl_channel_credentials(
122+
root_certificates=_get_contents(
123+
"root certificates",
124+
options.ssl_root_certificates_path,
125+
defaults.ssl.root_certificates,
126+
),
127+
private_key=_get_contents(
128+
"private key",
129+
options.ssl_private_key_path,
130+
defaults.ssl.private_key,
131+
),
132+
certificate_chain=_get_contents(
133+
"certificate chain",
134+
options.ssl_certificate_chain_path,
135+
defaults.ssl.certificate_chain,
136+
),
137+
),
138+
)
139+
return insecure_channel(target)
140+
141+
142+
def _to_bool(value: str) -> bool:
143+
value = value.lower()
144+
if value in ("true", "on", "1"):
145+
return True
146+
if value in ("false", "off", "0"):
147+
return False
148+
raise ValueError(f"Invalid boolean value '{value}'")
149+
150+
151+
@dataclasses.dataclass(frozen=True)
152+
class _QueryParams:
153+
ssl: bool | None
154+
ssl_root_certificates_path: pathlib.Path | None
155+
ssl_private_key_path: pathlib.Path | None
156+
ssl_certificate_chain_path: pathlib.Path | None
157+
158+
159+
def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
160+
"""Parse query parameters from a URI.
161+
162+
Args:
163+
uri: The URI from which the query parameters were extracted.
164+
query_string: The query string to parse.
165+
166+
Returns:
167+
A `_QueryParams` object with the parsed query parameters.
168+
169+
Raises:
170+
ValueError: If the query string contains unexpected components.
171+
"""
172+
options = {k: v[-1] for k, v in parse_qs(query_string).items()}
72173
ssl_option = options.pop("ssl", None)
73-
ssl = _to_bool(ssl_option) if ssl_option is not None else default_ssl
174+
ssl: bool | None = None
175+
if ssl_option is not None:
176+
ssl = _to_bool(ssl_option)
177+
178+
ssl_opts = {
179+
k: options.pop(k, None)
180+
for k in (
181+
"ssl_root_certificates_path",
182+
"ssl_private_key_path",
183+
"ssl_certificate_chain_path",
184+
)
185+
}
186+
187+
if ssl is False:
188+
erros = []
189+
for opt_name, opt in ssl_opts.items():
190+
if opt is not None:
191+
erros.append(opt_name)
192+
if erros:
193+
raise ValueError(
194+
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled",
195+
)
196+
74197
if options:
198+
names = ", ".join(options)
75199
raise ValueError(
76-
f"Unexpected query parameters {options!r} in the URI '{uri}'",
200+
f"Unexpected query parameters [{names}] in the URI '{uri}'",
77201
uri,
78202
)
79203

80-
host = parsed_uri.hostname
81-
port = parsed_uri.port or default_port
82-
target = f"{host}:{port}"
83-
if ssl:
84-
return secure_channel(target, ssl_channel_credentials())
85-
return insecure_channel(target)
204+
return _QueryParams(
205+
ssl=ssl,
206+
**{k: pathlib.Path(v) if v is not None else None for k, v in ssl_opts.items()},
207+
)
208+
209+
210+
def _get_contents(
211+
name: str, source: pathlib.Path | None, default: pathlib.Path | bytes | None
212+
) -> bytes | None:
213+
"""Get the contents of a file or use a default value.
214+
215+
If the `source` is `None`, the `default` value is used instead. If the source (or
216+
default) is a path, the contents of the file are returned. If the source is a byte
217+
string (or default) the byte string is returned without doing any reading.
218+
219+
Args:
220+
name: The name of the contents (used for error messages).
221+
source: The source of the contents.
222+
default: The default value to use if the source is None.
223+
224+
Returns:
225+
The contents of the source file or the default value.
226+
"""
227+
file_path: pathlib.Path
228+
match source:
229+
case None:
230+
match default:
231+
case None:
232+
return None
233+
case bytes() as default_bytes:
234+
return default_bytes
235+
case pathlib.Path() as file_path:
236+
return _read_bytes(name, file_path)
237+
case unexpected:
238+
assert_never(unexpected)
239+
case pathlib.Path() as file_path:
240+
return _read_bytes(name, file_path)
241+
case unexpected:
242+
assert_never(unexpected)
243+
244+
245+
def _read_bytes(name: str, source: pathlib.Path) -> bytes:
246+
"""Read the contents of a file as bytes."""
247+
try:
248+
return source.read_bytes()
249+
except OSError as exc:
250+
raise ValueError(f"Failed to read {name} from '{source}': {exc}") from exc

src/frequenz/client/base/client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from grpc.aio import AioRpcError, Channel
1212

13-
from .channel import parse_grpc_uri
13+
from .channel import ChannelOptions, parse_grpc_uri
1414
from .exception import ApiClientError, ClientNotConnected
1515

1616
StubT = TypeVar("StubT")
@@ -103,7 +103,7 @@ async def main():
103103
break
104104
```
105105
106-
!!! Note
106+
Note:
107107
* In this case a very simple `GrpcStreamBroadcaster` is used, asuming that
108108
each call to `example_stream` will stream the same data. If the request
109109
is more complex, you will probably need to have some kind of map from
@@ -117,6 +117,7 @@ def __init__(
117117
create_stub: Callable[[Channel], StubT],
118118
*,
119119
connect: bool = True,
120+
channel_defaults: ChannelOptions = ChannelOptions(),
120121
) -> None:
121122
"""Create an instance and connect to the server.
122123
@@ -127,9 +128,12 @@ def __init__(
127128
created. If `False`, the client will not connect to the server until
128129
[connect()][frequenz.client.base.client.BaseApiClient.connect] is
129130
called.
131+
channel_defaults: The default options for the gRPC channel to create using
132+
the server URL.
130133
"""
131134
self._server_url: str = server_url
132135
self._create_stub: Callable[[Channel], StubT] = create_stub
136+
self._channel_defaults: ChannelOptions = channel_defaults
133137
self._channel: Channel | None = None
134138
self._stub: StubT | None = None
135139
if connect:
@@ -156,6 +160,11 @@ def channel(self) -> Channel:
156160
raise ClientNotConnected(server_url=self.server_url, operation="channel")
157161
return self._channel
158162

163+
@property
164+
def channel_defaults(self) -> ChannelOptions:
165+
"""The default options for the gRPC channel."""
166+
return self._channel_defaults
167+
159168
@property
160169
def stub(self) -> StubT:
161170
"""The underlying gRPC stub.
@@ -192,7 +201,7 @@ def connect(self, server_url: str | None = None) -> None:
192201
self._server_url = server_url
193202
elif self.is_connected:
194203
return
195-
self._channel = parse_grpc_uri(self._server_url)
204+
self._channel = parse_grpc_uri(self._server_url, self._channel_defaults)
196205
self._stub = self._create_stub(self._channel)
197206

198207
async def disconnect(self) -> None:

0 commit comments

Comments
 (0)