Skip to content

Commit 672b9ea

Browse files
committed
Add key-based authentication support
Signed-off-by: Mathias L. Baumann <[email protected]>
1 parent c0cff94 commit 672b9ea

File tree

6 files changed

+100
-21
lines changed

6 files changed

+100
-21
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
## Upgrading
88

9-
<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
9+
* An API key for authorization must now be passed to the `DispatchClient`.
1010

1111
## New Features
1212

src/frequenz/client/dispatch/__main__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,19 @@
3737
DEFAULT_DISPATCH_API_PORT = 50051
3838

3939

40-
def get_client(host: str, port: int) -> Client:
40+
def get_client(host: str, port: int, key: str) -> Client:
4141
"""Get a new client instance.
4242
4343
Args:
4444
host: The host of the dispatch service.
4545
port: The port of the dispatch service.
46+
key: The API key for authentication.
4647
4748
Returns:
4849
Client: A new client instance.
4950
"""
5051
channel = grpc.aio.insecure_channel(f"{host}:{port}")
51-
return Client(channel, f"{host}:{port}")
52+
return Client(channel, f"{host}:{port}", key)
5253

5354

5455
# Click command groups
@@ -69,11 +70,17 @@ def get_client(host: str, port: int) -> Client:
6970
show_envvar=True,
7071
show_default=True,
7172
)
73+
@click.option(
74+
"--key",
75+
help="API key for authentication",
76+
envvar="DISPATCH_API_KEY",
77+
show_envvar=True,
78+
)
7279
@click.pass_context
73-
async def cli(ctx: click.Context, host: str, port: int) -> None:
80+
async def cli(ctx: click.Context, host: str, port: int, key: str) -> None:
7481
"""Dispatch Service CLI."""
7582
ctx.ensure_object(dict)
76-
ctx.obj["client"] = get_client(host, port)
83+
ctx.obj["client"] = get_client(host, port, key)
7784

7885

7986
@cli.command("list")

src/frequenz/client/dispatch/_client.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,17 @@
3636
class Client:
3737
"""Dispatch API client."""
3838

39-
def __init__(self, grpc_channel: grpc.aio.Channel, svc_addr: str) -> None:
39+
def __init__(self, grpc_channel: grpc.aio.Channel, svc_addr: str, key: str) -> None:
4040
"""Initialize the client.
4141
4242
Args:
4343
grpc_channel: gRPC channel to use for communication with the API.
4444
svc_addr: Address of the service to connect to.
45+
key: API key to use for authentication.
4546
"""
4647
self._svc_addr = svc_addr
4748
self._stub = dispatch_pb2_grpc.MicrogridDispatchServiceStub(grpc_channel)
49+
self._metadata = (("key", key),)
4850

