Skip to content

Commit b25fd29

Browse files
committed
Streamer: Write out events on the channels as well
Signed-off-by: Mathias L. Baumann <[email protected]>
1 parent 1eafda5 commit b25fd29

File tree

3 files changed

+122
-17
lines changed

3 files changed

+122
-17
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## New Features
1212

13-
<!-- Here goes the main new features and examples or instructions on how to use them -->
13+
* The streaming client now also sends state change events out.
1414

1515
## Bug Fixes
1616

src/frequenz/client/base/streaming.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import asyncio
77
import logging
88
from collections.abc import Callable
9+
from dataclasses import dataclass
10+
from enum import StrEnum
911
from typing import AsyncIterable, Generic, TypeVar
1012

1113
import grpc.aio
@@ -24,6 +26,33 @@
2426
"""The output type of the stream."""
2527

2628

29+
class State(StrEnum):
30+
"""State of the GrpcStreamBroadcaster."""
31+
32+
CONNECTING = "connecting"
33+
"""The broadcaster is connecting to the stream."""
34+
35+
CONNECTED = "connected"
36+
"""The broadcaster is connected to the stream."""
37+
38+
DISCONNECTED = "disconnected"
39+
"""The broadcaster is disconnected from the stream."""
40+
41+
42+
@dataclass(frozen=True, kw_only=True)
43+
class Message:
44+
"""Message sent by the broadcaster."""
45+
46+
state: State
47+
"""State change of the broadcaster."""
48+
49+
error: Exception | None = None
50+
"""An error that occurred during streaming, if any. Defaults to None."""
51+
52+
53+
MessageT = TypeVar("MessageT", bound=Message)
54+
55+
2756
class GrpcStreamBroadcaster(Generic[InputT, OutputT]):
2857
"""Helper class to handle grpc streaming methods."""
2958

@@ -55,14 +84,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument
5584
)
5685
self._retry_on_exhausted_stream = retry_on_exhausted_stream
5786

