diff --git a/pyproject.toml b/pyproject.toml index 4ac6216b..e204b2b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,8 +103,8 @@ dev-pytest = [ "pytest-mock == 3.14.1", "pytest-asyncio == 1.1.0", "async-solipsism == 0.8", - "time-machine == 2.16.0", - "hypothesis == 6.136.8", + "time-machine == 2.19.0", + "hypothesis == 6.138.13", "frequenz-client-dispatch[cli]", ] dev = [ diff --git a/src/frequenz/client/dispatch/_client.py b/src/frequenz/client/dispatch/_client.py index 04e59f96..e05f9fc0 100644 --- a/src/frequenz/client/dispatch/_client.py +++ b/src/frequenz/client/dispatch/_client.py @@ -270,12 +270,9 @@ def _get_stream( request = StreamMicrogridDispatchesRequest(microgrid_id=int(microgrid_id)) broadcaster = GrpcStreamBroadcaster( stream_name="StreamMicrogridDispatches", - stream_method=lambda: cast( - AsyncIterator[StreamMicrogridDispatchesResponse], - self.stub.StreamMicrogridDispatches( - request, - timeout=self._stream_timeout_seconds, - ), + stream_method=lambda: self.stub.StreamMicrogridDispatches( + request, + timeout=self._stream_timeout_seconds, ), transform=DispatchEvent.from_protobuf, retry_strategy=LinearBackoff(interval=1, limit=None), diff --git a/src/frequenz/client/dispatch/test/_service.py b/src/frequenz/client/dispatch/test/_service.py index 83a0a519..a7390725 100644 --- a/src/frequenz/client/dispatch/test/_service.py +++ b/src/frequenz/client/dispatch/test/_service.py @@ -8,7 +8,7 @@ import logging from dataclasses import dataclass, replace from datetime import datetime, timezone -from typing import AsyncIterator +from typing import AsyncIterator, TypeVar import grpc import grpc.aio @@ -44,6 +44,33 @@ _logger = logging.getLogger(__name__) +T = TypeVar("T") + + +class _MockStream(AsyncIterator[T]): + """A mock stream that wraps an async iterator and adds initial_metadata.""" + + def __init__(self, stream: AsyncIterator[T]) -> None: + """Initialize the mock stream. + + Args: + stream: The stream to wrap. + """ + self._iterator = stream.__aiter__() + + async def initial_metadata(self) -> None: + """Do nothing, just to mock the grpc call.""" + _logger.debug("Called initial_metadata()") + + def __aiter__(self) -> AsyncIterator[T]: + """Return the async iterator.""" + return self + + async def __anext__(self) -> T: + """Return the next item from the stream.""" + return await self._iterator.__anext__() + + class FakeService: """Dispatch mock service for testing.""" @@ -109,11 +136,11 @@ async def ListMicrogridDispatches( ), ) - async def StreamMicrogridDispatches( + def StreamMicrogridDispatches( self, request: StreamMicrogridDispatchesRequest, timeout: int = 5, # pylint: disable=unused-argument - ) -> AsyncIterator[StreamMicrogridDispatchesResponse]: + ) -> _MockStream[StreamMicrogridDispatchesResponse]: """Stream microgrid dispatches changes. Args: @@ -122,20 +149,28 @@ async def StreamMicrogridDispatches( Returns: An async generator for dispatch changes. - - Yields: - An event for each dispatch change. """ - receiver = self._stream_channel.new_receiver() - - async for message in receiver: - _logger.debug("Received message: %s", message) - if message.microgrid_id == MicrogridId(request.microgrid_id): - response = StreamMicrogridDispatchesResponse( - event=message.event.event.value, - dispatch=message.event.dispatch.to_protobuf(), - ) - yield response + + async def stream() -> AsyncIterator[StreamMicrogridDispatchesResponse]: + """Stream microgrid dispatches changes.""" + _logger.debug("Starting stream for microgrid %s", request.microgrid_id) + receiver = self._stream_channel.new_receiver() + + async for message in receiver: + _logger.debug("Received message: %s", message) + if message.microgrid_id == MicrogridId(request.microgrid_id): + response = StreamMicrogridDispatchesResponse( + event=message.event.event.value, + dispatch=message.event.dispatch.to_protobuf(), + ) + yield response + else: + _logger.debug( + "Skipping message for microgrid %s", + message.microgrid_id, + ) + + return _MockStream(stream()) # pylint: disable=too-many-branches @staticmethod