4951
# pylint: disable=too-many-arguments, too-many-locals
5052
async def list(
@@ -64,7 +66,7 @@ async def list(
6466
6567
```python
6668
grpc_channel = grpc.aio.insecure_channel("example")
67-
client = Client(grpc_channel, "localhost:50051")
69+
client = Client(grpc_channel, "localhost:50051", "key")
6870
async for dispatch in client.list(microgrid_id=1):
6971
print(dispatch)
7072
```
@@ -108,7 +110,9 @@ async def list(
108110
)
109111
request = DispatchListRequest(microgrid_id=microgrid_id, filter=filters)
110112

111-
response = await self._stub.ListMicrogridDispatches(request) # type: ignore
113+
response = await self._stub.ListMicrogridDispatches(
114+
request, metadata=self._metadata
115+
) # type: ignore
112116
for dispatch in response.dispatches:
113117
yield Dispatch.from_protobuf(dispatch)
114118

@@ -166,7 +170,9 @@ async def create(
166170
recurrence=recurrence or RecurrenceRule(),
167171
)
168172

169-
await self._stub.CreateMicrogridDispatch(request.to_protobuf()) # type: ignore
173+
await self._stub.CreateMicrogridDispatch(
174+
request.to_protobuf(), metadata=self._metadata
175+
) # type: ignore
170176

171177
if dispatch := await self._try_fetch_created_dispatch(request):
172178
return dispatch
@@ -246,7 +252,9 @@ async def update(
246252

247253
msg.update_mask.paths.append(key)
248254

249-
await self._stub.UpdateMicrogridDispatch(msg) # type: ignore
255+
await self._stub.UpdateMicrogridDispatch(
256+
msg, metadata=self._metadata
257+
) # type: ignore
250258

251259
async def get(self, dispatch_id: int) -> Dispatch:
252260
"""Get a dispatch.
@@ -258,7 +266,9 @@ async def get(self, dispatch_id: int) -> Dispatch:
258266
Dispatch: The dispatch.
259267
"""
260268
request = DispatchGetRequest(id=dispatch_id)
261-
response = await self._stub.GetMicrogridDispatch(request) # type: ignore
269+
response = await self._stub.GetMicrogridDispatch(
270+
request, metadata=self._metadata
271+
) # type: ignore
262272
return Dispatch.from_protobuf(response)
263273

264274
async def delete(self, dispatch_id: int) -> None:
@@ -268,7 +278,9 @@ async def delete(self, dispatch_id: int) -> None:
268278
dispatch_id: The dispatch_id to delete.
269279
"""
270280
request = DispatchDeleteRequest(id=dispatch_id)
271-
await self._stub.DeleteMicrogridDispatch(request) # type: ignore
281+
await self._stub.DeleteMicrogridDispatch(
282+
request, metadata=self._metadata
283+
) # type: ignore
272284

273285
async def _try_fetch_created_dispatch(
274286
self,

src/frequenz/client/dispatch/test/_service.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,59 @@ class FakeService:
4949
_shuffle_after_create: bool = False
5050
"""Whether to shuffle the dispatches after creating them."""
5151

52+
def _check_access(self, metadata: grpc.aio.Metadata) -> None:
53+
"""Check if the access key is valid.
54+
55+
Args:
56+
metadata: The metadata.
57+
58+
Raises:
59+
grpc.RpcError: If the access key is invalid.
60+
"""
61+
# metadata is a weird tuple of tuples, we don't like it
62+
metadata_dict = dict(metadata)
63+
64+
if "key" not in metadata_dict:
65+
raise grpc.RpcError(
66+
grpc.StatusCode.UNAUTHENTICATED,
67+
"No access key provided",
68+
)
69+
70+
key = metadata_dict["key"]
71+
72+
if key is None:
73+
raise grpc.RpcError(
74+
grpc.StatusCode.UNAUTHENTICATED,
75+
"No access key provided",
76+
)
77+
78+
if key == "none":
79+
raise grpc.RpcError(
80+
grpc.StatusCode.PERMISSION_DENIED,
81+
"Permission denied",
82+
)
83+
84+
if key != "all":
85+
raise grpc.RpcError(
86+
grpc.StatusCode.UNAUTHENTICATED,
87+
"Invalid access key",
88+
)
89+
5290
# pylint: disable=invalid-name
5391
async def ListMicrogridDispatches(
54-
self, request: PBDispatchListRequest
92+
self, request: PBDispatchListRequest, metadata: grpc.aio.Metadata
5593
) -> DispatchList:
5694
"""List microgrid dispatches.
5795
5896
Args:
5997
request: The request.
98+
metadata: The metadata.
6099
61100
Returns:
62101
The dispatch list.
63102
"""
103+
self._check_access(metadata)
104+
64105
return DispatchList(
65106
dispatches=map(
66107
lambda d: d.to_protobuf(),
@@ -108,8 +149,10 @@ def _filter_dispatch(dispatch: Dispatch, request: PBDispatchListRequest) -> bool
108149
async def CreateMicrogridDispatch(
109150
self,
110151
request: PBDispatchCreateRequest,
152+
metadata: grpc.aio.Metadata,
111153
) -> Empty:
112154
"""Create a new dispatch."""
155+
self._check_access(metadata)
113156
self._last_id += 1
114157

115158
self.dispatches.append(
@@ -128,8 +171,10 @@ async def CreateMicrogridDispatch(
128171
async def UpdateMicrogridDispatch(
129172
self,
130173
request: DispatchUpdateRequest,
174+
metadata: grpc.aio.Metadata,
131175
) -> Empty:
132176
"""Update a dispatch."""
177+
self._check_access(metadata)
133178
index = next(
134179
(i for i, d in enumerate(self.dispatches) if d.id == request.id),
135180
None,
@@ -194,8 +239,10 @@ async def UpdateMicrogridDispatch(
194239
async def GetMicrogridDispatch(
195240
self,
196241
request: DispatchGetRequest,
242+
metadata: grpc.aio.Metadata,
197243
) -> PBDispatch:
198244
"""Get a single dispatch."""
245+
self._check_access(metadata)
199246
dispatch = next((d for d in self.dispatches if d.id == request.id), None)
200247

201248
if dispatch is None:
@@ -211,8 +258,10 @@ async def GetMicrogridDispatch(
211258
async def DeleteMicrogridDispatch(
212259
self,
213260
request: DispatchDeleteRequest,
261+
metadata: grpc.aio.Metadata,
214262
) -> Empty:
215263
"""Delete a given dispatch."""
264+
self._check_access(metadata)
216265
num_dispatches = len(self.dispatches)
217266
self.dispatches = [d for d in self.dispatches if d.id != request.id]
218267

src/frequenz/client/dispatch/test/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@ class FakeClient(Client):
1717
This client uses a fake service to simulate the dispatch api.
1818
"""
1919

20-
def __init__(self, shuffle_after_create: bool = False) -> None:
20+
def __init__(
21+
self,
22+
shuffle_after_create: bool = False,
23+
) -> None:
2124
"""Initialize the mock client.
2225
2326
Args:
2427
shuffle_after_create: Whether to shuffle the dispatches after creating them.
2528
"""
26-
super().__init__(MagicMock(), "mock")
29+
super().__init__(MagicMock(), "mock", "all")
2730
self._stub = FakeService() # type: ignore
2831
self._service._shuffle_after_create = shuffle_after_create
2932

tests/test_dispatch_cli.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
TEST_NOW = datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
2727
"""Arbitrary time used as NOW for testing."""
2828

29+
ENVIRONMENT_VARIABLES = {"DISPATCH_API_KEY": "all"}
30+
2931

3032
@pytest.fixture
3133
def runner() -> CliRunner:
@@ -148,9 +150,11 @@ async def test_list_command( # pylint: disable=too-many-arguments
148150
) -> None:
149151
"""Test the list command."""
150152
fake_client.dispatches = dispatches
151-
result = await runner.invoke(cli, ["list", str(microgrid_id)])
152-
assert result.exit_code == expected_return_code
153+
result = await runner.invoke(
154+
cli, ["list", str(microgrid_id)], env=ENVIRONMENT_VARIABLES
155+
)
153156
assert expected_output in result.output
157+
assert result.exit_code == expected_return_code
154158

155159

156160
@pytest.mark.asyncio
@@ -311,7 +315,7 @@ async def test_create_command( # pylint: disable=too-many-arguments,too-many-lo
311315
expected_return_code: int,
312316
) -> None:
313317
"""Test the create command."""
314-
result = await runner.invoke(cli, args)
318+
result = await runner.invoke(cli, args, env=ENVIRONMENT_VARIABLES)
315319
now = datetime.now(get_localzone())
316320

317321
if (
@@ -501,7 +505,7 @@ async def test_update_command( # pylint: disable=too-many-arguments
501505
) -> None:
502506
"""Test the update command."""
503507
fake_client.dispatches = dispatches
504-
result = await runner.invoke(cli, ["update", "1", *args])
508+
result = await runner.invoke(cli, ["update", "1", *args], env=ENVIRONMENT_VARIABLES)
505509
assert expected_output in result.output
506510
assert result.exit_code == expected_return_code
507511
if dispatches:
@@ -551,7 +555,9 @@ async def test_get_command(
551555
) -> None:
552556
"""Test the get command."""
553557
fake_client.dispatches = dispatches
554-
result = await runner.invoke(cli, ["get", str(dispatch_id)])
558+
result = await runner.invoke(
559+
cli, ["get", str(dispatch_id)], env=ENVIRONMENT_VARIABLES
560+
)
555561
assert result.exit_code == 0 if dispatches else 1
556562
assert expected_in_output in result.output
557563

@@ -600,7 +606,9 @@ async def test_delete_command( # pylint: disable=too-many-arguments
600606
) -> None:
601607
"""Test the delete command."""
602608
fake_client.dispatches = dispatches
603-
result = await runner.invoke(cli, ["delete", str(dispatch_id)])
609+
result = await runner.invoke(
610+
cli, ["delete", str(dispatch_id)], env=ENVIRONMENT_VARIABLES
611+
)
604612
assert result.exit_code == expected_return_code
605613
assert expected_output in result.output
606614
if dispatches:

0 commit comments

Comments
 (0)