diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index d80eb2ff..676fcdfc 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -11,6 +11,7 @@ ## New Features * Added support for duration=None when creating a dispatch. +* The `FakeService` now supports the `stream()` method. ## Bug Fixes diff --git a/src/frequenz/client/dispatch/test/_service.py b/src/frequenz/client/dispatch/test/_service.py index e7af1c4b..2a129bf5 100644 --- a/src/frequenz/client/dispatch/test/_service.py +++ b/src/frequenz/client/dispatch/test/_service.py @@ -5,9 +5,10 @@ Useful for testing. """ -import dataclasses +import logging from dataclasses import dataclass, replace from datetime import datetime, timezone +from typing import AsyncIterator import grpc import grpc.aio @@ -22,22 +23,21 @@ DeleteMicrogridDispatchRequest, GetMicrogridDispatchRequest, GetMicrogridDispatchResponse, -) -from frequenz.api.dispatch.v1.dispatch_pb2 import ( - ListMicrogridDispatchesRequest as PBDispatchListRequest, -) -from frequenz.api.dispatch.v1.dispatch_pb2 import ( + ListMicrogridDispatchesRequest, ListMicrogridDispatchesResponse, + StreamMicrogridDispatchesRequest, + StreamMicrogridDispatchesResponse, UpdateMicrogridDispatchRequest, UpdateMicrogridDispatchResponse, ) +from frequenz.channels import Broadcast from google.protobuf.empty_pb2 import Empty # pylint: enable=no-name-in-module from frequenz.client.base.conversion import to_datetime as _to_dt from .._internal_types import DispatchCreateRequest -from ..types import Dispatch +from ..types import Dispatch, DispatchEvent, Event ALL_KEY = "all" """Key that has access to all resources in the FakeService.""" @@ -46,15 +46,31 @@ """Key that has no access to any resources in the FakeService.""" -@dataclass class FakeService: """Dispatch mock service for testing.""" - dispatches: dict[int, list[Dispatch]] = dataclasses.field(default_factory=dict) - """List of dispatches per microgrid.""" + @dataclass(frozen=True) + class StreamEvent: + """Event for the stream.""" + + microgrid_id: int + """The microgrid id.""" + + event: DispatchEvent + """The event.""" + + def __init__(self) -> None: + """Initialize the stream sender.""" + self._stream_channel: Broadcast[FakeService.StreamEvent] = Broadcast( + name="fakeservice-dispatch-stream" + ) + self._stream_sender = self._stream_channel.new_sender() + + self.dispatches: dict[int, list[Dispatch]] = {} + """List of dispatches per microgrid.""" - _last_id: int = 0 - """Last used dispatch id.""" + self._last_id: int = 0 + """Last used dispatch id.""" def _check_access(self, metadata: grpc.aio.Metadata) -> None: """Check if the access key is valid. @@ -96,7 +112,7 @@ def _check_access(self, metadata: grpc.aio.Metadata) -> None: # pylint: disable=invalid-name async def ListMicrogridDispatches( - self, request: PBDispatchListRequest, metadata: grpc.aio.Metadata + self, request: ListMicrogridDispatchesRequest, metadata: grpc.aio.Metadata ) -> ListMicrogridDispatchesResponse: """List microgrid dispatches. @@ -124,9 +140,39 @@ async def ListMicrogridDispatches( ), ) + async def StreamMicrogridDispatches( + self, request: StreamMicrogridDispatchesRequest, metadata: grpc.aio.Metadata + ) -> AsyncIterator[StreamMicrogridDispatchesResponse]: + """Stream microgrid dispatches changes. + + Args: + request: The request. + metadata: The metadata. + + Returns: + An async generator for dispatch changes. + + Yields: + An event for each dispatch change. + """ + self._check_access(metadata) + + receiver = self._stream_channel.new_receiver() + + async for message in receiver: + logging.debug("Received message: %s", message) + if message.microgrid_id == request.microgrid_id: + response = StreamMicrogridDispatchesResponse( + event=message.event.event.value, + dispatch=message.event.dispatch.to_protobuf(), + ) + yield response + # pylint: disable=too-many-branches @staticmethod - def _filter_dispatch(dispatch: Dispatch, request: PBDispatchListRequest) -> bool: + def _filter_dispatch( + dispatch: Dispatch, request: ListMicrogridDispatchesRequest + ) -> bool: """Filter a dispatch based on the request.""" if request.HasField("filter"): _filter = request.filter @@ -181,6 +227,13 @@ async def CreateMicrogridDispatch( # implicitly create the list if it doesn't exist self.dispatches.setdefault(request.microgrid_id, []).append(new_dispatch) + await self._stream_sender.send( + self.StreamEvent( + request.microgrid_id, + DispatchEvent(dispatch=new_dispatch, event=Event.CREATED), + ) + ) + return CreateMicrogridDispatchResponse(dispatch=new_dispatch.to_protobuf()) async def UpdateMicrogridDispatch( @@ -255,6 +308,13 @@ async def UpdateMicrogridDispatch( grid_dispatches[index] = dispatch + await self._stream_sender.send( + self.StreamEvent( + request.microgrid_id, + DispatchEvent(dispatch=dispatch, event=Event.UPDATED), + ) + ) + return UpdateMicrogridDispatchResponse(dispatch=dispatch.to_protobuf()) async def GetMicrogridDispatch( @@ -287,12 +347,12 @@ async def DeleteMicrogridDispatch( """Delete a given dispatch.""" self._check_access(metadata) grid_dispatches = self.dispatches.get(request.microgrid_id, []) - num_dispatches = len(grid_dispatches) - self.dispatches[request.microgrid_id] = [ - d for d in grid_dispatches if d.id != request.dispatch_id - ] - if len(self.dispatches[request.microgrid_id]) == num_dispatches: + dispatch_to_delete = next( + (d for d in grid_dispatches if d.id == request.dispatch_id), None + ) + + if dispatch_to_delete is None: error = grpc.RpcError() # pylint: disable=protected-access error._code = grpc.StatusCode.NOT_FOUND # type: ignore @@ -300,6 +360,18 @@ async def DeleteMicrogridDispatch( # pylint: enable=protected-access raise error + grid_dispatches.remove(dispatch_to_delete) + + await self._stream_sender.send( + self.StreamEvent( + request.microgrid_id, + DispatchEvent( + dispatch=dispatch_to_delete, + event=Event.DELETED, + ), + ) + ) + return Empty() # pylint: enable=invalid-name diff --git a/src/frequenz/client/dispatch/types.py b/src/frequenz/client/dispatch/types.py index 031a27d4..c9ab30d2 100644 --- a/src/frequenz/client/dispatch/types.py +++ b/src/frequenz/client/dispatch/types.py @@ -357,7 +357,7 @@ def to_protobuf(self) -> PBDispatch: is_active=self.active, is_dry_run=self.dry_run, payload=payload, - recurrence=self.recurrence.to_protobuf(), + recurrence=self.recurrence.to_protobuf() if self.recurrence else None, ), ) diff --git a/tests/test_dispatch_client.py b/tests/test_dispatch_client.py index 13370b21..57c82c86 100644 --- a/tests/test_dispatch_client.py +++ b/tests/test_dispatch_client.py @@ -3,6 +3,7 @@ """Tests for the frequenz.client.dispatch package.""" +import asyncio import random from dataclasses import replace from datetime import timedelta @@ -13,7 +14,7 @@ from frequenz.client.dispatch.test.client import FakeClient, to_create_params from frequenz.client.dispatch.test.fixtures import client, generator, sample from frequenz.client.dispatch.test.generator import DispatchGenerator -from frequenz.client.dispatch.types import Dispatch +from frequenz.client.dispatch.types import Dispatch, Event # Ignore flake8 error in the rest of the file to use the same fixture names # flake8: noqa[811] @@ -261,3 +262,50 @@ async def test_delete_dispatch_fail(client: FakeClient) -> None: """Test deleting a non-existent dispatch.""" with raises(grpc.RpcError): await client.delete(microgrid_id=1, dispatch_id=1) + + +async def test_dispatch_stream(client: FakeClient, sample: Dispatch) -> None: + """Test dispatching a stream of dispatches.""" + microgrid_id = random.randint(1, 100) + dispatches = [sample, sample, sample] + + stream = client.stream(microgrid_id) + + async def expect(dispatch: Dispatch, event: Event) -> None: + message = await stream.receive() + assert message.dispatch == dispatch + assert message.event == event + + # Give stream some time to start + await asyncio.sleep(0.1) + + # Add a new dispatch + dispatches[0] = await client.create(**to_create_params(microgrid_id, dispatches[0])) + # Expect the first dispatch event + await expect(dispatches[0], Event.CREATED) + + # Add a new dispatch + dispatches[1] = await client.create(**to_create_params(microgrid_id, dispatches[1])) + # Expect the second dispatch + await expect(dispatches[1], Event.CREATED) + + # Add a new dispatch + dispatches[2] = await client.create(**to_create_params(microgrid_id, dispatches[2])) + # Expect the third dispatch + await expect(dispatches[2], Event.CREATED) + + # Update the first dispatch + dispatches[0] = await client.update( + microgrid_id=microgrid_id, + dispatch_id=dispatches[0].id, + new_fields={"start_time": dispatches[0].start_time + timedelta(minutes=1)}, + ) + + # Expect the first dispatch update + await expect(dispatches[0], Event.UPDATED) + + # Delete the first dispatch + await client.delete(microgrid_id=microgrid_id, dispatch_id=dispatches[0].id) + + # Expect the first dispatch deletion + await expect(dispatches[0], Event.DELETED)