Skip to content

Commit c41e592

Browse files
committed
Allow passing a default root certificate
This change allows passing a default root certificate to the `parse_grpc_uri` function. This certificate will be used if the `ssl_root_certificates_path` is not provided in the URI. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 938ac5d commit c41e592

File tree

1 file changed

+58
-12
lines changed

1 file changed

+58
-12
lines changed

src/frequenz/client/base/channel.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import dataclasses
77
import pathlib
8+
from typing import assert_never
89
from urllib.parse import parse_qs, urlparse
910

1011
from grpc import ssl_channel_credentials
@@ -18,6 +19,13 @@ class SslOptions:
1819
enabled: bool = True
1920
"""Whether SSL should be enabled."""
2021

22+
root_certificates: pathlib.Path | bytes | None = None
23+
"""The PEM-encoded root certificates.
24+
25+
This can be a path to a file containing the certificates, a byte string, or None to
26+
retrieve them from a default location chosen by gRPC runtime.
27+
"""
28+
2129

2230
@dataclasses.dataclass(frozen=True)
2331
class ChannelOptions:
@@ -90,19 +98,15 @@ def parse_grpc_uri(
9098

9199
ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
92100
if ssl:
93-
root_cert: bytes | None = None
94-
if options.ssl_root_certificates_path is not None:
95-
try:
96-
with options.ssl_root_certificates_path.open("rb") as file:
97-
root_cert = file.read()
98-
except OSError as exc:
99-
raise ValueError(
100-
"Failed to read root certificates from "
101-
f"'{options.ssl_root_certificates_path}': {exc}",
102-
uri,
103-
) from exc
104101
return secure_channel(
105-
target, ssl_channel_credentials(root_certificates=root_cert)
102+
target,
103+
ssl_channel_credentials(
104+
root_certificates=_get_contents(
105+
"root certificates",
106+
options.ssl_root_certificates_path,
107+
defaults.ssl.root_certificates,
108+
)
109+
),
106110
)
107111
return insecure_channel(target)
108112

@@ -160,3 +164,45 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
160164
pathlib.Path(ssl_root_cert_path) if ssl_root_cert_path else None
161165
),
162166
)
167+
168+
169+
def _get_contents(
170+
name: str, source: pathlib.Path | None, default: pathlib.Path | bytes | None
171+
) -> bytes | None:
172+
"""Get the contents of a file or use a default value.
173+
174+
If the `source` is `None`, the `default` value is used instead. If the source (or
175+
default) is a path, the contents of the file are returned. If the source is a byte
176+
string (or default) the byte string is returned without doing any reading.
177+
178+
Args:
179+
name: The name of the contents (used for error messages).
180+
source: The source of the contents.
181+
default: The default value to use if the source is None.
182+
183+
Returns:
184+
The contents of the source file or the default value.
185+
186+
Raises:
187+
ValueError: If the file cannot be read.
188+
"""
189+
file_path: pathlib.Path
190+
match source:
191+
case None:
192+
match default:
193+
case None:
194+
return None
195+
case bytes() as default_bytes:
196+
return default_bytes
197+
case pathlib.Path() as file_path:
198+
pass
199+
case unexpected:
200+
assert_never(unexpected)
201+
case pathlib.Path() as file_path:
202+
pass
203+
case unexpected:
204+
assert_never(unexpected)
205+
try:
206+
return file_path.read_bytes()
207+
except OSError as exc:
208+
raise ValueError(f"Failed to read {name} from '{file_path}': {exc}") from exc

0 commit comments

Comments
 (0)