|
5 | 5 |
|
6 | 6 | Useful for testing. |
7 | 7 | """ |
8 | | -import dataclasses |
| 8 | +import logging |
9 | 9 | from dataclasses import dataclass, replace |
10 | 10 | from datetime import datetime, timezone |
| 11 | +from typing import AsyncIterator |
11 | 12 |
|
12 | 13 | import grpc |
13 | 14 | import grpc.aio |
|
24 | 25 | GetMicrogridDispatchResponse, |
25 | 26 | ListMicrogridDispatchesRequest, |
26 | 27 | ListMicrogridDispatchesResponse, |
| 28 | + StreamMicrogridDispatchesRequest, |
| 29 | + StreamMicrogridDispatchesResponse, |
27 | 30 | UpdateMicrogridDispatchRequest, |
28 | 31 | UpdateMicrogridDispatchResponse, |
29 | 32 | ) |
| 33 | +from frequenz.channels import Broadcast |
30 | 34 | from google.protobuf.empty_pb2 import Empty |
31 | 35 |
|
32 | 36 | # pylint: enable=no-name-in-module |
33 | 37 | from frequenz.client.base.conversion import to_datetime as _to_dt |
34 | 38 |
|
35 | 39 | from .._internal_types import DispatchCreateRequest |
36 | | -from ..types import Dispatch |
| 40 | +from ..types import Dispatch, DispatchEvent, Event |
37 | 41 |
|
38 | 42 | ALL_KEY = "all" |
39 | 43 | """Key that has access to all resources in the FakeService.""" |
|
42 | 46 | """Key that has no access to any resources in the FakeService.""" |
43 | 47 |
|
44 | 48 |
|
45 | | -@dataclass |
46 | 49 | class FakeService: |
47 | 50 | """Dispatch mock service for testing.""" |
48 | 51 |
|
49 | | - dispatches: dict[int, list[Dispatch]] = dataclasses.field(default_factory=dict) |
50 | | - """List of dispatches per microgrid.""" |
| 52 | + @dataclass(frozen=True) |
| 53 | + class StreamEvent: |
| 54 | + """Event for the stream.""" |
51 | 55 |
|
52 | | - _last_id: int = 0 |
53 | | - """Last used dispatch id.""" |
| 56 | + microgrid_id: int |
| 57 | + """The microgrid id.""" |
| 58 | + |
| 59 | + event: DispatchEvent |
| 60 | + """The event.""" |
| 61 | + |
| 62 | + def __init__(self) -> None: |
| 63 | + """Initialize the stream sender.""" |
| 64 | + self._stream_channel: Broadcast[FakeService.StreamEvent] = Broadcast( |
| 65 | + name="fakeservice-dispatch-stream" |
| 66 | + ) |
| 67 | + self._stream_sender = self._stream_channel.new_sender() |
| 68 | + |
| 69 | + self.dispatches: dict[int, list[Dispatch]] = {} |
| 70 | + """List of dispatches per microgrid.""" |
| 71 | + |
| 72 | + self._last_id: int = 0 |
| 73 | + """Last used dispatch id.""" |
54 | 74 |
|
55 | 75 | def _check_access(self, metadata: grpc.aio.Metadata) -> None: |
56 | 76 | """Check if the access key is valid. |
@@ -120,6 +140,34 @@ async def ListMicrogridDispatches( |
120 | 140 | ), |
121 | 141 | ) |
122 | 142 |
|
| 143 | + async def StreamMicrogridDispatches( |
| 144 | + self, request: StreamMicrogridDispatchesRequest, metadata: grpc.aio.Metadata |
| 145 | + ) -> AsyncIterator[StreamMicrogridDispatchesResponse]: |
| 146 | + """Stream microgrid dispatches changes. |
| 147 | +
|
| 148 | + Args: |
| 149 | + request: The request. |
| 150 | + metadata: The metadata. |
| 151 | +
|
| 152 | + Returns: |
| 153 | + An async generator for dispatch changes. |
| 154 | +
|
| 155 | + Yields: |
| 156 | + An event for each dispatch change. |
| 157 | + """ |
| 158 | + self._check_access(metadata) |
| 159 | + |
| 160 | + receiver = self._stream_channel.new_receiver() |
| 161 | + |
| 162 | + async for message in receiver: |
| 163 | + logging.debug("Received message: %s", message) |
| 164 | + if message.microgrid_id == request.microgrid_id: |
| 165 | + response = StreamMicrogridDispatchesResponse( |
| 166 | + event=message.event.event.value, |
| 167 | + dispatch=message.event.dispatch.to_protobuf(), |
| 168 | + ) |
| 169 | + yield response |
| 170 | + |
123 | 171 | # pylint: disable=too-many-branches |
124 | 172 | @staticmethod |
125 | 173 | def _filter_dispatch( |
@@ -179,6 +227,13 @@ async def CreateMicrogridDispatch( |
179 | 227 | # implicitly create the list if it doesn't exist |
180 | 228 | self.dispatches.setdefault(request.microgrid_id, []).append(new_dispatch) |
181 | 229 |
|
| 230 | + await self._stream_sender.send( |
| 231 | + self.StreamEvent( |
| 232 | + request.microgrid_id, |
| 233 | + DispatchEvent(dispatch=new_dispatch, event=Event.CREATED), |
| 234 | + ) |
| 235 | + ) |
| 236 | + |
182 | 237 | return CreateMicrogridDispatchResponse(dispatch=new_dispatch.to_protobuf()) |
183 | 238 |
|
184 | 239 | async def UpdateMicrogridDispatch( |
@@ -253,6 +308,13 @@ async def UpdateMicrogridDispatch( |
253 | 308 |
|
254 | 309 | grid_dispatches[index] = dispatch |
255 | 310 |
|
| 311 | + await self._stream_sender.send( |
| 312 | + self.StreamEvent( |
| 313 | + request.microgrid_id, |
| 314 | + DispatchEvent(dispatch=dispatch, event=Event.UPDATED), |
| 315 | + ) |
| 316 | + ) |
| 317 | + |
256 | 318 | return UpdateMicrogridDispatchResponse(dispatch=dispatch.to_protobuf()) |
257 | 319 |
|
258 | 320 | async def GetMicrogridDispatch( |
@@ -285,19 +347,31 @@ async def DeleteMicrogridDispatch( |
285 | 347 | """Delete a given dispatch.""" |
286 | 348 | self._check_access(metadata) |
287 | 349 | grid_dispatches = self.dispatches.get(request.microgrid_id, []) |
288 | | - num_dispatches = len(grid_dispatches) |
289 | | - self.dispatches[request.microgrid_id] = [ |
290 | | - d for d in grid_dispatches if d.id != request.dispatch_id |
291 | | - ] |
292 | 350 |
|
293 | | - if len(self.dispatches[request.microgrid_id]) == num_dispatches: |
| 351 | + dispatch_to_delete = next( |
| 352 | + (d for d in grid_dispatches if d.id == request.dispatch_id), None |
| 353 | + ) |
| 354 | + |
| 355 | + if dispatch_to_delete is None: |
294 | 356 | error = grpc.RpcError() |
295 | 357 | # pylint: disable=protected-access |
296 | 358 | error._code = grpc.StatusCode.NOT_FOUND # type: ignore |
297 | 359 | error._details = "Dispatch not found" # type: ignore |
298 | 360 | # pylint: enable=protected-access |
299 | 361 | raise error |
300 | 362 |
|
| 363 | + grid_dispatches.remove(dispatch_to_delete) |
| 364 | + |
| 365 | + await self._stream_sender.send( |
| 366 | + self.StreamEvent( |
| 367 | + request.microgrid_id, |
| 368 | + DispatchEvent( |
| 369 | + dispatch=dispatch_to_delete, |
| 370 | + event=Event.DELETED, |
| 371 | + ), |
| 372 | + ) |
| 373 | + ) |
| 374 | + |
301 | 375 | return Empty() |
302 | 376 |
|
303 | 377 | # pylint: enable=invalid-name |
|
0 commit comments