diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java index 4aae8893f..b5fa537ea 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java @@ -62,8 +62,8 @@ public final class ConfigGenerator implements Runnable { .nullable(false) .initialize(writer -> { writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addImport("smithy_core.retries", "SimpleRetryStrategy"); - writer.write("self.retry_strategy = retry_strategy or SimpleRetryStrategy()"); + writer.addImport("smithy_core.retries", "StandardRetryStrategy"); + writer.write("self.retry_strategy = retry_strategy or StandardRetryStrategy()"); }) .build(), ConfigProperty.builder() diff --git a/packages/smithy-aws-core/CHANGES.md b/packages/smithy-aws-core/CHANGES.md index bad47df65..6cff38194 100644 --- a/packages/smithy-aws-core/CHANGES.md +++ b/packages/smithy-aws-core/CHANGES.md @@ -12,6 +12,7 @@ ### Features * Added a hand-written implmentation for the `restJson1` protocol. +* Added a new retry mode `standard` and made it the default retry strategy. ## v0.0.3 diff --git a/packages/smithy-core/src/smithy_core/aio/client.py b/packages/smithy-core/src/smithy_core/aio/client.py index bf27c440c..87146de5e 100644 --- a/packages/smithy-core/src/smithy_core/aio/client.py +++ b/packages/smithy-core/src/smithy_core/aio/client.py @@ -12,7 +12,7 @@ from ..auth import AuthParams from ..deserializers import DeserializeableShape, ShapeDeserializer from ..endpoints import EndpointResolverParams -from ..exceptions import RetryError, SmithyError +from ..exceptions import ClientTimeoutError, RetryError, SmithyError from ..interceptors import ( InputContext, Interceptor, @@ -330,7 +330,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( return await self._handle_attempt(call, request_context, request_future) retry_strategy = call.retry_strategy - retry_token = retry_strategy.acquire_initial_retry_token( + retry_token = await retry_strategy.acquire_initial_retry_token( token_scope=call.retry_scope ) @@ -349,7 +349,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( if isinstance(output_context.response, Exception): try: - retry_strategy.refresh_retry_token_for_retry( + retry_token = await retry_strategy.refresh_retry_token_for_retry( token_to_renew=retry_token, error=output_context.response, ) @@ -364,7 +364,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( await seek(request_context.transport_request.body, 0) else: - retry_strategy.record_success(token=retry_token) + await retry_strategy.record_success(token=retry_token) return output_context async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape]( @@ -448,24 +448,32 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape]( _LOGGER.debug("Sending request %s", request_context.transport_request) - if request_future is not None: - # If we have an input event stream (or duplex event stream) then we - # need to let the client return ASAP so that it can start sending - # events. So here we start the transport send in a background task - # then set the result of the request future. It's important to sequence - # it just like that so that the client gets a stream that's ready - # to send. - transport_task = asyncio.create_task( - self.transport.send(request=request_context.transport_request) - ) - request_future.set_result(request_context) - transport_response = await transport_task - else: - # If we don't have an input stream, there's no point in creating a - # task, so we just immediately await the coroutine. - transport_response = await self.transport.send( - request=request_context.transport_request - ) + try: + if request_future is not None: + # If we have an input event stream (or duplex event stream) then we + # need to let the client return ASAP so that it can start sending + # events. So here we start the transport send in a background task + # then set the result of the request future. It's important to sequence + # it just like that so that the client gets a stream that's ready + # to send. + transport_task = asyncio.create_task( + self.transport.send(request=request_context.transport_request) + ) + request_future.set_result(request_context) + transport_response = await transport_task + else: + # If we don't have an input stream, there's no point in creating a + # task, so we just immediately await the coroutine. + transport_response = await self.transport.send( + request=request_context.transport_request + ) + except Exception as e: + error_info = self.transport.get_error_info(e) + if error_info.is_timeout_error: + raise ClientTimeoutError( + message=f"Client timeout occurred: {e}", fault=error_info.fault + ) from e + raise _LOGGER.debug("Received response: %s", transport_response) diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py index 31d772125..e3f8974be 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py @@ -1,7 +1,8 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from collections.abc import AsyncIterable, Callable -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable from ...documents import TypeRegistry from ...endpoints import EndpointResolverParams @@ -10,6 +11,18 @@ from ...interfaces import StreamingBlob as SyncStreamingBlob from .eventstream import EventPublisher, EventReceiver + +@dataclass(frozen=True) +class ErrorInfo: + """Information about an error from a transport.""" + + is_timeout_error: bool + """Whether this error represents a timeout condition.""" + + fault: Literal["client", "server"] = "client" + """Whether the client or server is at fault.""" + + if TYPE_CHECKING: from typing_extensions import TypeForm @@ -86,7 +99,23 @@ async def resolve_endpoint(self, params: EndpointResolverParams[Any]) -> Endpoin class ClientTransport[I: Request, O: Response](Protocol): - """Protocol-agnostic representation of a client tranport (e.g. an HTTP client).""" + """Protocol-agnostic representation of a client transport (e.g. an HTTP client). + + Transport implementations must define the get_error_info method to determine which + exceptions represent timeout conditions for that transport. + """ + + def get_error_info(self, exception: Exception, **kwargs) -> ErrorInfo: + """Get information about an exception. + + Args: + exception: The exception to analyze + **kwargs: Additional context for analysis + + Returns: + ErrorInfo with timeout and fault information. + """ + ... async def send(self, request: I) -> O: """Send a request over the transport and receive the response.""" diff --git a/packages/smithy-core/src/smithy_core/exceptions.py b/packages/smithy-core/src/smithy_core/exceptions.py index 0e28bd530..53e4aacb3 100644 --- a/packages/smithy-core/src/smithy_core/exceptions.py +++ b/packages/smithy-core/src/smithy_core/exceptions.py @@ -50,6 +50,9 @@ class CallError(SmithyError): is_throttling_error: bool = False """Whether the error is a throttling error.""" + is_timeout_error: bool = False + """Whether the error represents a timeout condition.""" + def __post_init__(self): super().__init__(self.message) @@ -61,6 +64,20 @@ class ModeledError(CallError): fault: Fault = "client" +@dataclass(kw_only=True) +class ClientTimeoutError(CallError): + """Exception raised when a client-side timeout occurs. + + This error indicates that the client transport layer encountered a timeout while + attempting to communicate with the server. This typically occurs when network + requests take longer than the configured timeout period. + """ + + fault: Fault = "client" + is_timeout_error: bool = True + is_retry_safe: bool = True + + class SerializationError(SmithyError): """Base exception type for exceptions raised during serialization.""" diff --git a/packages/smithy-core/src/smithy_core/interfaces/retries.py b/packages/smithy-core/src/smithy_core/interfaces/retries.py index a5c9d428b..ab7bbdeed 100644 --- a/packages/smithy-core/src/smithy_core/interfaces/retries.py +++ b/packages/smithy-core/src/smithy_core/interfaces/retries.py @@ -61,7 +61,7 @@ class RetryStrategy(Protocol): max_attempts: int """Upper limit on total attempt count (initial attempt plus retries).""" - def acquire_initial_retry_token( + async def acquire_initial_retry_token( self, *, token_scope: str | None = None ) -> RetryToken: """Called before any retries (for the first attempt at the operation). @@ -74,7 +74,7 @@ def acquire_initial_retry_token( """ ... - def refresh_retry_token_for_retry( + async def refresh_retry_token_for_retry( self, *, token_to_renew: RetryToken, error: Exception ) -> RetryToken: """Replace an existing retry token from a failed attempt with a new token. @@ -91,7 +91,7 @@ def refresh_retry_token_for_retry( """ ... - def record_success(self, *, token: RetryToken) -> None: + async def record_success(self, *, token: RetryToken) -> None: """Return token after successful completion of an operation. Upon successful completion of the operation, a user calls this function to diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index 06bf6f988..c79d6b3ac 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import random from collections.abc import Callable from dataclasses import dataclass @@ -204,7 +205,7 @@ def __init__( self.backoff_strategy = backoff_strategy or ExponentialRetryBackoffStrategy() self.max_attempts = max_attempts - def acquire_initial_retry_token( + async def acquire_initial_retry_token( self, *, token_scope: str | None = None ) -> SimpleRetryToken: """Called before any retries (for the first attempt at the operation). @@ -214,7 +215,7 @@ def acquire_initial_retry_token( retry_delay = self.backoff_strategy.compute_next_backoff_delay(0) return SimpleRetryToken(retry_count=0, retry_delay=retry_delay) - def refresh_retry_token_for_retry( + async def refresh_retry_token_for_retry( self, *, token_to_renew: retries_interface.RetryToken, @@ -240,5 +241,158 @@ def refresh_retry_token_for_retry( else: raise RetryError(f"Error is not retryable: {error}") from error - def record_success(self, *, token: retries_interface.RetryToken) -> None: + async def record_success(self, *, token: retries_interface.RetryToken) -> None: """Not used by this retry strategy.""" + + +@dataclass(kw_only=True) +class StandardRetryToken: + retry_count: int + """Retry count is the total number of attempts minus the initial attempt.""" + + retry_delay: float + """Delay in seconds to wait before the retry attempt.""" + + quota_consumed: int = 0 + """The total amount of quota consumed.""" + + last_quota_acquired: int = 0 + """The amount of last quota acquired.""" + + +class StandardRetryStrategy(retries_interface.RetryStrategy): + def __init__(self, *, max_attempts: int = 3): + """Standard retry strategy using truncated binary exponential backoff with full + jitter. + + :param max_attempts: Upper limit on total number of attempts made, including + initial attempt and retries. + """ + self.backoff_strategy = ExponentialRetryBackoffStrategy( + backoff_scale_value=1, + jitter_type=ExponentialBackoffJitterType.FULL, + ) + self.max_attempts = max_attempts + self._retry_quota = StandardRetryQuota() + + async def acquire_initial_retry_token( + self, *, token_scope: str | None = None + ) -> StandardRetryToken: + """Called before any retries (for the first attempt at the operation). + + :param token_scope: This argument is ignored by this retry strategy. + """ + retry_delay = self.backoff_strategy.compute_next_backoff_delay(0) + return StandardRetryToken(retry_count=0, retry_delay=retry_delay) + + async def refresh_retry_token_for_retry( + self, + *, + token_to_renew: StandardRetryToken, + error: Exception, + ) -> StandardRetryToken: + """Replace an existing retry token from a failed attempt with a new token. + + This retry strategy always returns a token until the attempt count stored in + the new token exceeds the ``max_attempts`` value. + + :param token_to_renew: The token used for the previous failed attempt. + :param error: The error that triggered the need for a retry. + :raises RetryError: If no further retry attempts are allowed. + """ + if isinstance(error, retries_interface.ErrorRetryInfo) and error.is_retry_safe: + retry_count = token_to_renew.retry_count + 1 + if retry_count >= self.max_attempts: + raise RetryError( + f"Reached maximum number of allowed attempts: {self.max_attempts}" + ) from error + + # Acquire additional quota for this retry attempt + # (may raise a RetryError if none is available) + quota_acquired = await self._retry_quota.acquire(error=error) + total_quota = token_to_renew.quota_consumed + quota_acquired + + if error.retry_after is not None: + retry_delay = error.retry_after + else: + retry_delay = self.backoff_strategy.compute_next_backoff_delay( + retry_count + ) + + return StandardRetryToken( + retry_count=retry_count, + retry_delay=retry_delay, + quota_consumed=total_quota, + last_quota_acquired=quota_acquired, + ) + else: + raise RetryError(f"Error is not retryable: {error}") from error + + async def record_success(self, *, token: StandardRetryToken) -> None: + """Return token after successful completion of an operation. + + Releases retry tokens back to the retry quota based on the previous amount + consumed. + + :param token: The token used for the previous successful attempt. + """ + await self._retry_quota.release(release_amount=token.last_quota_acquired) + + +class StandardRetryQuota: + """Retry quota used by :py:class:`StandardRetryStrategy`.""" + + INITIAL_RETRY_TOKENS = 500 + RETRY_COST = 5 + NO_RETRY_INCREMENT = 1 + TIMEOUT_RETRY_COST = 10 + + def __init__(self): + self._max_capacity = self.INITIAL_RETRY_TOKENS + self._available_capacity = self.INITIAL_RETRY_TOKENS + self._lock = asyncio.Lock() + + async def acquire(self, *, error: Exception) -> int: + """Attempt to acquire a certain amount of capacity. + + If there's no sufficient amount of capacity available, raise an exception. + Otherwise, we return the amount of capacity successfully allocated. + """ + # TODO: update `is_timeout` when `is_timeout_error` is implemented + is_timeout = False + capacity_amount = self.TIMEOUT_RETRY_COST if is_timeout else self.RETRY_COST + + async with self._lock: + if capacity_amount > self._available_capacity: + raise RetryError("Retry quota exceeded") + self._available_capacity -= capacity_amount + return capacity_amount + + async def release(self, *, release_amount: int) -> None: + """Release capacity back to the retry quota. + + The capacity being released will be truncated if necessary to ensure the max + capacity is never exceeded. + """ + increment = self.NO_RETRY_INCREMENT if release_amount == 0 else release_amount + + if self._available_capacity == self._max_capacity: + return + + async with self._lock: + self._available_capacity = min( + self._available_capacity + increment, self._max_capacity + ) + + +class RetryStrategyMode(Enum): + """Enumeration of available retry strategies.""" + + SIMPLE = "simple" + STANDARD = "standard" + + +RETRY_MODE_MAP = { + RetryStrategyMode.SIMPLE: SimpleRetryStrategy, + RetryStrategyMode.STANDARD: StandardRetryStrategy, +} diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index 0b3c23be4..48b3b9286 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -4,7 +4,12 @@ import pytest from smithy_core.exceptions import CallError, RetryError from smithy_core.retries import ExponentialBackoffJitterType as EBJT -from smithy_core.retries import ExponentialRetryBackoffStrategy, SimpleRetryStrategy +from smithy_core.retries import ( + ExponentialRetryBackoffStrategy, + SimpleRetryStrategy, + StandardRetryQuota, + StandardRetryStrategy, +) @pytest.mark.parametrize( @@ -54,49 +59,229 @@ def test_exponential_backoff_strategy( assert delay_actual == pytest.approx(delay_expected) # type: ignore +@pytest.mark.asyncio @pytest.mark.parametrize("max_attempts", [2, 3, 10]) -def test_simple_retry_strategy(max_attempts: int) -> None: +async def test_simple_retry_strategy(max_attempts: int) -> None: strategy = SimpleRetryStrategy( backoff_strategy=ExponentialRetryBackoffStrategy(backoff_scale_value=5), max_attempts=max_attempts, ) error = CallError(is_retry_safe=True) - token = strategy.acquire_initial_retry_token() + token = await strategy.acquire_initial_retry_token() for _ in range(max_attempts - 1): - token = strategy.refresh_retry_token_for_retry( + token = await strategy.refresh_retry_token_for_retry( token_to_renew=token, error=error ) with pytest.raises(RetryError): - strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) -def test_simple_retry_does_not_retry_unclassified() -> None: +@pytest.mark.asyncio +async def test_simple_retry_does_not_retry_unclassified() -> None: strategy = SimpleRetryStrategy( backoff_strategy=ExponentialRetryBackoffStrategy(backoff_scale_value=5), max_attempts=2, ) - token = strategy.acquire_initial_retry_token() + token = await strategy.acquire_initial_retry_token() with pytest.raises(RetryError): - strategy.refresh_retry_token_for_retry(token_to_renew=token, error=Exception()) + await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=Exception() + ) -def test_simple_retry_does_not_retry_when_safety_unknown() -> None: +@pytest.mark.asyncio +async def test_simple_retry_does_not_retry_when_safety_unknown() -> None: strategy = SimpleRetryStrategy( backoff_strategy=ExponentialRetryBackoffStrategy(backoff_scale_value=5), max_attempts=2, ) error = CallError(is_retry_safe=None) - token = strategy.acquire_initial_retry_token() + token = await strategy.acquire_initial_retry_token() with pytest.raises(RetryError): - strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) -def test_simple_retry_does_not_retry_unsafe() -> None: +@pytest.mark.asyncio +async def test_simple_retry_does_not_retry_unsafe() -> None: strategy = SimpleRetryStrategy( backoff_strategy=ExponentialRetryBackoffStrategy(backoff_scale_value=5), max_attempts=2, ) error = CallError(fault="client", is_retry_safe=False) - token = strategy.acquire_initial_retry_token() + token = await strategy.acquire_initial_retry_token() + with pytest.raises(RetryError): + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("max_attempts", [2, 3, 10]) +async def test_standard_retry_strategy(max_attempts: int) -> None: + strategy = StandardRetryStrategy(max_attempts=max_attempts) + error = CallError(is_retry_safe=True) + token = await strategy.acquire_initial_retry_token() + for _ in range(max_attempts - 1): + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + with pytest.raises(RetryError): + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + +@pytest.mark.asyncio +async def test_standard_retry_does_not_retry_unclassified() -> None: + strategy = StandardRetryStrategy() + token = await strategy.acquire_initial_retry_token() with pytest.raises(RetryError): - strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=Exception() + ) + + +@pytest.mark.asyncio +async def test_standard_retry_does_not_retry_when_safety_unknown() -> None: + strategy = StandardRetryStrategy() + error = CallError(is_retry_safe=None) + token = await strategy.acquire_initial_retry_token() + with pytest.raises(RetryError): + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + +@pytest.mark.asyncio +async def test_standard_retry_does_not_retry_unsafe() -> None: + strategy = StandardRetryStrategy() + error = CallError(fault="client", is_retry_safe=False) + token = await strategy.acquire_initial_retry_token() + with pytest.raises(RetryError): + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + +@pytest.mark.asyncio +async def test_standard_retry_strategy_respects_max_attempts() -> None: + strategy = StandardRetryStrategy() + error = CallError(is_retry_safe=True) + token = await strategy.acquire_initial_retry_token() + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + with pytest.raises(RetryError): + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + +@pytest.mark.asyncio +async def test_retry_after_overrides_backoff() -> None: + strategy = StandardRetryStrategy() + error = CallError(is_retry_safe=True, retry_after=5) + token = await strategy.acquire_initial_retry_token() + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == 5 + + +@pytest.mark.asyncio +async def test_retry_quota_acquire_when_exhausted(monkeypatch) -> None: + monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 5, raising=False) + monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 2, raising=False) + + quota = StandardRetryQuota() + assert quota._available_capacity == 5 + + # First acquire: 5 -> 3 + assert await quota.acquire(error=Exception()) == 2 + assert quota._available_capacity == 3 + + # Second acquire: 3 -> 1 + assert await quota.acquire(error=Exception()) == 2 + assert quota._available_capacity == 1 + + # Third acquire needs 2 but only 1 remains -> should raise + with pytest.raises(RetryError): + await quota.acquire(error=Exception()) + assert quota._available_capacity == 1 + + +@pytest.mark.asyncio +async def test_retry_quota_release_zero_adds_increment(monkeypatch) -> None: + monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 5, raising=False) + monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 2, raising=False) + monkeypatch.setattr(StandardRetryQuota, "NO_RETRY_INCREMENT", 1, raising=False) + + quota = StandardRetryQuota() + assert quota._available_capacity == 5 + + # First acquire: 5 -> 3 + assert await quota.acquire(error=Exception()) == 2 + assert quota._available_capacity == 3 + + # release 0 should add NO_RETRY_INCREMENT: 3 -> 4 + await quota.release(release_amount=0) + assert quota._available_capacity == 4 + + # Next acquire should still work: 4 -> 2 + assert await quota.acquire(error=Exception()) == 2 + assert quota._available_capacity == 2 + + +@pytest.mark.asyncio +async def test_retry_quota_release_caps_at_max(monkeypatch) -> None: + monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 10, raising=False) + monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 3, raising=False) + + quota = StandardRetryQuota() + assert quota._available_capacity == 10 + + # Drain some capacity: 10 -> 7 -> 4 + assert await quota.acquire(error=Exception()) == 3 + assert quota._available_capacity == 7 + assert await quota.acquire(error=Exception()) == 3 + assert quota._available_capacity == 4 + + # Release more than needed: 4 + 8 = 12. Should cap at max = 10 + await quota.release(release_amount=8) + assert quota._available_capacity == 10 + + # Another acquire should succeed from max: 10 -> 7 + assert await quota.acquire(error=Exception()) == 3 + assert quota._available_capacity == 7 + + +@pytest.mark.asyncio +async def test_retry_quota_releases_last_acquired_amount(monkeypatch) -> None: + monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 10, raising=False) + monkeypatch.setattr(StandardRetryQuota, "RETRY_COST", 5, raising=False) + + strategy = StandardRetryStrategy() + err = CallError(is_retry_safe=True) + token = await strategy.acquire_initial_retry_token() + + # Two retries: 10 -> 5 -> 0 + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=err + ) + assert strategy._retry_quota._available_capacity == 5 + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=err + ) + assert strategy._retry_quota._available_capacity == 0 + + # Success returns ONLY the last acquired amount -> 5 + await strategy.record_success(token=token) + assert strategy._retry_quota._available_capacity == 5 + + +@pytest.mark.asyncio +async def test_retry_quota_release_when_no_retry(monkeypatch) -> None: + monkeypatch.setattr(StandardRetryQuota, "INITIAL_RETRY_TOKENS", 10, raising=False) + quota = StandardRetryQuota() + + await quota.acquire(error=Exception()) + assert quota._available_capacity == 5 + before = quota._available_capacity + + await quota.release(release_amount=0) + # Should increment by NO_RETRY_INCREMENT = 1 + assert quota._available_capacity == min(before + 1, quota._max_capacity) + assert quota._available_capacity == 6 diff --git a/packages/smithy-http/src/smithy_http/aio/aiohttp.py b/packages/smithy-http/src/smithy_http/aio/aiohttp.py index 83f4c191f..d1935ada6 100644 --- a/packages/smithy-http/src/smithy_http/aio/aiohttp.py +++ b/packages/smithy-http/src/smithy_http/aio/aiohttp.py @@ -20,7 +20,7 @@ except ImportError: HAS_AIOHTTP = False # type: ignore -from smithy_core.aio.interfaces import StreamingBlob +from smithy_core.aio.interfaces import ErrorInfo, StreamingBlob from smithy_core.aio.types import AsyncBytesReader from smithy_core.aio.utils import async_list from smithy_core.exceptions import MissingDependencyError @@ -52,6 +52,14 @@ def __post_init__(self) -> None: class AIOHTTPClient(HTTPClient): """Implementation of :py:class:`.interfaces.HTTPClient` using aiohttp.""" + def get_error_info(self, exception: Exception, **kwargs) -> ErrorInfo: + """Get information about aiohttp errors.""" + + if isinstance(exception, TimeoutError): + return ErrorInfo(is_timeout_error=True) + + return ErrorInfo(is_timeout_error=False) + def __init__( self, *, diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index 028161279..6c492d525 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -12,6 +12,8 @@ from io import BufferedIOBase, BytesIO from typing import TYPE_CHECKING, Any, cast +from awscrt.exceptions import AwsCrtError + if TYPE_CHECKING: # Both of these are types that essentially are "castable to bytes/memoryview" # Unfortunately they're not exposed anywhere so we have to import them from @@ -33,6 +35,7 @@ HAS_CRT = False # type: ignore from smithy_core import interfaces as core_interfaces +from smithy_core.aio.interfaces import ErrorInfo from smithy_core.aio.types import AsyncBytesReader from smithy_core.aio.utils import close from smithy_core.exceptions import MissingDependencyError @@ -205,6 +208,22 @@ class AWSCRTHTTPClient(http_aio_interfaces.HTTPClient): _HTTP_PORT = 80 _HTTPS_PORT = 443 + def get_error_info(self, exception: Exception, **kwargs) -> ErrorInfo: + """Get information about CRT errors.""" + + timeout_indicators = ( + "AWS_IO_SOCKET_TIMEOUT", + "AWS_IO_CHANNEL_ERROR_SOCKET_TIMEOUT", + "AWS_ERROR_HTTP_REQUEST_TIMEOUT", + ) + if isinstance(exception, TimeoutError): + return ErrorInfo(is_timeout_error=True, fault="client") + + if isinstance(exception, AwsCrtError) and exception.name in timeout_indicators: + return ErrorInfo(is_timeout_error=True, fault="client") + + return ErrorInfo(is_timeout_error=False) + def __init__( self, eventloop: _AWSCRTEventLoop | None = None, diff --git a/packages/smithy-http/src/smithy_http/aio/protocols.py b/packages/smithy-http/src/smithy_http/aio/protocols.py index cf25036fe..992f72d35 100644 --- a/packages/smithy-http/src/smithy_http/aio/protocols.py +++ b/packages/smithy-http/src/smithy_http/aio/protocols.py @@ -215,7 +215,6 @@ async def _create_error( ) return error_shape.deserialize(deserializer) - is_throttle = response.status == 429 message = ( f"Unknown error for operation {operation.schema.id} " f"- status: {response.status}" @@ -224,11 +223,22 @@ async def _create_error( message += f" - id: {error_id}" if response.reason is not None: message += f" - reason: {response.status}" + + if response.status == 408: + is_timeout = True + fault = "server" + else: + is_timeout = False + fault = "client" if response.status < 500 else "server" + + is_throttle = response.status == 429 + return CallError( message=message, - fault="client" if response.status < 500 else "server", + fault=fault, is_throttling_error=is_throttle, - is_retry_safe=is_throttle or None, + is_timeout_error=is_timeout, + is_retry_safe=is_throttle or is_timeout or None, ) def _matches_content_type(self, response: HTTPResponse) -> bool: diff --git a/packages/smithy-http/tests/unit/aio/test_protocols.py b/packages/smithy-http/tests/unit/aio/test_protocols.py index ecdb15cfa..665989865 100644 --- a/packages/smithy-http/tests/unit/aio/test_protocols.py +++ b/packages/smithy-http/tests/unit/aio/test_protocols.py @@ -2,23 +2,24 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any +from unittest.mock import Mock import pytest from smithy_core import URI from smithy_core.documents import TypeRegistry from smithy_core.endpoints import Endpoint -from smithy_core.interfaces import TypedProperties from smithy_core.interfaces import URI as URIInterface from smithy_core.schemas import APIOperation from smithy_core.shapes import ShapeID +from smithy_core.types import TypedProperties from smithy_http import Fields -from smithy_http.aio import HTTPRequest +from smithy_http.aio import HTTPRequest, HTTPResponse from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface -from smithy_http.aio.protocols import HttpClientProtocol +from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol -class TestProtocol(HttpClientProtocol): +class MockProtocol(HttpClientProtocol): _id = ShapeID("ns.foo#bar") @property @@ -125,7 +126,7 @@ def deserialize_response( def test_http_protocol_joins_uris( request_uri: URI, endpoint_uri: URI, expected: URI ) -> None: - protocol = TestProtocol() + protocol = MockProtocol() request = HTTPRequest( destination=request_uri, method="GET", @@ -135,3 +136,28 @@ def test_http_protocol_joins_uris( updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint) actual = updated_request.destination assert actual == expected + + +@pytest.mark.asyncio +async def test_http_408_creates_timeout_error() -> None: + """Test that HTTP 408 creates a timeout error with server fault.""" + protocol = Mock(spec=HttpBindingClientProtocol) + protocol.error_identifier = Mock() + protocol.error_identifier.identify.return_value = None + + response = HTTPResponse(status=408, fields=Fields()) + + error = await HttpBindingClientProtocol._create_error( + protocol, + operation=Mock(), + request=HTTPRequest( + destination=URI(host="example.com"), method="POST", fields=Fields() + ), + response=response, + response_body=b"", + error_registry=TypeRegistry({}), + context=TypedProperties(), + ) + + assert error.is_timeout_error is True + assert error.fault == "server" diff --git a/packages/smithy-http/tests/unit/aio/test_timeout_errors.py b/packages/smithy-http/tests/unit/aio/test_timeout_errors.py new file mode 100644 index 000000000..4bec2d2be --- /dev/null +++ b/packages/smithy-http/tests/unit/aio/test_timeout_errors.py @@ -0,0 +1,63 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from smithy_core.aio.interfaces import ErrorInfo + +try: + from smithy_http.aio.aiohttp import AIOHTTPClient + + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + +try: + from smithy_http.aio.crt import AWSCRTHTTPClient + + HAS_CRT = True +except ImportError: + HAS_CRT = False + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not available") +class TestAIOHTTPTimeoutErrorHandling: + """Test timeout error handling for AIOHTTPClient.""" + + @pytest.fixture + async def client(self): + return AIOHTTPClient() + + @pytest.mark.asyncio + async def test_timeout_error_detection(self, client): + """Test timeout error detection for standard TimeoutError.""" + timeout_err = TimeoutError("Connection timed out") + result = client.get_error_info(timeout_err) + assert result == ErrorInfo(is_timeout_error=True, fault="client") + + @pytest.mark.asyncio + async def test_non_timeout_error_detection(self, client): + """Test non-timeout error detection.""" + other_err = ValueError("Not a timeout") + result = client.get_error_info(other_err) + assert result == ErrorInfo(is_timeout_error=False, fault="client") + + +@pytest.mark.skipif(not HAS_CRT, reason="AWS CRT not available") +class TestAWSCRTTimeoutErrorHandling: + """Test timeout error handling for AWSCRTHTTPClient.""" + + @pytest.fixture + def client(self): + return AWSCRTHTTPClient() + + def test_timeout_error_detection(self, client): + """Test timeout error detection for standard TimeoutError.""" + timeout_err = TimeoutError("Connection timed out") + result = client.get_error_info(timeout_err) + assert result == ErrorInfo(is_timeout_error=True, fault="client") + + def test_non_timeout_error_detection(self, client): + """Test non-timeout error detection.""" + other_err = ValueError("Not a timeout") + result = client.get_error_info(other_err) + assert result == ErrorInfo(is_timeout_error=False, fault="client")