Skip to content

Commit 006c5e2

Browse files
committed
Implement feedback and fix type errors
1 parent 989a343 commit 006c5e2

File tree

3 files changed

+128
-104
lines changed

3 files changed

+128
-104
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "breaking",
3+
"description": "Added standard retry mode as the default retry strategy for AWS clients."
4+
}

packages/smithy-core/src/smithy_core/retries.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,14 @@ def __init__(self, *, max_attempts: int = 3):
268268
:param max_attempts: Upper limit on total number of attempts made, including
269269
initial attempt and retries.
270270
"""
271+
if max_attempts < 1:
272+
raise ValueError(
273+
f"max_attempts must be a positive integer, got {max_attempts}"
274+
)
275+
271276
self.backoff_strategy = ExponentialRetryBackoffStrategy(
272277
backoff_scale_value=1,
278+
max_backoff=20,
273279
jitter_type=ExponentialBackoffJitterType.FULL,
274280
)
275281
self.max_attempts = max_attempts
@@ -288,7 +294,7 @@ async def acquire_initial_retry_token(
288294
async def refresh_retry_token_for_retry(
289295
self,
290296
*,
291-
token_to_renew: StandardRetryToken,
297+
token_to_renew: retries_interface.RetryToken,
292298
error: Exception,
293299
) -> StandardRetryToken:
294300
"""Replace an existing retry token from a failed attempt with a new token.
@@ -300,6 +306,11 @@ async def refresh_retry_token_for_retry(
300306
:param error: The error that triggered the need for a retry.
301307
:raises RetryError: If no further retry attempts are allowed.
302308
"""
309+
if not isinstance(token_to_renew, StandardRetryToken):
310+
raise TypeError(
311+
f"StandardRetryStrategy requires StandardRetryToken, got {type(token_to_renew).__name__}"
312+
)
313+
303314
if isinstance(error, retries_interface.ErrorRetryInfo) and error.is_retry_safe:
304315
retry_count = token_to_renew.retry_count + 1
305316
if retry_count >= self.max_attempts:
@@ -310,7 +321,7 @@ async def refresh_retry_token_for_retry(
310321
# Acquire additional quota for this retry attempt
311322
# (may raise a RetryError if none is available)
312323
quota_acquired = await self._retry_quota.acquire(error=error)
313-
total_quota = token_to_renew.quota_consumed + quota_acquired
324+
total_quota: int = token_to_renew.quota_consumed + quota_acquired
314325

315326
if error.retry_after is not None:
316327
retry_delay = error.retry_after
@@ -328,24 +339,28 @@ async def refresh_retry_token_for_retry(
328339
else:
329340
raise RetryError(f"Error is not retryable: {error}") from error
330341

331-
async def record_success(self, *, token: StandardRetryToken) -> None:
342+
async def record_success(self, *, token: retries_interface.RetryToken) -> None:
332343
"""Return token after successful completion of an operation.
333344
334345
Releases retry tokens back to the retry quota based on the previous amount
335346
consumed.
336347
337348
:param token: The token used for the previous successful attempt.
338349
"""
350+
if not isinstance(token, StandardRetryToken):
351+
raise TypeError(
352+
f"StandardRetryStrategy requires StandardRetryToken, got {type(token).__name__}"
353+
)
339354
await self._retry_quota.release(release_amount=token.last_quota_acquired)
340355

341356

342357
class StandardRetryQuota:
343358
"""Retry quota used by :py:class:`StandardRetryStrategy`."""
344359

345-
INITIAL_RETRY_TOKENS = 500
346-
RETRY_COST = 5
347-
NO_RETRY_INCREMENT = 1
348-
TIMEOUT_RETRY_COST = 10
360+
INITIAL_RETRY_TOKENS: int = 500
361+
RETRY_COST: int = 5
362+
NO_RETRY_INCREMENT: int = 1
363+
TIMEOUT_RETRY_COST: int = 10
349364

350365
def __init__(self):
351366
self._max_capacity = self.INITIAL_RETRY_TOKENS
@@ -384,6 +399,11 @@ async def release(self, *, release_amount: int) -> None:
384399
self._available_capacity + increment, self._max_capacity
385400
)
386401

402+
@property
403+
def available_capacity(self) -> int:
404+
"""Return the amount of capacity available."""
405+
return self._available_capacity
406+
387407

388408
class RetryStrategyMode(Enum):
389409
"""Enumeration of available retry strategies."""

packages/smithy-core/tests/unit/test_retries.py

Lines changed: 97 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_exponential_backoff_strategy(
5959
assert delay_actual == pytest.approx(delay_expected) # type: ignore
6060

6161

62-
@pytest.mark.asyncio
6362
@pytest.mark.parametrize("max_attempts", [2, 3, 10])
6463
async def test_simple_retry_strategy(max_attempts: int) -> None:
6564
strategy = SimpleRetryStrategy(
@@ -76,7 +75,6 @@ async def test_simple_retry_strategy(max_attempts: int) -> None:
7675
await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)
7776

7877

79-
@pytest.mark.asyncio
8078
async def test_simple_retry_does_not_retry_unclassified() -> None:
8179
strategy = SimpleRetryStrategy(
8280
backoff_strategy=ExponentialRetryBackoffStrategy(backoff_scale_value=5),
@@ -89,7 +87,6 @@ async def test_simple_retry_does_not_retry_unclassified() -> None:
8987
)
9088

9189

92-
@pytest.mark.asyncio
9390
async def test_simple_retry_does_not_retry_when_safety_unknown() -> None:
9491
strategy = SimpleRetryStrategy(
9592
backoff_strategy=ExponentialRetryBackoffStrategy(backoff_scale_value=5),
@@ -101,7 +98,6 @@ async def test_simple_retry_does_not_retry_when_safety_unknown() -> None:
10198
await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)
10299

103100

104-
@pytest.mark.asyncio
105101
async def test_simple_retry_does_not_retry_unsafe() -> None:
106102
strategy = SimpleRetryStrategy(
107103
backoff_strategy=ExponentialRetryBackoffStrategy(backoff_scale_value=5),
@@ -113,7 +109,6 @@ async def test_simple_retry_does_not_retry_unsafe() -> None:
113109
await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)
114110

115111

116-
@pytest.mark.asyncio
117112
@pytest.mark.parametrize("max_attempts", [2, 3, 10])
118113
async def test_standard_retry_strategy(max_attempts: int) -> None:
119114
strategy = StandardRetryStrategy(max_attempts=max_attempts)
@@ -127,7 +122,6 @@ async def test_standard_retry_strategy(max_attempts: int) -> None:
127122
await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)
128123

129124

130-
@pytest.mark.asyncio
131125
async def test_standard_retry_does_not_retry_unclassified() -> None:
132126
strategy = StandardRetryStrategy()
133127
token = await strategy.acquire_initial_retry_token()
@@ -137,7 +131,6 @@ async def test_standard_retry_does_not_retry_unclassified() -> None:
137131
)
138132

139133

140-
@pytest.mark.asyncio
141134
async def test_standard_retry_does_not_retry_when_safety_unknown() -> None:
142135
strategy = StandardRetryStrategy()
143136
error = CallError(is_retry_safe=None)
@@ -146,7 +139,6 @@ async def test_standard_retry_does_not_retry_when_safety_unknown() -> None:
146139
await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)
147140

148141

149-
@pytest.mark.asyncio
150142
async def test_standard_retry_does_not_retry_unsafe() -> None:
151143
strategy = StandardRetryStrategy()
152144
error = CallError(fault="client", is_retry_safe=False)
@@ -155,133 +147,141 @@ async def test_standard_retry_does_not_retry_unsafe() -> None:
155147
await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)
156148

157149

158-
@pytest.mark.asyncio
159-
async def test_standard_retry_strategy_respects_max_attempts() -> None:
150+
async def test_standard_retry_after_overrides_backoff() -> None:
160151
strategy = StandardRetryStrategy()
161-
error = CallError(is_retry_safe=True)
152+
error = CallError(is_retry_safe=True, retry_after=5.5)
162153
token = await strategy.acquire_initial_retry_token()
163154
token = await strategy.refresh_retry_token_for_retry(
164155
token_to_renew=token, error=error
165156
)
166-
token = await strategy.refresh_retry_token_for_retry(
167-
token_to_renew=token, error=error
168-
)
169-
with pytest.raises(RetryError):
170-
await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)
157+
assert token.retry_delay == 5.5
171158

172159

173-
@pytest.mark.asyncio
174-
async def test_retry_after_overrides_backoff() -> None:
160+
async def test_standard_retry_quota_consumed_accumulates() -> None:
175161
strategy = StandardRetryStrategy()
176-
error = CallError(is_retry_safe=True, retry_after=5)
162+
error = CallError(is_retry_safe=True)
177163
token = await strategy.acquire_initial_retry_token()
164+
178165
token = await strategy.refresh_retry_token_for_retry(
179166
token_to_renew=token, error=error
180167
)
181-
assert token.retry_delay == 5
168+
first_consumed = token.quota_consumed
169+
assert first_consumed == StandardRetryQuota.RETRY_COST
182170

171+
token = await strategy.refresh_retry_token_for_retry(
172+
token_to_renew=token, error=error
173+
)
174+
assert token.quota_consumed == first_consumed + StandardRetryQuota.RETRY_COST
183175

184-
@pytest.mark.asyncio
185-
async def test_retry_quota_acquire_when_exhausted(monkeypatch) -> None:
186-
monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 5, raising=False)
187-
monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 2, raising=False)
188176

189-
quota = StandardRetryQuota()
190-
assert quota._available_capacity == 5
177+
async def test_standard_retry_invalid_max_attempts() -> None:
178+
with pytest.raises(ValueError):
179+
StandardRetryStrategy(max_attempts=0)
191180

192-
# First acquire: 5 -> 3
193-
assert await quota.acquire(error=Exception()) == 2
194-
assert quota._available_capacity == 3
181+
with pytest.raises(ValueError):
182+
StandardRetryStrategy(max_attempts=-1)
195183

196-
# Second acquire: 3 -> 1
197-
assert await quota.acquire(error=Exception()) == 2
198-
assert quota._available_capacity == 1
199184

200-
# Third acquire needs 2 but only 1 remains -> should raise
201-
with pytest.raises(RetryError):
202-
await quota.acquire(error=Exception())
203-
assert quota._available_capacity == 1
185+
async def test_standard_retry_record_success_without_retry() -> None:
186+
strategy = StandardRetryStrategy()
187+
token = await strategy.acquire_initial_retry_token()
188+
initial_capacity = strategy._retry_quota.available_capacity # pyright: ignore[reportPrivateUsage]
204189

190+
await strategy.record_success(token=token)
205191

206-
@pytest.mark.asyncio
207-
async def test_retry_quota_release_zero_adds_increment(monkeypatch) -> None:
208-
monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 5, raising=False)
209-
monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 2, raising=False)
210-
monkeypatch.setattr(StandardRetryQuota, "NO_RETRY_INCREMENT", 1, raising=False)
192+
# Should increment by NO_RETRY_INCREMENT
193+
expected = min(
194+
initial_capacity + StandardRetryQuota.NO_RETRY_INCREMENT,
195+
StandardRetryQuota.INITIAL_RETRY_TOKENS,
196+
)
197+
assert strategy._retry_quota.available_capacity == expected # pyright: ignore[reportPrivateUsage]
211198

212-
quota = StandardRetryQuota()
213-
assert quota._available_capacity == 5
214199

215-
# First acquire: 5 -> 3
216-
assert await quota.acquire(error=Exception()) == 2
217-
assert quota._available_capacity == 3
200+
async def test_standard_retry_record_success_with_retry() -> None:
201+
strategy = StandardRetryStrategy()
202+
error = CallError(is_retry_safe=True)
203+
token = await strategy.acquire_initial_retry_token()
218204

219-
# release 0 should add NO_RETRY_INCREMENT: 3 -> 4
220-
await quota.release(release_amount=0)
221-
assert quota._available_capacity == 4
205+
token = await strategy.refresh_retry_token_for_retry(
206+
token_to_renew=token, error=error
207+
)
208+
capacity_after_retry = strategy._retry_quota.available_capacity # pyright: ignore[reportPrivateUsage]
222209

223-
# Next acquire should still work: 4 -> 2
224-
assert await quota.acquire(error=Exception()) == 2
225-
assert quota._available_capacity == 2
210+
await strategy.record_success(token=token)
211+
212+
# Capacity should increase by last_quota_acquired
213+
assert (
214+
strategy._retry_quota.available_capacity # pyright: ignore[reportPrivateUsage]
215+
== capacity_after_retry + token.last_quota_acquired
216+
)
226217

227218

228-
@pytest.mark.asyncio
229-
async def test_retry_quota_release_caps_at_max(monkeypatch) -> None:
219+
@pytest.fixture
220+
def retry_quota(monkeypatch: pytest.MonkeyPatch) -> StandardRetryQuota:
230221
monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 10, raising=False)
231222
monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 3, raising=False)
223+
monkeypatch.setattr(StandardRetryQuota, "NO_RETRY_INCREMENT", 1, raising=False)
224+
return StandardRetryQuota()
232225

233-
quota = StandardRetryQuota()
234-
assert quota._available_capacity == 10
235226

236-
# Drain some capacity: 10 -> 7 -> 4
237-
assert await quota.acquire(error=Exception()) == 3
238-
assert quota._available_capacity == 7
239-
assert await quota.acquire(error=Exception()) == 3
240-
assert quota._available_capacity == 4
227+
async def test_retry_quota_initial_state(
228+
retry_quota: StandardRetryQuota,
229+
) -> None:
230+
assert retry_quota.available_capacity == 10
231+
assert retry_quota._max_capacity == 10 # pyright: ignore[reportPrivateUsage]
241232

242-
# Release more than needed: 4 + 8 = 12. Should cap at max = 10
243-
await quota.release(release_amount=8)
244-
assert quota._available_capacity == 10
245233

246-
# Another acquire should succeed from max: 10 -> 7
247-
assert await quota.acquire(error=Exception()) == 3
248-
assert quota._available_capacity == 7
234+
async def test_retry_quota_acquire_success(
235+
retry_quota: StandardRetryQuota,
236+
) -> None:
237+
acquired = await retry_quota.acquire(error=Exception())
249238

239+
assert acquired == 3
240+
assert retry_quota.available_capacity == 7
250241

251-
@pytest.mark.asyncio
252-
async def test_retry_quota_releases_last_acquired_amount(monkeypatch) -> None:
253-
monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 10, raising=False)
254-
monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 5, raising=False)
255242

256-
strategy = StandardRetryStrategy()
257-
err = CallError(is_retry_safe=True)
258-
token = await strategy.acquire_initial_retry_token()
243+
async def test_retry_quota_acquire_when_exhausted(
244+
retry_quota: StandardRetryQuota,
245+
) -> None:
246+
# Drain capacity: 10 -> 7 -> 4 -> 1
247+
await retry_quota.acquire(error=Exception())
248+
await retry_quota.acquire(error=Exception())
249+
await retry_quota.acquire(error=Exception())
250+
assert retry_quota.available_capacity == 1
259251

260-
# Two retries: 10 -> 5 -> 0
261-
token = await strategy.refresh_retry_token_for_retry(
262-
token_to_renew=token, error=err
263-
)
264-
assert strategy._retry_quota._available_capacity == 5
265-
token = await strategy.refresh_retry_token_for_retry(
266-
token_to_renew=token, error=err
267-
)
268-
assert strategy._retry_quota._available_capacity == 0
252+
# Next acquire needs 3 but only 1 remains
253+
with pytest.raises(RetryError, match="Retry quota exceeded"):
254+
await retry_quota.acquire(error=Exception())
269255

270-
# Success returns ONLY the last acquired amount -> 5
271-
await strategy.record_success(token=token)
272-
assert strategy._retry_quota._available_capacity == 5
273256

257+
async def test_retry_quota_release_restores_capacity(
258+
retry_quota: StandardRetryQuota,
259+
) -> None:
260+
acquired = await retry_quota.acquire(error=Exception())
261+
assert retry_quota.available_capacity == 7
274262

275-
@pytest.mark.asyncio
276-
async def test_retry_quota_release_when_no_retry(monkeypatch) -> None:
277-
monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 10, raising=False)
278-
quota = StandardRetryQuota()
263+
await retry_quota.release(release_amount=acquired)
264+
assert retry_quota.available_capacity == 10
279265

280-
await quota.acquire(error=Exception())
281-
assert quota._available_capacity == 5
282-
before = quota._available_capacity
283266

284-
await quota.release(release_amount=0)
285-
# Should increment by NO_RETRY_INCREMENT = 1
286-
assert quota._available_capacity == min(before + 1, quota._max_capacity)
287-
assert quota._available_capacity == 6
267+
async def test_retry_quota_release_zero_adds_increment(
268+
retry_quota: StandardRetryQuota,
269+
) -> None:
270+
await retry_quota.acquire(error=Exception())
271+
assert retry_quota.available_capacity == 7
272+
273+
await retry_quota.release(release_amount=0)
274+
assert retry_quota.available_capacity == 8
275+
276+
277+
async def test_retry_quota_release_caps_at_max(
278+
retry_quota: StandardRetryQuota,
279+
) -> None:
280+
# Drain some capacity
281+
await retry_quota.acquire(error=Exception())
282+
await retry_quota.acquire(error=Exception())
283+
assert retry_quota.available_capacity == 4
284+
285+
# Release more than drained. Should cap at max
286+
await retry_quota.release(release_amount=20)
287+
assert retry_quota.available_capacity == 10

0 commit comments

Comments
 (0)