|
6 | 6 | import asyncio |
7 | 7 | import logging |
8 | 8 | from collections.abc import Callable |
9 | | -from typing import AsyncIterable, Generic, TypeVar |
| 9 | +from dataclasses import dataclass |
| 10 | +from datetime import timedelta |
| 11 | +from typing import AsyncIterable, Generic, TypeAlias, TypeVar |
10 | 12 |
|
11 | 13 | import grpc.aio |
12 | 14 |
|
|
24 | 26 | """The output type of the stream.""" |
25 | 27 |
|
26 | 28 |
|
| 29 | +@dataclass(frozen=True, kw_only=True) |
| 30 | +class StreamStartedEvent: |
| 31 | + """Event indicating that the stream has started.""" |
| 32 | + |
| 33 | + |
| 34 | +@dataclass(frozen=True, kw_only=True) |
| 35 | +class StreamStoppedEvent: |
| 36 | + """Event indicating that the stream has stopped.""" |
| 37 | + |
| 38 | + retry_time: timedelta | None = None |
| 39 | + """Time to wait before retrying the stream, if applicable.""" |
| 40 | + |
| 41 | + exception: Exception | None = None |
| 42 | + """The exception that caused the stream to stop, if any.""" |
| 43 | + |
| 44 | + |
| 45 | +StreamEvent: TypeAlias = StreamStartedEvent | StreamStoppedEvent |
| 46 | +"""Type alias for the events that can be sent over the stream.""" |
| 47 | + |
| 48 | + |
| 49 | +# Ignore D412: "No blank lines allowed between a section header and its content" |
| 50 | +# flake8: noqa: D412 |
27 | 51 | class GrpcStreamBroadcaster(Generic[InputT, OutputT]): |
28 | | - """Helper class to handle grpc streaming methods.""" |
| 52 | + """Helper class to handle grpc streaming methods. |
| 53 | +
|
| 54 | + This class handles the grpc streaming methods, automatically reconnecting |
| 55 | + when the connection is lost, and broadcasting the received messages to |
| 56 | + multiple receivers. |
| 57 | +
|
| 58 | + The stream is started when the class is initialized, and can be stopped |
| 59 | + with the `stop` method. New receivers can be created with the |
| 60 | + `new_receiver` method, which will receive the streamed messages. |
| 61 | +
|
| 62 | + Additionally to the transformed messages, the broadcaster will also send |
| 63 | + state change messages indicating whether the stream is connecting, |
| 64 | + connected, or disconnected. These messages can be used to monitor the |
| 65 | + state of the stream. |
| 66 | +
|
| 67 | + Example: |
| 68 | +
|
| 69 | + ```python |
| 70 | + from frequenz.client.base import GrpcStreamBroadcaster |
| 71 | +
|
| 72 | + def async_range() -> AsyncIterable[int]: |
| 73 | + yield from range(10) |
| 74 | +
|
| 75 | + streamer = GrpcStreamBroadcaster( |
| 76 | + stream_name="example_stream", |
| 77 | + stream_method=async_range, |
| 78 | + transform=lambda msg: msg, |
| 79 | + ) |
| 80 | +
|
| 81 | + recv = streamer.new_receiver() |
| 82 | +
|
| 83 | + for msg in recv: |
| 84 | + match msg: |
| 85 | + case StreamStartedEvent(): |
| 86 | + print("Stream started") |
| 87 | + case StreamStoppedEvent() as event: |
| 88 | + print(f"Stream stopped, reason {event.exception}, retry in {event.retry_time}") |
| 89 | + case int() as output: |
| 90 | + print(f"Received message: {output}") |
| 91 | + ``` |
| 92 | + """ |
29 | 93 |
|
30 | 94 | def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments |
31 | 95 | self, |
@@ -55,14 +119,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument |
55 | 119 | ) |
56 | 120 | self._retry_on_exhausted_stream = retry_on_exhausted_stream |
57 | 121 |
|
58 | | - self._channel: channels.Broadcast[OutputT] = channels.Broadcast( |
| 122 | + self._channel: channels.Broadcast[StreamEvent | OutputT] = channels.Broadcast( |
59 | 123 | name=f"GrpcStreamBroadcaster-{stream_name}" |
60 | 124 | ) |
61 | 125 | self._task = asyncio.create_task(self._run()) |
62 | 126 |
|
63 | 127 | def new_receiver( |
64 | 128 | self, maxsize: int = 50, warn_on_overflow: bool = True |
65 | | - ) -> channels.Receiver[OutputT]: |
| 129 | + ) -> channels.Receiver[StreamEvent | OutputT]: |
66 | 130 | """Create a new receiver for the stream. |
67 | 131 |
|
68 | 132 | Args: |
@@ -107,24 +171,28 @@ async def _run(self) -> None: |
107 | 171 | _logger.info("%s: starting to stream", self._stream_name) |
108 | 172 | try: |
109 | 173 | call = self._stream_method() |
| 174 | + await sender.send(StreamStartedEvent()) |
110 | 175 | async for msg in call: |
111 | 176 | await sender.send(self._transform(msg)) |
112 | 177 | except grpc.aio.AioRpcError as err: |
113 | 178 | error = err |
114 | | - except Exception as err: # pylint: disable=broad-except |
115 | | - _logger.exception( |
116 | | - "%s: raise an unexpected exception", |
117 | | - self._stream_name, |
| 179 | + |
| 180 | + interval = self._retry_strategy.next_interval() |
| 181 | + |
| 182 | + await sender.send( |
| 183 | + StreamStoppedEvent( |
| 184 | + retry_time=timedelta(seconds=interval) if interval else None, |
| 185 | + exception=error, |
118 | 186 | ) |
119 | | - error = err |
| 187 | + ) |
| 188 | + |
120 | 189 | if error is None and not self._retry_on_exhausted_stream: |
121 | 190 | _logger.info( |
122 | 191 | "%s: connection closed, stream exhausted", self._stream_name |
123 | 192 | ) |
124 | 193 | await self._channel.close() |
125 | 194 | break |
126 | 195 | error_str = f"Error: {error}" if error else "Stream exhausted" |
127 | | - interval = self._retry_strategy.next_interval() |
128 | 196 | if interval is None: |
129 | 197 | _logger.error( |
130 | 198 | "%s: connection ended, retry limit exceeded (%s), giving up. %s.", |
|
0 commit comments