Skip to content

Commit f033515

Browse files
committed
Add more unit tests
1 parent 47fdd70 commit f033515

File tree

2 files changed

+89
-42
lines changed

2 files changed

+89
-42
lines changed

google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
import time
1716
from typing import Any, AsyncIterator, Callable
1817

1918
from google.api_core import exceptions
20-
from google.api_core.retry.retry_base import exponential_sleep_generator
2119
from google.cloud.storage._experimental.asyncio.retry.base_strategy import (
2220
_BaseResumptionStrategy,
2321
)
@@ -45,46 +43,25 @@ async def execute(self, initial_state: Any, retry_policy):
4543
"""
4644
Executes the bidi operation with the configured retry policy.
4745
48-
This method implements a manual retry loop that provides the necessary
49-
control points to manage state between attempts.
50-
5146
Args:
5247
initial_state: An object containing all state for the operation.
5348
retry_policy: The `google.api_core.retry.AsyncRetry` object to
5449
govern the retry behavior for this specific operation.
5550
"""
5651
state = initial_state
5752

58-
deadline = time.monotonic() + retry_policy._deadline if retry_policy._deadline else 0
59-
60-
sleep_generator = exponential_sleep_generator(
61-
retry_policy._initial, retry_policy._maximum, retry_policy._multiplier
62-
)
63-
64-
while True:
53+
async def attempt():
54+
requests = self._strategy.generate_requests(state)
55+
stream = self._stream_opener(requests, state)
6556
try:
66-
requests = self._strategy.generate_requests(state)
67-
stream = self._stream_opener(requests, state)
6857
async for response in stream:
6958
self._strategy.update_state_from_response(response, state)
7059
return
7160
except Exception as e:
72-
# AsyncRetry may expose either 'on_error' (public) or the private
73-
# '_on_error' depending on google.api_core version. Call whichever
74-
# exists so the retry policy can decide to raise (non-retriable /
75-
# deadline exceeded) or allow a retry.
76-
on_error_callable = getattr(retry_policy, "on_error", None)
77-
if on_error_callable is None:
78-
on_error_callable = getattr(retry_policy, "_on_error", None)
79-
80-
if on_error_callable is None:
81-
# No hook available on the policy; re-raise the error.
82-
raise
61+
if retry_policy._predicate(e):
62+
await self._strategy.recover_state_on_failure(e, state)
63+
raise e
8364

84-
# Let the retry policy handle the error (may raise RetryError).
85-
await on_error_callable(e)
65+
wrapped_attempt = retry_policy(attempt)
8666

87-
# If the retry policy did not raise, allow the strategy to recover
88-
# and then sleep per policy before next attempt.
89-
await self._strategy.recover_state_on_failure(e, state)
90-
await retry_policy.sleep()
67+
await wrapped_attempt()

tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py

Lines changed: 81 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
import unittest
1716
from unittest import mock
1817

1918
import pytest
@@ -30,9 +29,11 @@ def _is_retriable(exc):
3029
DEFAULT_TEST_RETRY = AsyncRetry(predicate=_is_retriable, deadline=1)
3130

3231

33-
class TestBidiStreamRetryManager(unittest.IsolatedAsyncioTestCase):
32+
class TestBidiStreamRetryManager:
33+
@pytest.mark.asyncio
3434
async def test_execute_success_on_first_try(self):
3535
mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy)
36+
3637
async def mock_stream_opener(*args, **kwargs):
3738
yield "response_1"
3839

@@ -41,12 +42,33 @@ async def mock_stream_opener(*args, **kwargs):
4142
)
4243
await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY)
4344
mock_strategy.generate_requests.assert_called_once()
44-
mock_strategy.update_state_from_response.assert_called_once_with("response_1", {})
45+
mock_strategy.update_state_from_response.assert_called_once_with(
46+
"response_1", {}
47+
)
4548
mock_strategy.recover_state_on_failure.assert_not_called()
4649