58-
self._channel: channels.Broadcast[OutputT] = channels.Broadcast(
87+
self._channel: channels.Broadcast[Message | OutputT] = channels.Broadcast(
5988
name=f"GrpcStreamBroadcaster-{stream_name}"
6089
)
6190
self._task = asyncio.create_task(self._run())
6291

6392
def new_receiver(
6493
self, maxsize: int = 50, warn_on_overflow: bool = True
65-
) -> channels.Receiver[OutputT]:
94+
) -> channels.Receiver[Message | OutputT]:
6695
"""Create a new receiver for the stream.
6796
6897
Args:
@@ -107,20 +136,16 @@ async def _run(self) -> None:
107136
_logger.info("%s: starting to stream", self._stream_name)
108137
try:
109138
call = self._stream_method()
139+
await sender.send(Message(state=State.CONNECTED, error=None))
110140
async for msg in call:
111141
await sender.send(self._transform(msg))
112142
except grpc.aio.AioRpcError as err:
113143
error = err
114-
except Exception as err: # pylint: disable=broad-except
115-
_logger.exception(
116-
"%s: raise an unexpected exception",
117-
self._stream_name,
118-
)
119-
error = err
120144
if error is None and not self._retry_on_exhausted_stream:
121145
_logger.info(
122146
"%s: connection closed, stream exhausted", self._stream_name
123147
)
148+
await sender.send(Message(state=State.DISCONNECTED, error=None))
124149
await self._channel.close()
125150
break
126151
error_str = f"Error: {error}" if error else "Stream exhausted"
@@ -132,6 +157,7 @@ async def _run(self) -> None:
132157
self._retry_strategy.get_progress(),
133158
error_str,
134159
)
160+
await sender.send(Message(state=State.DISCONNECTED, error=error))
135161
await self._channel.close()
136162
break
137163
_logger.warning(
@@ -141,4 +167,6 @@ async def _run(self) -> None:
141167
interval,
142168
error_str,
143169
)
170+
await sender.send(Message(state=State.DISCONNECTED, error=error))
171+
await sender.send(Message(state=State.CONNECTING, error=None))
144172
await asyncio.sleep(interval)

tests/streaming/test_grpc_stream_broadcaster.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
import grpc.aio
1313
import pytest
14+
from frequenz.channels import Receiver
1415

1516
from frequenz.client.base import retry, streaming
17+
from frequenz.client.base.streaming import Message, State
1618

1719

1820
def _transformer(x: int) -> str:
@@ -62,6 +64,28 @@ async def asynciter(ready_event: asyncio.Event) -> AsyncIterator[int]:
6264
await helper.stop()
6365

6466

67+
async def _split_message(
68+
receiver: Receiver[Message | str],
69+
) -> tuple[list[str], list[Message]]:
70+
"""Split the items received from the receiver into items and messages.
71+
72+
Args:
73+
receiver: The receiver to process.
74+
75+
Returns:
76+
A tuple containing a list of transformed items and a list of messages.
77+
"""
78+
items: list[str] = []
79+
events: list[Message] = []
80+
async for item in receiver:
81+
match item:
82+
case Message():
83+
events.append(item)
84+
case str():
85+
items.append(item)
86+
return items, events
87+
88+
6589
class _ErroringAsyncIter(AsyncIterator[int]):
6690
"""Async iterator that raises an error after a certain number of successes."""
6791

@@ -93,11 +117,12 @@ async def test_streaming_success_retry_on_exhausted(
93117
"""Test streaming success."""
94118
caplog.set_level(logging.INFO)
95119
items: list[str] = []
120+
events: list[Message] = []
96121
async with asyncio.timeout(1):
97122
receiver = ok_helper.new_receiver()
98123
receiver_ready_event.set()
99-
async for item in receiver:
100-
items.append(item)
124+
items, events = await _split_message(receiver)
125+
101126
no_retry.next_interval.assert_called_once_with()
102127
assert items == [
103128
"transformed_0",
@@ -106,6 +131,10 @@ async def test_streaming_success_retry_on_exhausted(
106131
"transformed_3",
107132
"transformed_4",
108133
]
134+
assert events == [
135+
Message(state=State.DISCONNECTED, error=None),
136+
]
137+
109138
assert caplog.record_tuples == [
110139
(
111140
"frequenz.client.base.streaming",
@@ -128,11 +157,13 @@ async def test_streaming_success(
128157
"""Test streaming success."""
129158
caplog.set_level(logging.INFO)
130159
items: list[str] = []
160+
events: list[Message] = []
161+
131162
async with asyncio.timeout(1):
132163
receiver = ok_helper.new_receiver()
133164
receiver_ready_event.set()
134-
async for item in receiver:
135-
items.append(item)
165+
items, events = await _split_message(receiver)
166+
136167
assert (
137168
no_retry.next_interval.call_count == 0
138169
), "next_interval should not be called when streaming is successful"
@@ -144,6 +175,9 @@ async def test_streaming_success(
144175
"transformed_3",
145176
"transformed_4",
146177
]
178+
assert events == [
179+
Message(state=State.DISCONNECTED, error=None),
180+
]
147181
assert caplog.record_tuples == [
148182
(
149183
"frequenz.client.base.streaming",
@@ -191,13 +225,13 @@ async def test_streaming_error( # pylint: disable=too-many-arguments
191225
)
192226

193227
items: list[str] = []
228+
events: list[Message] = []
194229
async with AsyncExitStack() as stack:
195230
stack.push_async_callback(helper.stop)
196231

197232
receiver = helper.new_receiver()
198233
receiver_ready_event.set()
199-
async for item in receiver:
200-
items.append(item)
234+
items, events = await _split_message(receiver)
201235

202236
no_retry.next_interval.assert_called_once_with()
203237
assert items == [f"transformed_{i}" for i in range(successes)]
@@ -251,13 +285,13 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
251285
)
252286

253287
items: list[str] = []
288+
events: list[Message] = []
254289
async with AsyncExitStack() as stack:
255290
stack.push_async_callback(helper.stop)
256291

257292
receiver = helper.new_receiver()
258293
receiver_ready_event.set()
259-
async for item in receiver:
260-
items.append(item)
294+
items, events = await _split_message(receiver)
261295

262296
assert not items
263297
assert mock_retry.next_interval.mock_calls == [mock.call(), mock.call()]
@@ -282,3 +316,46 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
282316
f"giving up. Error: {expected_error_str}.",
283317
),
284318
]
319+
320+
321+
async def test_messages_on_retry(
322+
receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name
323+
) -> None:
324+
"""Test that messages are sent on retry."""
325+
helper = streaming.GrpcStreamBroadcaster(
326+
stream_name="test_helper",
327+
stream_method=lambda: _ErroringAsyncIter(
328+
grpc.aio.AioRpcError(
329+
code=_NamedMagicMock(name="mock grpc code"),
330+
initial_metadata=mock.MagicMock(),
331+
trailing_metadata=mock.MagicMock(),
332+
details="mock details",
333+
debug_error_string="mock debug_error_string",
334+
),
335+
receiver_ready_event,
336+
),
337+
transform=_transformer,
338+
retry_strategy=retry.LinearBackoff(
339+
limit=1,
340+
interval=0.01,
341+
),
342+
retry_on_exhausted_stream=True,
343+
)
344+
345+
items: list[str] = []
346+
events: list[Message] = []
347+
async with AsyncExitStack() as stack:
348+
stack.push_async_callback(helper.stop)
349+
350+
receiver = helper.new_receiver()
351+
receiver_ready_event.set()
352+
items, events = await _split_message(receiver)
353+
354+
assert items == []
355+
assert [e.state for e in events] == [
356+
State.CONNECTED,
357+
State.DISCONNECTED,
358+
State.CONNECTING,
359+
State.CONNECTED,
360+
State.DISCONNECTED,
361+
]

0 commit comments

Comments
 (0)