Skip to content

Commit e84c217

Browse files
committed
Support http2 keep-alive
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 689feaa commit e84c217

File tree

2 files changed

+202
-5
lines changed

2 files changed

+202
-5
lines changed

src/frequenz/client/base/channel.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import dataclasses
77
import pathlib
8+
from datetime import timedelta
89
from typing import assert_never
910
from urllib.parse import parse_qs, urlparse
1011

@@ -41,6 +42,20 @@ class SslOptions:
4142
"""
4243

4344

45+
@dataclasses.dataclass(frozen=True)
46+
class KeepAliveOptions:
47+
"""Options for HTTP2 keep-alive pings."""
48+
49+
enabled: bool = True
50+
"""Whether keep-alive should be enabled."""
51+
52+
interval: timedelta = timedelta(seconds=60)
53+
"""The interval between pings."""
54+
55+
timeout: timedelta = timedelta(seconds=20)
56+
"""The time in milliseconds to wait for a keep-alive response."""
57+
58+
4459
@dataclasses.dataclass(frozen=True)
4560
class ChannelOptions:
4661
"""Options for a gRPC channel."""
@@ -51,6 +66,9 @@ class ChannelOptions:
5166
ssl: SslOptions = SslOptions()
5267
"""SSL options for the channel."""
5368

69+
keep_alive: KeepAliveOptions = KeepAliveOptions()
70+
"""HTTP2 keep-alive options for the channel."""
71+
5472

5573
def parse_grpc_uri(
5674
uri: str,
@@ -120,6 +138,40 @@ def parse_grpc_uri(
120138
parsed_uri.netloc if parsed_uri.port else f"{parsed_uri.netloc}:{defaults.port}"
121139
)
122140

141+
keep_alive = (
142+
defaults.keep_alive.enabled
143+
if options.keep_alive is None
144+
else options.keep_alive
145+
)
146+
channel_options = (
147+
[
148+
("grpc.http2.max_pings_without_data", 0),
149+
("grpc.keepalive_permit_without_calls", 1),
150+
(
151+
"grpc.keepalive_time_ms",
152+
(
153+
(
154+
options.keep_alive_interval
155+
if options.keep_alive_interval is not None
156+
else defaults.keep_alive.interval
157+
).total_seconds()
158+
* 1000
159+
),
160+
),
161+
(
162+
"grpc.keepalive_timeout_ms",
163+
(
164+
options.keep_alive_timeout
165+
if options.keep_alive_timeout is not None
166+
else defaults.keep_alive.timeout
167+
).total_seconds()
168+
* 1000,
169+
),
170+
]
171+
if keep_alive
172+
else None
173+
)
174+
123175
ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
124176
if ssl:
125177
return secure_channel(
@@ -141,8 +193,9 @@ def parse_grpc_uri(
141193
defaults.ssl.certificate_chain,
142194
),
143195
),
196+
channel_options,
144197
)
145-
return insecure_channel(target)
198+
return insecure_channel(target, channel_options)
146199

147200

148201
def _to_bool(value: str) -> bool:
@@ -160,6 +213,9 @@ class _QueryParams:
160213
ssl_root_certificates_path: pathlib.Path | None
161214
ssl_private_key_path: pathlib.Path | None
162215
ssl_certificate_chain_path: pathlib.Path | None
216+
keep_alive: bool | None
217+
keep_alive_interval: timedelta | None
218+
keep_alive_timeout: timedelta | None
163219

164220

165221
def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
@@ -200,6 +256,26 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
200256
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled",
201257
)
202258

259+
keep_alive_option = options.pop("keep_alive", None)
260+
keep_alive: bool | None = None
261+
if keep_alive_option is not None:
262+
keep_alive = _to_bool(keep_alive_option)
263+
264+
keep_alive_opts = {
265+
k: options.pop(k, None)
266+
for k in ("keep_alive_interval_s", "keep_alive_timeout_s")
267+
}
268+
269+
if keep_alive is False:
270+
erros = []
271+
for opt_name, opt in keep_alive_opts.items():
272+
if opt is not None:
273+
erros.append(opt_name)
274+
if erros:
275+
raise ValueError(
276+
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but keep_alive is disabled",
277+
)
278+
203279
if options:
204280
names = ", ".join(options)
205281
raise ValueError(
@@ -209,7 +285,32 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
209285

210286
return _QueryParams(
211287
ssl=ssl,
212-
**{k: pathlib.Path(v) if v is not None else None for k, v in ssl_opts.items()},
288+
ssl_root_certificates_path=(
289+
pathlib.Path(ssl_opts["ssl_root_certificates_path"])
290+
if ssl_opts["ssl_root_certificates_path"] is not None
291+
else None
292+
),
293+
ssl_private_key_path=(
294+
pathlib.Path(ssl_opts["ssl_private_key_path"])
295+
if ssl_opts["ssl_private_key_path"] is not None
296+
else None
297+
),
298+
ssl_certificate_chain_path=(
299+
pathlib.Path(ssl_opts["ssl_certificate_chain_path"])
300+
if ssl_opts["ssl_certificate_chain_path"] is not None
301+
else None
302+
),
303+
keep_alive=keep_alive,
304+
keep_alive_interval=(
305+
timedelta(seconds=int(keep_alive_opts["keep_alive_interval_s"]))
306+
if keep_alive_opts["keep_alive_interval_s"] is not None
307+
else None
308+
),
309+
keep_alive_timeout=(
310+
timedelta(seconds=int(keep_alive_opts["keep_alive_timeout_s"]))
311+
if keep_alive_opts["keep_alive_timeout_s"] is not None
312+
else None
313+
),
213314
)
214315

215316

tests/test_channel.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import dataclasses
77
import pathlib
8+
from datetime import timedelta
89
from unittest import mock
910

1011
import pytest
@@ -13,6 +14,7 @@
1314

1415
from frequenz.client.base.channel import (
1516
ChannelOptions,
17+
KeepAliveOptions,
1618
SslOptions,
1719
_to_bool,
1820
parse_grpc_uri,
@@ -136,6 +138,67 @@ class _ValidUrlTestCase:
136138
),
137139
),
138140
),
141+
_ValidUrlTestCase(
142+
title="Keep-alive no defaults",
143+
uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300"
144+
+ "&keep_alive_timeout_s=60",
145+
expected_host="localhost",
146+
expected_port=1234,
147+
expected_options=ChannelOptions(
148+
keep_alive=KeepAliveOptions(
149+
enabled=True,
150+
interval=timedelta(minutes=5),
151+
timeout=timedelta(minutes=1),
152+
),
153+
),
154+
),
155+
_ValidUrlTestCase(
156+
title="Keep-alive default timeout",
157+
uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300",
158+
defaults=ChannelOptions(
159+
keep_alive=KeepAliveOptions(
160+
enabled=True,
161+
interval=timedelta(seconds=10),
162+
timeout=timedelta(seconds=2),
163+
),
164+
),
165+
expected_host="localhost",
166+
expected_port=1234,
167+
expected_options=ChannelOptions(
168+
keep_alive=KeepAliveOptions(
169+
enabled=True,
170+
interval=timedelta(seconds=300),
171+
timeout=timedelta(seconds=2),
172+
),
173+
),
174+
),
175+
_ValidUrlTestCase(
176+
title="Keep-alive default interval",
177+
uri="grpc://localhost:1234?keep_alive=1&keep_alive_timeout_s=60",
178+
defaults=ChannelOptions(
179+
keep_alive=KeepAliveOptions(
180+
enabled=True, interval=timedelta(minutes=30)
181+
),
182+
),
183+
expected_host="localhost",
184+
expected_port=1234,
185+
expected_options=ChannelOptions(
186+
keep_alive=KeepAliveOptions(
187+
enabled=True,
188+
timeout=timedelta(minutes=1),
189+
interval=timedelta(minutes=30),
190+
),
191+
),
192+
),
193+
_ValidUrlTestCase(
194+
title="keep-alive disabled",
195+
uri="grpc://localhost:1234?keep_alive=0",
196+
expected_host="localhost",
197+
expected_port=1234,
198+
expected_options=ChannelOptions(
199+
keep_alive=KeepAliveOptions(enabled=False),
200+
),
201+
),
139202
],
140203
ids=lambda case: case.title,
141204
)
@@ -154,7 +217,9 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
154217
)
155218
expected_port = case.expected_port
156219
expected_ssl = (
157-
expected_options.ssl.enabled if "ssl=" in uri else defaults.ssl.enabled
220+
expected_options.ssl.enabled
221+
if "ssl=" in uri or defaults.ssl.enabled is None
222+
else defaults.ssl.enabled
158223
)
159224
expected_root_certificates = (
160225
expected_options.ssl.root_certificates
@@ -196,6 +261,35 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
196261

197262
assert channel == expected_channel
198263
expected_target = f"{expected_host}:{expected_port}"
264+
expected_keep_alive = (
265+
expected_options.keep_alive if "keep_alive=" in uri else defaults.keep_alive
266+
)
267+
expected_keep_alive_interval = (
268+
expected_keep_alive.interval
269+
if "keep_alive_interval_s=" in uri
270+
else defaults.keep_alive.interval
271+
)
272+
expected_keep_alive_timeout = (
273+
expected_keep_alive.timeout
274+
if "keep_alive_timeout_s=" in uri
275+
else defaults.keep_alive.timeout
276+
)
277+
expected_channel_options = (
278+
[
279+
("grpc.http2.max_pings_without_data", 0),
280+
("grpc.keepalive_permit_without_calls", 1),
281+
(
282+
"grpc.keepalive_time_ms",
283+
(expected_keep_alive_interval.total_seconds() * 1000),
284+
),
285+
(
286+
"grpc.keepalive_timeout_ms",
287+
expected_keep_alive_timeout.total_seconds() * 1000,
288+
),
289+
]
290+
if expected_keep_alive.enabled
291+
else None
292+
)
199293
if expected_ssl:
200294
if isinstance(expected_root_certificates, pathlib.Path):
201295
get_contents_mock.assert_any_call(
@@ -221,10 +315,12 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
221315
certificate_chain=expected_certificate_chain,
222316
)
223317
secure_channel_mock.assert_called_once_with(
224-
expected_target, expected_credentials
318+
expected_target, expected_credentials, expected_channel_options
225319
)
226320
else:
227-
insecure_channel_mock.assert_called_once_with(expected_target)
321+
insecure_channel_mock.assert_called_once_with(
322+
expected_target, expected_channel_options
323+
)
228324

229325

230326
@pytest.mark.parametrize("value", ["true", "on", "1", "TrUe", "On", "ON", "TRUE"])

0 commit comments

Comments
 (0)