47-
async def test_execute_retries_and_succeeds(self):
50+
@pytest.mark.asyncio
51+
async def test_execute_success_on_empty_stream(self):
52+
mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy)
53+
54+
async def mock_stream_opener(*args, **kwargs):
55+
if False:
56+
yield
57+
58+
retry_manager = manager._BidiStreamRetryManager(
59+
strategy=mock_strategy, stream_opener=mock_stream_opener
60+
)
61+
await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY)
62+
63+
mock_strategy.generate_requests.assert_called_once()
64+
mock_strategy.update_state_from_response.assert_not_called()
65+
mock_strategy.recover_state_on_failure.assert_not_called()
66+
67+
@pytest.mark.asyncio
68+
async def test_execute_retries_on_initial_failure_and_succeeds(self):
4869
mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy)
4970
attempt_count = 0
71+
5072
async def mock_stream_opener(*args, **kwargs):
5173
nonlocal attempt_count
5274
attempt_count += 1
@@ -59,17 +81,63 @@ async def mock_stream_opener(*args, **kwargs):
5981
strategy=mock_strategy, stream_opener=mock_stream_opener
6082
)
6183
retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01)
62-
retry_policy.sleep = mock.AsyncMock()
6384

64-
await retry_manager.execute(initial_state={}, retry_policy=retry_policy)
85+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock):
86+
await retry_manager.execute(initial_state={}, retry_policy=retry_policy)
87+
88+
assert attempt_count == 2
89+
assert mock_strategy.generate_requests.call_count == 2
90+
mock_strategy.recover_state_on_failure.assert_called_once()
91+
mock_strategy.update_state_from_response.assert_called_once_with(
92+
"response_2", {}
93+
)
94+
95+
@pytest.mark.asyncio
96+
async def test_execute_retries_and_succeeds_mid_stream(self):
97+
"""Test retry logic for a stream that fails after yielding some data."""
98+
mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy)
99+
attempt_count = 0
100+
# Use a list to simulate stream content for each attempt
101+
stream_content = [
102+
["response_1", exceptions.ServiceUnavailable("Service is down")],
103+
["response_2"],
104+
]
105+
106+
async def mock_stream_opener(*args, **kwargs):
107+
nonlocal attempt_count
108+
content = stream_content[attempt_count]
109+
attempt_count += 1
110+
for item in content:
111+
if isinstance(item, Exception):
112+
raise item
113+
else:
114+
yield item
115+
116+
retry_manager = manager._BidiStreamRetryManager(
117+
strategy=mock_strategy, stream_opener=mock_stream_opener
118+
)
119+
retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01)
120+
121+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as mock_sleep:
122+
await retry_manager.execute(initial_state={}, retry_policy=retry_policy)
123+
124+
assert attempt_count == 2
125+
mock_sleep.assert_called_once()
65126

66-
self.assertEqual(attempt_count, 2)
67-
self.assertEqual(mock_strategy.generate_requests.call_count, 2)
127+
assert mock_strategy.generate_requests.call_count == 2
68128
mock_strategy.recover_state_on_failure.assert_called_once()
69-
mock_strategy.update_state_from_response.assert_called_once_with("response_2", {})
129+
assert mock_strategy.update_state_from_response.call_count == 2
130+
mock_strategy.update_state_from_response.assert_has_calls(
131+
[
132+
mock.call("response_1", {}),
133+
mock.call("response_2", {}),
134+
]
135+
)
70136

137+
@pytest.mark.asyncio
71138
async def test_execute_fails_after_deadline_exceeded(self):
72139
mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy)
140+
73141
async def mock_stream_opener(*args, **kwargs):
74142
if False:
75143
yield
@@ -79,13 +147,15 @@ async def mock_stream_opener(*args, **kwargs):
79147
retry_manager = manager._BidiStreamRetryManager(
80148
strategy=mock_strategy, stream_opener=mock_stream_opener
81149
)
82-
with pytest.raises(exceptions.RetryError, match="Deadline of 0.01s exceeded"):
150+
with pytest.raises(exceptions.RetryError, match="Timeout of 0.0s exceeded"):
83151
await retry_manager.execute(initial_state={}, retry_policy=fast_retry)
84152

85-
self.assertGreater(mock_strategy.recover_state_on_failure.call_count, 0)
153+
mock_strategy.recover_state_on_failure.assert_called_once()
86154

155+
@pytest.mark.asyncio
87156
async def test_execute_fails_immediately_on_non_retriable_error(self):
88157
mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy)
158+
89159
async def mock_stream_opener(*args, **kwargs):
90160
if False:
91161
yield

0 commit comments

Comments
 (0)