1111
1212import grpc .aio
1313import pytest
14+ from frequenz .channels import Receiver
1415
1516from frequenz .client .base import retry , streaming
17+ from frequenz .client .base .streaming import Message , State
1618
1719
1820def _transformer (x : int ) -> str :
@@ -62,6 +64,28 @@ async def asynciter(ready_event: asyncio.Event) -> AsyncIterator[int]:
6264 await helper .stop ()
6365
6466
67+ async def _split_message (
68+ receiver : Receiver [Message | str ],
69+ ) -> tuple [list [str ], list [Message ]]:
70+ """Split the items received from the receiver into items and messages.
71+
72+ Args:
73+ receiver: The receiver to process.
74+
75+ Returns:
76+ A tuple containing a list of transformed items and a list of messages.
77+ """
78+ items : list [str ] = []
79+ events : list [Message ] = []
80+ async for item in receiver :
81+ match item :
82+ case Message ():
83+ events .append (item )
84+ case str ():
85+ items .append (item )
86+ return items , events
87+
88+
6589class _ErroringAsyncIter (AsyncIterator [int ]):
6690 """Async iterator that raises an error after a certain number of successes."""
6791
@@ -93,11 +117,12 @@ async def test_streaming_success_retry_on_exhausted(
93117 """Test streaming success."""
94118 caplog .set_level (logging .INFO )
95119 items : list [str ] = []
120+ events : list [Message ] = []
96121 async with asyncio .timeout (1 ):
97122 receiver = ok_helper .new_receiver ()
98123 receiver_ready_event .set ()
99- async for item in receiver :
100- items . append ( item )
124+ items , events = await _split_message ( receiver )
125+
101126 no_retry .next_interval .assert_called_once_with ()
102127 assert items == [
103128 "transformed_0" ,
@@ -106,6 +131,10 @@ async def test_streaming_success_retry_on_exhausted(
106131 "transformed_3" ,
107132 "transformed_4" ,
108133 ]
134+ assert events == [
135+ Message (state = State .DISCONNECTED , error = None ),
136+ ]
137+
109138 assert caplog .record_tuples == [
110139 (
111140 "frequenz.client.base.streaming" ,
@@ -128,11 +157,13 @@ async def test_streaming_success(
128157 """Test streaming success."""
129158 caplog .set_level (logging .INFO )
130159 items : list [str ] = []
160+ events : list [Message ] = []
161+
131162 async with asyncio .timeout (1 ):
132163 receiver = ok_helper .new_receiver ()
133164 receiver_ready_event .set ()
134- async for item in receiver :
135- items . append ( item )
165+ items , events = await _split_message ( receiver )
166+
136167 assert (
137168 no_retry .next_interval .call_count == 0
138169 ), "next_interval should not be called when streaming is successful"
@@ -144,6 +175,9 @@ async def test_streaming_success(
144175 "transformed_3" ,
145176 "transformed_4" ,
146177 ]
178+ assert events == [
179+ Message (state = State .DISCONNECTED , error = None ),
180+ ]
147181 assert caplog .record_tuples == [
148182 (
149183 "frequenz.client.base.streaming" ,
@@ -191,13 +225,13 @@ async def test_streaming_error( # pylint: disable=too-many-arguments
191225 )
192226
193227 items : list [str ] = []
228+ events : list [Message ] = []
194229 async with AsyncExitStack () as stack :
195230 stack .push_async_callback (helper .stop )
196231
197232 receiver = helper .new_receiver ()
198233 receiver_ready_event .set ()
199- async for item in receiver :
200- items .append (item )
234+ items , events = await _split_message (receiver )
201235
202236 no_retry .next_interval .assert_called_once_with ()
203237 assert items == [f"transformed_{ i } " for i in range (successes )]
@@ -251,13 +285,13 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
251285 )
252286
253287 items : list [str ] = []
288+ events : list [Message ] = []
254289 async with AsyncExitStack () as stack :
255290 stack .push_async_callback (helper .stop )
256291
257292 receiver = helper .new_receiver ()
258293 receiver_ready_event .set ()
259- async for item in receiver :
260- items .append (item )
294+ items , events = await _split_message (receiver )
261295
262296 assert not items
263297 assert mock_retry .next_interval .mock_calls == [mock .call (), mock .call ()]
@@ -282,3 +316,46 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
282316 f"giving up. Error: { expected_error_str } ." ,
283317 ),
284318 ]
319+
320+
321+ async def test_messages_on_retry (
322+ receiver_ready_event : asyncio .Event , # pylint: disable=redefined-outer-name
323+ ) -> None :
324+ """Test that messages are sent on retry."""
325+ helper = streaming .GrpcStreamBroadcaster (
326+ stream_name = "test_helper" ,
327+ stream_method = lambda : _ErroringAsyncIter (
328+ grpc .aio .AioRpcError (
329+ code = _NamedMagicMock (name = "mock grpc code" ),
330+ initial_metadata = mock .MagicMock (),
331+ trailing_metadata = mock .MagicMock (),
332+ details = "mock details" ,
333+ debug_error_string = "mock debug_error_string" ,
334+ ),
335+ receiver_ready_event ,
336+ ),
337+ transform = _transformer ,
338+ retry_strategy = retry .LinearBackoff (
339+ limit = 1 ,
340+ interval = 0.01 ,
341+ ),
342+ retry_on_exhausted_stream = True ,
343+ )
344+
345+ items : list [str ] = []
346+ events : list [Message ] = []
347+ async with AsyncExitStack () as stack :
348+ stack .push_async_callback (helper .stop )
349+
350+ receiver = helper .new_receiver ()
351+ receiver_ready_event .set ()
352+ items , events = await _split_message (receiver )
353+
354+ assert items == []
355+ assert [e .state for e in events ] == [
356+ State .CONNECTED ,
357+ State .DISCONNECTED ,
358+ State .CONNECTING ,
359+ State .CONNECTED ,
360+ State .DISCONNECTED ,
361+ ]
0 commit comments