Skip to content

Commit 434960b

Browse files
authored
Implement key-based authentication (#56)
- **Add key-based authentication support** - **Update REPL client to work better with host/port/key parameters**
2 parents c0cff94 + a2be0e2 commit 434960b

File tree

6 files changed

+195
-34
lines changed

6 files changed

+195
-34
lines changed

RELEASE_NOTES.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
## Upgrading
88

9-
<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
9+
* An API key for authorization must now be passed to the `DispatchClient`.
10+
* The client constructor now requires named parameters. If you are using the client directly, you will need to update your code to use named parameters.
1011

1112
## New Features
1213

13-
<!-- Here goes the main new features and examples or instructions on how to use them -->
14+
* TLS is now enabled by default for the CLI client.
1415

1516
## Bug Fixes
1617

src/frequenz/client/dispatch/__main__.py

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import asyncio
77
import os
8-
import sys
98
from pprint import pformat
109
from typing import Any, List
1110

@@ -37,22 +36,66 @@
3736
DEFAULT_DISPATCH_API_PORT = 50051
3837

3938

40-
def get_client(host: str, port: int) -> Client:
39+
def ssl_channel_credentials_from_files(
40+
root_cert_path: str | None = None,
41+
client_cert_path: str | None = None,
42+
client_key_path: str | None = None,
43+
) -> grpc.ChannelCredentials:
44+
"""Create credentials for use with an SSL-enabled Channel.
45+
46+
Using the provided certificate and key files.
47+
48+
Args:
49+
root_cert_path: Path to the PEM-encoded root certificates file,
50+
or None to retrieve them from a default location chosen by gRPC runtime.
51+
client_cert_path: Path to the PEM-encoded client certificate file.
52+
client_key_path: Path to the PEM-encoded client private key file.
53+
54+
Returns:
55+
A ChannelCredentials for use with an SSL-enabled Channel.
56+
"""
57+
root_certificates = None
58+
if root_cert_path is not None:
59+
with open(root_cert_path, "rb") as f:
60+
root_certificates = f.read()
61+
62+
certificate_chain = None
63+
if client_cert_path is not None:
64+
with open(client_cert_path, "rb") as f:
65+
certificate_chain = f.read()
66+
67+
private_key = None
68+
if client_key_path is not None:
69+
with open(client_key_path, "rb") as f:
70+
private_key = f.read()
71+
72+
return grpc.ssl_channel_credentials(
73+
root_certificates=root_certificates,
74+
private_key=private_key,
75+
certificate_chain=certificate_chain,
76+
)
77+
78+
79+
def get_client(*, host: str, port: int, key: str) -> Client:
4180
"""Get a new client instance.
4281
4382
Args:
4483
host: The host of the dispatch service.
4584
port: The port of the dispatch service.
85+
key: The API key for authentication.
4686
4787
Returns:
4888
Client: A new client instance.
4989
"""
50-
channel = grpc.aio.insecure_channel(f"{host}:{port}")
51-
return Client(channel, f"{host}:{port}")
90+
channel = grpc.aio.secure_channel(
91+
f"{host}:{port}",
92+
credentials=ssl_channel_credentials_from_files(),
93+
)
94+
return Client(grpc_channel=channel, svc_addr=f"{host}:{port}", key=key)
5295

5396

5497
# Click command groups
55-
@click.group()
98+
@click.group(invoke_without_command=True)
5699
@click.option(
57100
"--host",
58101
default=DEFAULT_DISPATCH_API_HOST,
@@ -69,11 +112,29 @@ def get_client(host: str, port: int) -> Client:
69112
show_envvar=True,
70113
show_default=True,
71114
)
115+
@click.option(
116+
"--key",
117+
help="API key for authentication",
118+
envvar="DISPATCH_API_KEY",
119+
show_envvar=True,
120+
required=True,
121+
)
72122
@click.pass_context
73-
async def cli(ctx: click.Context, host: str, port: int) -> None:
123+
async def cli(ctx: click.Context, host: str, port: int, key: str) -> None:
74124
"""Dispatch Service CLI."""
75-
ctx.ensure_object(dict)
76-
ctx.obj["client"] = get_client(host, port)
125+
if ctx.obj is None:
126+
ctx.obj = {}
127+
128+
ctx.obj["client"] = get_client(host=host, port=port, key=key)
129+
ctx.obj["params"] = {
130+
"host": host,
131+
"port": port,
132+
"key": key,
133+
}
134+
135+
# Check if a subcommand was given
136+
if ctx.invoked_subcommand is None:
137+
await interactive_mode(host, port, key)
77138

78139

79140
@cli.command("list")
@@ -327,6 +388,18 @@ async def get(ctx: click.Context, dispatch_ids: List[int]) -> None:
327388
raise click.ClickException("Some gets failed.")
328389

329390

391+
@cli.command()
392+
@click.pass_obj
393+
async def repl(
394+
obj: dict[str, Any],
395+
) -> None:
396+
"""Start an interactive interface."""
397+
click.echo(f"Parameters: {obj}")
398+
await interactive_mode(
399+
obj["params"]["host"], obj["params"]["port"], obj["params"]["key"]
400+
)
401+
402+
330403
@cli.command()
331404
@click.argument("dispatch_ids", type=FuzzyIntRange(), nargs=-1) # Allow multiple IDs
332405
@click.pass_context
@@ -359,7 +432,7 @@ async def delete(ctx: click.Context, dispatch_ids: list[list[int]]) -> None:
359432
raise click.ClickException("Some deletions failed.")
360433

361434

362-
async def interactive_mode() -> None:
435+
async def interactive_mode(host: str, port: int, key: str) -> None:
363436
"""Interactive mode for the CLI."""
364437
hist_file = os.path.expanduser("~/.dispatch_cli_history.txt")
365438
session: PromptSession[str] = PromptSession(history=FileHistory(filename=hist_file))
@@ -390,7 +463,15 @@ async def display_help() -> None:
390463
break
391464
else:
392465
# Split, but keep quoted strings together
393-
params = click.parser.split_arg_string(user_input)
466+
params = [
467+
"--host",
468+
host,
469+
"--port",
470+
str(port),
471+
"--key",
472+
key,
473+
] + click.parser.split_arg_string(user_input)
474+
394475
try:
395476
await cli.main(args=params, standalone_mode=False)
396477
except click.ClickException as e:
@@ -405,10 +486,7 @@ async def display_help() -> None:
405486

406487
def main() -> None:
407488
"""Entrypoint for the CLI."""
408-
if len(sys.argv) > 1:
409-
asyncio.run(cli.main())
410-
else:
411-
asyncio.run(interactive_mode())
489+
asyncio.run(cli.main())
412490

413491

414492
if __name__ == "__main__":

src/frequenz/client/dispatch/_client.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,19 @@
3636
class Client:
3737
"""Dispatch API client."""
3838

39-
def __init__(self, grpc_channel: grpc.aio.Channel, svc_addr: str) -> None:
39+
def __init__(
40+
self, *, grpc_channel: grpc.aio.Channel, svc_addr: str, key: str
41+
) -> None:
4042
"""Initialize the client.
4143
4244
Args:
4345
grpc_channel: gRPC channel to use for communication with the API.
4446
svc_addr: Address of the service to connect to.
47+
key: API key to use for authentication.
4548
"""
4649
self._svc_addr = svc_addr
4750
self._stub = dispatch_pb2_grpc.MicrogridDispatchServiceStub(grpc_channel)
51+
self._metadata = (("key", key),)
4852

4953
# pylint: disable=too-many-arguments, too-many-locals
5054
async def list(
@@ -64,7 +68,7 @@ async def list(
6468
6569
```python
6670
grpc_channel = grpc.aio.insecure_channel("example")
67-
client = Client(grpc_channel, "localhost:50051")
71+
client = Client(grpc_channel=grpc_channel, svc_addr="localhost:50051", key="key")
6872
async for dispatch in client.list(microgrid_id=1):
6973
print(dispatch)
7074
```
@@ -108,7 +112,9 @@ async def list(
108112
)
109113
request = DispatchListRequest(microgrid_id=microgrid_id, filter=filters)
110114

111-
response = await self._stub.ListMicrogridDispatches(request) # type: ignore
115+
response = await self._stub.ListMicrogridDispatches(
116+
request, metadata=self._metadata
117+
) # type: ignore
112118
for dispatch in response.dispatches:
113119
yield Dispatch.from_protobuf(dispatch)
114120

@@ -166,7 +172,9 @@ async def create(
166172
recurrence=recurrence or RecurrenceRule(),
167173
)
168174

169-
await self._stub.CreateMicrogridDispatch(request.to_protobuf()) # type: ignore
175+
await self._stub.CreateMicrogridDispatch(
176+
request.to_protobuf(), metadata=self._metadata
177+
) # type: ignore
170178

171179
if dispatch := await self._try_fetch_created_dispatch(request):
172180
return dispatch
@@ -246,7 +254,9 @@ async def update(
246254

247255
msg.update_mask.paths.append(key)
248256

249-
await self._stub.UpdateMicrogridDispatch(msg) # type: ignore
257+
await self._stub.UpdateMicrogridDispatch(
258+
msg, metadata=self._metadata
259+
) # type: ignore
250260

251261
async def get(self, dispatch_id: int) -> Dispatch:
252262
"""Get a dispatch.
@@ -258,7 +268,9 @@ async def get(self, dispatch_id: int) -> Dispatch:
258268
Dispatch: The dispatch.
259269
"""
260270
request = DispatchGetRequest(id=dispatch_id)
261-
response = await self._stub.GetMicrogridDispatch(request) # type: ignore
271+
response = await self._stub.GetMicrogridDispatch(
272+
request, metadata=self._metadata
273+
) # type: ignore
262274
return Dispatch.from_protobuf(response)
263275

264276
async def delete(self, dispatch_id: int) -> None:
@@ -268,7 +280,9 @@ async def delete(self, dispatch_id: int) -> None:
268280
dispatch_id: The dispatch_id to delete.
269281
"""
270282
request = DispatchDeleteRequest(id=dispatch_id)
271-
await self._stub.DeleteMicrogridDispatch(request) # type: ignore
283+
await self._stub.DeleteMicrogridDispatch(
284+
request, metadata=self._metadata
285+
) # type: ignore
272286

273287
async def _try_fetch_created_dispatch(
274288
self,

src/frequenz/client/dispatch/test/_service.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
from .._internal_types import DispatchCreateRequest
3636
from ..types import Dispatch
3737

38+
ALL_KEY = "all"
39+
"""Key that has access to all resources in the FakeService."""
40+
41+
NONE_KEY = "none"
42+
"""Key that has no access to any resources in the FakeService."""
43+
3844

3945
@dataclass
4046
class FakeService:
@@ -49,18 +55,59 @@ class FakeService:
4955
_shuffle_after_create: bool = False
5056
"""Whether to shuffle the dispatches after creating them."""
5157

58+
def _check_access(self, metadata: grpc.aio.Metadata) -> None:
59+
"""Check if the access key is valid.
60+
61+
Args:
62+
metadata: The metadata.
63+
64+
Raises:
65+
grpc.RpcError: If the access key is invalid.
66+
"""
67+
# metadata is a weird tuple of tuples, we don't like it
68+
metadata_dict = dict(metadata)
69+
70+
if "key" not in metadata_dict:
71+
raise grpc.RpcError(
72+
grpc.StatusCode.UNAUTHENTICATED,
73+
"No access key provided",
74+
)
75+
76+
key = metadata_dict["key"]
77+
78+
if key is None:
79+
raise grpc.RpcError(
80+
grpc.StatusCode.UNAUTHENTICATED,
81+
"No access key provided",
82+
)
83+
84+
if key == NONE_KEY:
85+
raise grpc.RpcError(
86+
grpc.StatusCode.PERMISSION_DENIED,
87+
"Permission denied",
88+
)
89+
90+
if key != ALL_KEY:
91+
raise grpc.RpcError(
92+
grpc.StatusCode.UNAUTHENTICATED,
93+
"Invalid access key",
94+
)
95+
5296
# pylint: disable=invalid-name
5397
async def ListMicrogridDispatches(
54-
self, request: PBDispatchListRequest
98+
self, request: PBDispatchListRequest, metadata: grpc.aio.Metadata
5599
) -> DispatchList:
56100
"""List microgrid dispatches.
57101
58102
Args:
59103
request: The request.
104+
metadata: The metadata.
60105
61106
Returns:
62107
The dispatch list.
63108
"""
109+
self._check_access(metadata)
110+
64111
return DispatchList(
65112
dispatches=map(
66113
lambda d: d.to_protobuf(),
@@ -108,8 +155,10 @@ def _filter_dispatch(dispatch: Dispatch, request: PBDispatchListRequest) -> bool
108155
async def CreateMicrogridDispatch(
109156
self,
110157
request: PBDispatchCreateRequest,
158+
metadata: grpc.aio.Metadata,
111159
) -> Empty:
112160
"""Create a new dispatch."""
161+
self._check_access(metadata)
113162
self._last_id += 1
114163

115164
self.dispatches.append(
@@ -128,8 +177,10 @@ async def CreateMicrogridDispatch(
128177
async def UpdateMicrogridDispatch(
129178
self,
130179
request: DispatchUpdateRequest,
180+
metadata: grpc.aio.Metadata,
131181
) -> Empty:
132182
"""Update a dispatch."""
183+
self._check_access(metadata)
133184
index = next(
134185
(i for i, d in enumerate(self.dispatches) if d.id == request.id),
135186
None,
@@ -194,8 +245,10 @@ async def UpdateMicrogridDispatch(
194245
async def GetMicrogridDispatch(
195246
self,
196247
request: DispatchGetRequest,
248+
metadata: grpc.aio.Metadata,
197249
) -> PBDispatch:
198250
"""Get a single dispatch."""
251+
self._check_access(metadata)
199252
dispatch = next((d for d in self.dispatches if d.id == request.id), None)
200253

201254
if dispatch is None:
@@ -211,8 +264,10 @@ async def GetMicrogridDispatch(
211264
async def DeleteMicrogridDispatch(
212265
self,
213266
request: DispatchDeleteRequest,
267+
metadata: grpc.aio.Metadata,
214268
) -> Empty:
215269
"""Delete a given dispatch."""
270+
self._check_access(metadata)
216271
num_dispatches = len(self.dispatches)
217272
self.dispatches = [d for d in self.dispatches if d.id != request.id]
218273

0 commit comments

Comments
 (0)