|
5 | 5 |
|
6 | 6 | Useful for testing. |
7 | 7 | """ |
8 | | -import dataclasses |
| 8 | +import logging |
9 | 9 | from dataclasses import dataclass, replace |
10 | 10 | from datetime import datetime, timezone |
| 11 | +from typing import AsyncIterator |
11 | 12 |
|
12 | 13 | import grpc |
13 | 14 | import grpc.aio |
|
22 | 23 | DeleteMicrogridDispatchRequest, |
23 | 24 | GetMicrogridDispatchRequest, |
24 | 25 | GetMicrogridDispatchResponse, |
25 | | -) |
26 | | -from frequenz.api.dispatch.v1.dispatch_pb2 import ( |
27 | | - ListMicrogridDispatchesRequest as PBDispatchListRequest, |
28 | | -) |
29 | | -from frequenz.api.dispatch.v1.dispatch_pb2 import ( |
| 26 | + ListMicrogridDispatchesRequest, |
30 | 27 | ListMicrogridDispatchesResponse, |
| 28 | + StreamMicrogridDispatchesRequest, |
| 29 | + StreamMicrogridDispatchesResponse, |
31 | 30 | UpdateMicrogridDispatchRequest, |
32 | 31 | UpdateMicrogridDispatchResponse, |
33 | 32 | ) |
| 33 | +from frequenz.channels import Broadcast |
34 | 34 | from google.protobuf.empty_pb2 import Empty |
35 | 35 |
|
36 | 36 | # pylint: enable=no-name-in-module |
37 | 37 | from frequenz.client.base.conversion import to_datetime as _to_dt |
38 | 38 |
|
39 | 39 | from .._internal_types import DispatchCreateRequest |
40 | | -from ..types import Dispatch |
| 40 | +from ..types import Dispatch, DispatchEvent, Event |
41 | 41 |
|
42 | 42 | ALL_KEY = "all" |
43 | 43 | """Key that has access to all resources in the FakeService.""" |
|
46 | 46 | """Key that has no access to any resources in the FakeService.""" |
47 | 47 |
|
48 | 48 |
|
49 | | -@dataclass |
50 | 49 | class FakeService: |
51 | 50 | """Dispatch mock service for testing.""" |
52 | 51 |
|
53 | | - dispatches: dict[int, list[Dispatch]] = dataclasses.field(default_factory=dict) |
54 | | - """List of dispatches per microgrid.""" |
| 52 | + @dataclass(frozen=True) |
| 53 | + class StreamEvent: |
| 54 | + """Event for the stream.""" |
| 55 | + |
| 56 | + microgrid_id: int |
| 57 | + """The microgrid id.""" |
| 58 | + |
| 59 | + event: DispatchEvent |
| 60 | + """The event.""" |
| 61 | + |
| 62 | + def __init__(self) -> None: |
| 63 | + """Initialize the stream sender.""" |
| 64 | + self._stream_channel: Broadcast[FakeService.StreamEvent] = Broadcast( |
| 65 | + name="fakeservice-dispatch-stream" |
| 66 | + ) |
| 67 | + self._stream_sender = self._stream_channel.new_sender() |
| 68 | + |
| 69 | + self.dispatches: dict[int, list[Dispatch]] = {} |
| 70 | + """List of dispatches per microgrid.""" |
55 | 71 |
|
56 | | - _last_id: int = 0 |
57 | | - """Last used dispatch id.""" |
| 72 | + self._last_id: int = 0 |
| 73 | + """Last used dispatch id.""" |
58 | 74 |
|
59 | 75 | def _check_access(self, metadata: grpc.aio.Metadata) -> None: |
60 | 76 | """Check if the access key is valid. |
@@ -96,7 +112,7 @@ def _check_access(self, metadata: grpc.aio.Metadata) -> None: |
96 | 112 |
|
97 | 113 | # pylint: disable=invalid-name |
98 | 114 | async def ListMicrogridDispatches( |
99 | | - self, request: PBDispatchListRequest, metadata: grpc.aio.Metadata |
| 115 | + self, request: ListMicrogridDispatchesRequest, metadata: grpc.aio.Metadata |
100 | 116 | ) -> ListMicrogridDispatchesResponse: |
101 | 117 | """List microgrid dispatches. |
102 | 118 |
|
@@ -124,9 +140,39 @@ async def ListMicrogridDispatches( |
124 | 140 | ), |
125 | 141 | ) |
126 | 142 |
|
| 143 | + async def StreamMicrogridDispatches( |
| 144 | + self, request: StreamMicrogridDispatchesRequest, metadata: grpc.aio.Metadata |
| 145 | + ) -> AsyncIterator[StreamMicrogridDispatchesResponse]: |
| 146 | + """Stream microgrid dispatches changes. |
| 147 | +
|
| 148 | + Args: |
| 149 | + request: The request. |
| 150 | + metadata: The metadata. |
| 151 | +
|
| 152 | + Returns: |
| 153 | + An async generator for dispatch changes. |
| 154 | +
|
| 155 | + Yields: |
| 156 | + An event for each dispatch change. |
| 157 | + """ |
| 158 | + self._check_access(metadata) |
| 159 | + |
| 160 | + receiver = self._stream_channel.new_receiver() |
| 161 | + |
| 162 | + async for message in receiver: |
| 163 | + logging.debug("Received message: %s", message) |
| 164 | + if message.microgrid_id == request.microgrid_id: |
| 165 | + response = StreamMicrogridDispatchesResponse( |
| 166 | + event=message.event.event.value, |
| 167 | + dispatch=message.event.dispatch.to_protobuf(), |
| 168 | + ) |
| 169 | + yield response |
| 170 | + |
127 | 171 | # pylint: disable=too-many-branches |
128 | 172 | @staticmethod |
129 | | - def _filter_dispatch(dispatch: Dispatch, request: PBDispatchListRequest) -> bool: |
| 173 | + def _filter_dispatch( |
| 174 | + dispatch: Dispatch, request: ListMicrogridDispatchesRequest |
| 175 | + ) -> bool: |
130 | 176 | """Filter a dispatch based on the request.""" |
131 | 177 | if request.HasField("filter"): |
132 | 178 | _filter = request.filter |
@@ -181,6 +227,13 @@ async def CreateMicrogridDispatch( |
181 | 227 | # implicitly create the list if it doesn't exist |
182 | 228 | self.dispatches.setdefault(request.microgrid_id, []).append(new_dispatch) |
183 | 229 |
|
| 230 | + await self._stream_sender.send( |
| 231 | + self.StreamEvent( |
| 232 | + request.microgrid_id, |
| 233 | + DispatchEvent(dispatch=new_dispatch, event=Event.CREATED), |
| 234 | + ) |
| 235 | + ) |
| 236 | + |
184 | 237 | return CreateMicrogridDispatchResponse(dispatch=new_dispatch.to_protobuf()) |
185 | 238 |
|
186 | 239 | async def UpdateMicrogridDispatch( |
@@ -255,6 +308,13 @@ async def UpdateMicrogridDispatch( |
255 | 308 |
|
256 | 309 | grid_dispatches[index] = dispatch |
257 | 310 |
|
| 311 | + await self._stream_sender.send( |
| 312 | + self.StreamEvent( |
| 313 | + request.microgrid_id, |
| 314 | + DispatchEvent(dispatch=dispatch, event=Event.UPDATED), |
| 315 | + ) |
| 316 | + ) |
| 317 | + |
258 | 318 | return UpdateMicrogridDispatchResponse(dispatch=dispatch.to_protobuf()) |
259 | 319 |
|
260 | 320 | async def GetMicrogridDispatch( |
@@ -287,19 +347,31 @@ async def DeleteMicrogridDispatch( |
287 | 347 | """Delete a given dispatch.""" |
288 | 348 | self._check_access(metadata) |
289 | 349 | grid_dispatches = self.dispatches.get(request.microgrid_id, []) |
290 | | - num_dispatches = len(grid_dispatches) |
291 | | - self.dispatches[request.microgrid_id] = [ |
292 | | - d for d in grid_dispatches if d.id != request.dispatch_id |
293 | | - ] |
294 | 350 |
|
295 | | - if len(self.dispatches[request.microgrid_id]) == num_dispatches: |
| 351 | + dispatch_to_delete = next( |
| 352 | + (d for d in grid_dispatches if d.id == request.dispatch_id), None |
| 353 | + ) |
| 354 | + |
| 355 | + if dispatch_to_delete is None: |
296 | 356 | error = grpc.RpcError() |
297 | 357 | # pylint: disable=protected-access |
298 | 358 | error._code = grpc.StatusCode.NOT_FOUND # type: ignore |
299 | 359 | error._details = "Dispatch not found" # type: ignore |
300 | 360 | # pylint: enable=protected-access |
301 | 361 | raise error |
302 | 362 |
|
| 363 | + grid_dispatches.remove(dispatch_to_delete) |
| 364 | + |
| 365 | + await self._stream_sender.send( |
| 366 | + self.StreamEvent( |
| 367 | + request.microgrid_id, |
| 368 | + DispatchEvent( |
| 369 | + dispatch=dispatch_to_delete, |
| 370 | + event=Event.DELETED, |
| 371 | + ), |
| 372 | + ) |
| 373 | + ) |
| 374 | + |
303 | 375 | return Empty() |
304 | 376 |
|
305 | 377 | # pylint: enable=invalid-name |
|
0 commit comments