Skip to content

Commit aec339d

Browse files
authored
Implement stream() for FakeService (#87)
2 parents 24a8e3a + eb7e8bc commit aec339d

File tree

4 files changed

+142
-21
lines changed

4 files changed

+142
-21
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
## New Features
1212

1313
* Added support for duration=None when creating a dispatch.
14+
* The `FakeService` now supports the `stream()` method.
1415

1516
## Bug Fixes
1617

src/frequenz/client/dispatch/test/_service.py

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
66
Useful for testing.
77
"""
8-
import dataclasses
8+
import logging
99
from dataclasses import dataclass, replace
1010
from datetime import datetime, timezone
11+
from typing import AsyncIterator
1112

1213
import grpc
1314
import grpc.aio
@@ -22,22 +23,21 @@
2223
DeleteMicrogridDispatchRequest,
2324
GetMicrogridDispatchRequest,
2425
GetMicrogridDispatchResponse,
25-
)
26-
from frequenz.api.dispatch.v1.dispatch_pb2 import (
27-
ListMicrogridDispatchesRequest as PBDispatchListRequest,
28-
)
29-
from frequenz.api.dispatch.v1.dispatch_pb2 import (
26+
ListMicrogridDispatchesRequest,
3027
ListMicrogridDispatchesResponse,
28+
StreamMicrogridDispatchesRequest,
29+
StreamMicrogridDispatchesResponse,
3130
UpdateMicrogridDispatchRequest,
3231
UpdateMicrogridDispatchResponse,
3332
)
33+
from frequenz.channels import Broadcast
3434
from google.protobuf.empty_pb2 import Empty
3535

3636
# pylint: enable=no-name-in-module
3737
from frequenz.client.base.conversion import to_datetime as _to_dt
3838

3939
from .._internal_types import DispatchCreateRequest
40-
from ..types import Dispatch
40+
from ..types import Dispatch, DispatchEvent, Event
4141

4242
ALL_KEY = "all"
4343
"""Key that has access to all resources in the FakeService."""
@@ -46,15 +46,31 @@
4646
"""Key that has no access to any resources in the FakeService."""
4747

4848

49-
@dataclass
5049
class FakeService:
5150
"""Dispatch mock service for testing."""
5251

53-
dispatches: dict[int, list[Dispatch]] = dataclasses.field(default_factory=dict)
54-
"""List of dispatches per microgrid."""
52+
@dataclass(frozen=True)
53+
class StreamEvent:
54+
"""Event for the stream."""
55+
56+
microgrid_id: int
57+
"""The microgrid id."""
58+
59+
event: DispatchEvent
60+
"""The event."""
61+
62+
def __init__(self) -> None:
63+
"""Initialize the stream sender."""
64+
self._stream_channel: Broadcast[FakeService.StreamEvent] = Broadcast(
65+
name="fakeservice-dispatch-stream"
66+
)
67+
self._stream_sender = self._stream_channel.new_sender()
68+
69+
self.dispatches: dict[int, list[Dispatch]] = {}
70+
"""List of dispatches per microgrid."""
5571

56-
_last_id: int = 0
57-
"""Last used dispatch id."""
72+
self._last_id: int = 0
73+
"""Last used dispatch id."""
5874

5975
def _check_access(self, metadata: grpc.aio.Metadata) -> None:
6076
"""Check if the access key is valid.
@@ -96,7 +112,7 @@ def _check_access(self, metadata: grpc.aio.Metadata) -> None:
96112

97113
# pylint: disable=invalid-name
98114
async def ListMicrogridDispatches(
99-
self, request: PBDispatchListRequest, metadata: grpc.aio.Metadata
115+
self, request: ListMicrogridDispatchesRequest, metadata: grpc.aio.Metadata
100116
) -> ListMicrogridDispatchesResponse:
101117
"""List microgrid dispatches.
102118
@@ -124,9 +140,39 @@ async def ListMicrogridDispatches(
124140
),
125141
)
126142

143+
async def StreamMicrogridDispatches(
144+
self, request: StreamMicrogridDispatchesRequest, metadata: grpc.aio.Metadata
145+
) -> AsyncIterator[StreamMicrogridDispatchesResponse]:
146+
"""Stream microgrid dispatches changes.
147+
148+
Args:
149+
request: The request.
150+
metadata: The metadata.
151+
152+
Returns:
153+
An async generator for dispatch changes.
154+
155+
Yields:
156+
An event for each dispatch change.
157+
"""
158+
self._check_access(metadata)
159+
160+
receiver = self._stream_channel.new_receiver()
161+
162+
async for message in receiver:
163+
logging.debug("Received message: %s", message)
164+
if message.microgrid_id == request.microgrid_id:
165+
response = StreamMicrogridDispatchesResponse(
166+
event=message.event.event.value,
167+
dispatch=message.event.dispatch.to_protobuf(),
168+
)
169+
yield response
170+
127171
# pylint: disable=too-many-branches
128172
@staticmethod
129-
def _filter_dispatch(dispatch: Dispatch, request: PBDispatchListRequest) -> bool:
173+
def _filter_dispatch(
174+
dispatch: Dispatch, request: ListMicrogridDispatchesRequest
175+
) -> bool:
130176
"""Filter a dispatch based on the request."""
131177
if request.HasField("filter"):
132178
_filter = request.filter
@@ -181,6 +227,13 @@ async def CreateMicrogridDispatch(
181227
# implicitly create the list if it doesn't exist
182228
self.dispatches.setdefault(request.microgrid_id, []).append(new_dispatch)
183229

230+
await self._stream_sender.send(
231+
self.StreamEvent(
232+
request.microgrid_id,
233+
DispatchEvent(dispatch=new_dispatch, event=Event.CREATED),
234+
)
235+
)
236+
184237
return CreateMicrogridDispatchResponse(dispatch=new_dispatch.to_protobuf())
185238

186239
async def UpdateMicrogridDispatch(
@@ -255,6 +308,13 @@ async def UpdateMicrogridDispatch(
255308

256309
grid_dispatches[index] = dispatch
257310

311+
await self._stream_sender.send(
312+
self.StreamEvent(
313+
request.microgrid_id,
314+
DispatchEvent(dispatch=dispatch, event=Event.UPDATED),
315+
)
316+
)
317+
258318
return UpdateMicrogridDispatchResponse(dispatch=dispatch.to_protobuf())
259319

260320
async def GetMicrogridDispatch(
@@ -287,19 +347,31 @@ async def DeleteMicrogridDispatch(
287347
"""Delete a given dispatch."""
288348
self._check_access(metadata)
289349
grid_dispatches = self.dispatches.get(request.microgrid_id, [])
290-
num_dispatches = len(grid_dispatches)
291-
self.dispatches[request.microgrid_id] = [
292-
d for d in grid_dispatches if d.id != request.dispatch_id
293-
]
294350

295-
if len(self.dispatches[request.microgrid_id]) == num_dispatches:
351+
dispatch_to_delete = next(
352+
(d for d in grid_dispatches if d.id == request.dispatch_id), None
353+
)
354+
355+
if dispatch_to_delete is None:
296356
error = grpc.RpcError()
297357
# pylint: disable=protected-access
298358
error._code = grpc.StatusCode.NOT_FOUND # type: ignore
299359
error._details = "Dispatch not found" # type: ignore
300360
# pylint: enable=protected-access
301361
raise error
302362

363+
grid_dispatches.remove(dispatch_to_delete)
364+
365+
await self._stream_sender.send(
366+
self.StreamEvent(
367+
request.microgrid_id,
368+
DispatchEvent(
369+
dispatch=dispatch_to_delete,
370+
event=Event.DELETED,
371+
),
372+
)
373+
)
374+
303375
return Empty()
304376

305377
# pylint: enable=invalid-name

src/frequenz/client/dispatch/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def to_protobuf(self) -> PBDispatch:
357357
is_active=self.active,
358358
is_dry_run=self.dry_run,
359359
payload=payload,
360-
recurrence=self.recurrence.to_protobuf(),
360+
recurrence=self.recurrence.to_protobuf() if self.recurrence else None,
361361
),
362362
)
363363

tests/test_dispatch_client.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Tests for the frequenz.client.dispatch package."""
55

6+
import asyncio
67
import random
78
from dataclasses import replace
89
from datetime import timedelta
@@ -13,7 +14,7 @@
1314
from frequenz.client.dispatch.test.client import FakeClient, to_create_params
1415
from frequenz.client.dispatch.test.fixtures import client, generator, sample
1516
from frequenz.client.dispatch.test.generator import DispatchGenerator
16-
from frequenz.client.dispatch.types import Dispatch
17+
from frequenz.client.dispatch.types import Dispatch, Event
1718

1819
# Ignore flake8 error in the rest of the file to use the same fixture names
1920
# flake8: noqa[811]
@@ -261,3 +262,50 @@ async def test_delete_dispatch_fail(client: FakeClient) -> None:
261262
"""Test deleting a non-existent dispatch."""
262263
with raises(grpc.RpcError):
263264
await client.delete(microgrid_id=1, dispatch_id=1)
265+
266+
267+
async def test_dispatch_stream(client: FakeClient, sample: Dispatch) -> None:
268+
"""Test dispatching a stream of dispatches."""
269+
microgrid_id = random.randint(1, 100)
270+
dispatches = [sample, sample, sample]
271+
272+
stream = client.stream(microgrid_id)
273+
274+
async def expect(dispatch: Dispatch, event: Event) -> None:
275+
message = await stream.receive()
276+
assert message.dispatch == dispatch
277+
assert message.event == event
278+
279+
# Give stream some time to start
280+
await asyncio.sleep(0.1)
281+
282+
# Add a new dispatch
283+
dispatches[0] = await client.create(**to_create_params(microgrid_id, dispatches[0]))
284+
# Expect the first dispatch event
285+
await expect(dispatches[0], Event.CREATED)
286+
287+
# Add a new dispatch
288+
dispatches[1] = await client.create(**to_create_params(microgrid_id, dispatches[1]))
289+
# Expect the second dispatch
290+
await expect(dispatches[1], Event.CREATED)
291+
292+
# Add a new dispatch
293+
dispatches[2] = await client.create(**to_create_params(microgrid_id, dispatches[2]))
294+
# Expect the third dispatch
295+
await expect(dispatches[2], Event.CREATED)
296+
297+
# Update the first dispatch
298+
dispatches[0] = await client.update(
299+
microgrid_id=microgrid_id,
300+
dispatch_id=dispatches[0].id,
301+
new_fields={"start_time": dispatches[0].start_time + timedelta(minutes=1)},
302+
)
303+
304+
# Expect the first dispatch update
305+
await expect(dispatches[0], Event.UPDATED)
306+
307+
# Delete the first dispatch
308+
await client.delete(microgrid_id=microgrid_id, dispatch_id=dispatches[0].id)
309+
310+
# Expect the first dispatch deletion
311+
await expect(dispatches[0], Event.DELETED)

0 commit comments

Comments
 (0)