1313# limitations under the License.
1414
1515import asyncio
16- import unittest
1716from unittest import mock
1817
1918import pytest
@@ -30,9 +29,11 @@ def _is_retriable(exc):
3029DEFAULT_TEST_RETRY = AsyncRetry (predicate = _is_retriable , deadline = 1 )
3130
3231
33- class TestBidiStreamRetryManager (unittest .IsolatedAsyncioTestCase ):
32+ class TestBidiStreamRetryManager :
33+ @pytest .mark .asyncio
3434 async def test_execute_success_on_first_try (self ):
3535 mock_strategy = mock .AsyncMock (spec = base_strategy ._BaseResumptionStrategy )
36+
3637 async def mock_stream_opener (* args , ** kwargs ):
3738 yield "response_1"
3839
@@ -41,12 +42,33 @@ async def mock_stream_opener(*args, **kwargs):
4142 )
4243 await retry_manager .execute (initial_state = {}, retry_policy = DEFAULT_TEST_RETRY )
4344 mock_strategy .generate_requests .assert_called_once ()
44- mock_strategy .update_state_from_response .assert_called_once_with ("response_1" , {})
45+ mock_strategy .update_state_from_response .assert_called_once_with (
46+ "response_1" , {}
47+ )
4548 mock_strategy .recover_state_on_failure .assert_not_called ()
4649
47- async def test_execute_retries_and_succeeds (self ):
50+ @pytest .mark .asyncio
51+ async def test_execute_success_on_empty_stream (self ):
52+ mock_strategy = mock .AsyncMock (spec = base_strategy ._BaseResumptionStrategy )
53+
54+ async def mock_stream_opener (* args , ** kwargs ):
55+ if False :
56+ yield
57+
58+ retry_manager = manager ._BidiStreamRetryManager (
59+ strategy = mock_strategy , stream_opener = mock_stream_opener
60+ )
61+ await retry_manager .execute (initial_state = {}, retry_policy = DEFAULT_TEST_RETRY )
62+
63+ mock_strategy .generate_requests .assert_called_once ()
64+ mock_strategy .update_state_from_response .assert_not_called ()
65+ mock_strategy .recover_state_on_failure .assert_not_called ()
66+
67+ @pytest .mark .asyncio
68+ async def test_execute_retries_on_initial_failure_and_succeeds (self ):
4869 mock_strategy = mock .AsyncMock (spec = base_strategy ._BaseResumptionStrategy )
4970 attempt_count = 0
71+
5072 async def mock_stream_opener (* args , ** kwargs ):
5173 nonlocal attempt_count
5274 attempt_count += 1
@@ -59,17 +81,63 @@ async def mock_stream_opener(*args, **kwargs):
5981 strategy = mock_strategy , stream_opener = mock_stream_opener
6082 )
6183 retry_policy = AsyncRetry (predicate = _is_retriable , initial = 0.01 )
62- retry_policy .sleep = mock .AsyncMock ()
6384
64- await retry_manager .execute (initial_state = {}, retry_policy = retry_policy )
85+ with mock .patch ("asyncio.sleep" , new_callable = mock .AsyncMock ):
86+ await retry_manager .execute (initial_state = {}, retry_policy = retry_policy )
87+
88+ assert attempt_count == 2
89+ assert mock_strategy .generate_requests .call_count == 2
90+ mock_strategy .recover_state_on_failure .assert_called_once ()
91+ mock_strategy .update_state_from_response .assert_called_once_with (
92+ "response_2" , {}
93+ )
94+
95+ @pytest .mark .asyncio
96+ async def test_execute_retries_and_succeeds_mid_stream (self ):
97+ """Test retry logic for a stream that fails after yielding some data."""
98+ mock_strategy = mock .AsyncMock (spec = base_strategy ._BaseResumptionStrategy )
99+ attempt_count = 0
100+ # Use a list to simulate stream content for each attempt
101+ stream_content = [
102+ ["response_1" , exceptions .ServiceUnavailable ("Service is down" )],
103+ ["response_2" ],
104+ ]
105+
106+ async def mock_stream_opener (* args , ** kwargs ):
107+ nonlocal attempt_count
108+ content = stream_content [attempt_count ]
109+ attempt_count += 1
110+ for item in content :
111+ if isinstance (item , Exception ):
112+ raise item
113+ else :
114+ yield item
115+
116+ retry_manager = manager ._BidiStreamRetryManager (
117+ strategy = mock_strategy , stream_opener = mock_stream_opener
118+ )
119+ retry_policy = AsyncRetry (predicate = _is_retriable , initial = 0.01 )
120+
121+ with mock .patch ("asyncio.sleep" , new_callable = mock .AsyncMock ) as mock_sleep :
122+ await retry_manager .execute (initial_state = {}, retry_policy = retry_policy )
123+
124+ assert attempt_count == 2
125+ mock_sleep .assert_called_once ()
65126
66- self .assertEqual (attempt_count , 2 )
67- self .assertEqual (mock_strategy .generate_requests .call_count , 2 )
127+ assert mock_strategy .generate_requests .call_count == 2
68128 mock_strategy .recover_state_on_failure .assert_called_once ()
69- mock_strategy .update_state_from_response .assert_called_once_with ("response_2" , {})
129+ assert mock_strategy .update_state_from_response .call_count == 2
130+ mock_strategy .update_state_from_response .assert_has_calls (
131+ [
132+ mock .call ("response_1" , {}),
133+ mock .call ("response_2" , {}),
134+ ]
135+ )
70136
137+ @pytest .mark .asyncio
71138 async def test_execute_fails_after_deadline_exceeded (self ):
72139 mock_strategy = mock .AsyncMock (spec = base_strategy ._BaseResumptionStrategy )
140+
73141 async def mock_stream_opener (* args , ** kwargs ):
74142 if False :
75143 yield
@@ -79,13 +147,15 @@ async def mock_stream_opener(*args, **kwargs):
79147 retry_manager = manager ._BidiStreamRetryManager (
80148 strategy = mock_strategy , stream_opener = mock_stream_opener
81149 )
82- with pytest .raises (exceptions .RetryError , match = "Deadline of 0.01s exceeded" ):
150+ with pytest .raises (exceptions .RetryError , match = "Timeout of 0.0s exceeded" ):
83151 await retry_manager .execute (initial_state = {}, retry_policy = fast_retry )
84152
85- self . assertGreater ( mock_strategy .recover_state_on_failure .call_count , 0 )
153+ mock_strategy .recover_state_on_failure .assert_called_once ( )
86154
155+ @pytest .mark .asyncio
87156 async def test_execute_fails_immediately_on_non_retriable_error (self ):
88157 mock_strategy = mock .AsyncMock (spec = base_strategy ._BaseResumptionStrategy )
158+
89159 async def mock_stream_opener (* args , ** kwargs ):
90160 if False :
91161 yield
0 commit comments