diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py index af7c60f6..bdc46211 100644 --- a/src/a2a/client/transports/__init__.py +++ b/src/a2a/client/transports/__init__.py @@ -3,6 +3,7 @@ from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport +from a2a.client.transports.retry import RetryTransport, default_retry_predicate try: @@ -16,4 +17,6 @@ 'GrpcTransport', 'JsonRpcTransport', 'RestTransport', + 'RetryTransport', + 'default_retry_predicate', ] diff --git a/src/a2a/client/transports/retry.py b/src/a2a/client/transports/retry.py new file mode 100644 index 00000000..34be49a2 --- /dev/null +++ b/src/a2a/client/transports/retry.py @@ -0,0 +1,372 @@ +"""A transport decorator that adds retry logic with exponential backoff.""" + +import asyncio +import inspect +import logging +import random + +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any, TypeVar + +import httpx + +from a2a.client.client import ClientCallContext +from a2a.client.errors import A2AClientError, A2AClientTimeoutError +from a2a.client.transports.base import ClientTransport +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, + SendMessageRequest, + SendMessageResponse, + StreamResponse, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, +) + + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + +RetryPredicate = Callable[[Exception], bool] +OnRetryCallback = Callable[[int, Exception, float], Awaitable[None] | None] + + +def default_retry_predicate(error: Exception) -> bool: # noqa: PLR0911 + """Determines if an error is retryable based on its type and cause. + + Retryable conditions: + - A2AClientTimeoutError (always) + - A2AClientError caused by httpx.RequestError (network errors) + - A2AClientError caused by httpx.HTTPStatusError with status 429, 502, 503, 504 + - A2AClientError caused by grpc.aio.AioRpcError with UNAVAILABLE or RESOURCE_EXHAUSTED + + Non-retryable: + - Domain-specific errors (TaskNotFoundError, etc.) — inherit A2AError, not A2AClientError + - A2AClientError caused by json.JSONDecodeError or SSEError + - A2AClientError with no recognized __cause__ + - Any non-A2AClientError exception + """ + if isinstance(error, A2AClientTimeoutError): + return True + + if not isinstance(error, A2AClientError): + return False + + cause = error.__cause__ + if cause is None: + return False + + if isinstance(cause, httpx.RequestError): + return True + if isinstance(cause, httpx.HTTPStatusError): + return cause.response.status_code in {429, 502, 503, 504} + + try: + import grpc # noqa: PLC0415 + + if isinstance(cause, grpc.aio.AioRpcError): + return cause.code() in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + } + except ImportError: + pass + + return False + + +class RetryTransport(ClientTransport): + """A transport decorator that adds retry logic with exponential backoff. + + Wraps any ClientTransport and retries failed operations that match + the retry predicate. Streaming methods (send_message_streaming, + subscribe) only retry pre-stream failures; once the first event + is yielded, errors propagate without retry. + """ + + def __init__( # noqa: PLR0913 + self, + base: ClientTransport, + *, + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + jitter: bool = True, + retry_predicate: RetryPredicate | None = None, + on_retry: OnRetryCallback | None = None, + ) -> None: + if max_retries < 0: + raise ValueError('max_retries must be >= 0') + if base_delay <= 0: + raise ValueError('base_delay must be > 0') + if max_delay <= 0: + raise ValueError('max_delay must be > 0') + self._base = base + self._max_retries = max_retries + self._base_delay = base_delay + self._max_delay = max_delay + self._jitter = jitter + self._retry_predicate = retry_predicate or default_retry_predicate + self._on_retry = on_retry + + def _calculate_delay(self, attempt: int) -> float: + """Calculates the delay for a given retry attempt using exponential backoff. + + Args: + attempt: The retry attempt number (1-indexed). + + Returns: + The delay in seconds before the next retry. + """ + delay = min(self._base_delay * (2 ** (attempt - 1)), self._max_delay) + if self._jitter: + delay = random.uniform(0, delay) # noqa: S311 + return delay + + async def _execute_with_retry( + self, + operation: Callable[[], Awaitable[T]], + method_name: str, + ) -> T: + """Executes an async operation with retry logic. + + Args: + operation: A zero-argument async callable that performs the transport call. + method_name: Name of the method being called, used for logging. + + Returns: + The result of the operation. + + Raises: + The last exception if all retry attempts are exhausted. + """ + last_error: Exception | None = None + for attempt in range(self._max_retries + 1): + try: + return await operation() + except Exception as e: # noqa: PERF203 + last_error = e + if attempt >= self._max_retries or not self._retry_predicate(e): + raise + delay = self._calculate_delay(attempt + 1) + logger.warning( + 'Retry %d/%d for %s after %.2fs: %s', + attempt + 1, + self._max_retries, + method_name, + delay, + e, + ) + if self._on_retry is not None: + result: Any = self._on_retry(attempt + 1, e, delay) + if inspect.isawaitable(result): + await result + await asyncio.sleep(delay) + raise last_error # type: ignore[misc] + + async def _execute_streaming_with_retry( + self, + operation: Callable[[], AsyncGenerator[StreamResponse]], + method_name: str, + ) -> AsyncGenerator[StreamResponse]: + """Executes a streaming operation with retry logic for pre-stream failures. + + Retries only apply before the first event is yielded. Once streaming + has started, errors propagate to the caller without retry. + + Args: + operation: A zero-argument callable returning an async generator. + method_name: Name of the method being called, used for logging. + + Yields: + StreamResponse events from the underlying transport. + """ + last_error: Exception | None = None + for attempt in range(self._max_retries + 1): + first = True + try: + stream = operation() + async for event in stream: + first = False + yield event + except Exception as e: + if not first: + raise + last_error = e + if attempt >= self._max_retries or not self._retry_predicate(e): + raise + delay = self._calculate_delay(attempt + 1) + logger.warning( + 'Retry %d/%d for %s after %.2fs: %s', + attempt + 1, + self._max_retries, + method_name, + delay, + e, + ) + if self._on_retry is not None: + result: Any = self._on_retry(attempt + 1, e, delay) + if inspect.isawaitable(result): + await result + await asyncio.sleep(delay) + else: + return + raise last_error # type: ignore[misc] + + async def send_message( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + ) -> SendMessageResponse: + """Sends a non-streaming message request to the agent.""" + return await self._execute_with_retry( + lambda: self._base.send_message(request, context=context), + 'send_message', + ) + + async def send_message_streaming( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Sends a streaming message request to the agent and yields responses.""" + async for event in self._execute_streaming_with_retry( + lambda: self._base.send_message_streaming(request, context=context), + 'send_message_streaming', + ): + yield event + + async def get_task( + self, + request: GetTaskRequest, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + return await self._execute_with_retry( + lambda: self._base.get_task(request, context=context), + 'get_task', + ) + + async def list_tasks( + self, + request: ListTasksRequest, + *, + context: ClientCallContext | None = None, + ) -> ListTasksResponse: + """Retrieves tasks for an agent.""" + return await self._execute_with_retry( + lambda: self._base.list_tasks(request, context=context), + 'list_tasks', + ) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + return await self._execute_with_retry( + lambda: self._base.cancel_task(request, context=context), + 'cancel_task', + ) + + async def create_task_push_notification_config( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + return await self._execute_with_retry( + lambda: self._base.create_task_push_notification_config( + request, context=context + ), + 'create_task_push_notification_config', + ) + + async def get_task_push_notification_config( + self, + request: GetTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + return await self._execute_with_retry( + lambda: self._base.get_task_push_notification_config( + request, context=context + ), + 'get_task_push_notification_config', + ) + + async def list_task_push_notification_configs( + self, + request: ListTaskPushNotificationConfigsRequest, + *, + context: ClientCallContext | None = None, + ) -> ListTaskPushNotificationConfigsResponse: + """Lists push notification configurations for a specific task.""" + return await self._execute_with_retry( + lambda: self._base.list_task_push_notification_configs( + request, context=context + ), + 'list_task_push_notification_configs', + ) + + async def delete_task_push_notification_config( + self, + request: DeleteTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + ) -> None: + """Deletes the push notification configuration for a specific task.""" + await self._execute_with_retry( + lambda: self._base.delete_task_push_notification_config( + request, context=context + ), + 'delete_task_push_notification_config', + ) + + async def subscribe( + self, + request: SubscribeToTaskRequest, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Reconnects to get task updates.""" + async for event in self._execute_streaming_with_retry( + lambda: self._base.subscribe(request, context=context), + 'subscribe', + ): + yield event + + async def get_extended_agent_card( + self, + request: GetExtendedAgentCardRequest, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the Extended AgentCard.""" + return await self._execute_with_retry( + lambda: self._base.get_extended_agent_card( + request, context=context + ), + 'get_extended_agent_card', + ) + + async def close(self) -> None: + """Closes the transport.""" + await self._base.close() diff --git a/tests/client/transports/test_retry.py b/tests/client/transports/test_retry.py new file mode 100644 index 00000000..d86d780b --- /dev/null +++ b/tests/client/transports/test_retry.py @@ -0,0 +1,816 @@ +import json + +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from starlette.applications import Starlette + +from a2a.client.client import ClientCallContext +from a2a.client.errors import A2AClientError, A2AClientTimeoutError +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.client.transports.rest import RestTransport +from a2a.client.transports.retry import ( + RetryTransport, + default_retry_predicate, +) +from a2a.server.request_handlers import RequestHandler +from a2a.server.routes import create_jsonrpc_routes, create_rest_routes +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + Message, + Part, + SendMessageRequest, + SendMessageResponse, + StreamResponse, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, +) +from a2a.utils.constants import ( + PROTOCOL_VERSION_CURRENT, + VERSION_HEADER, + TransportProtocol, +) +from a2a.utils.errors import InternalError, TaskNotFoundError + + +@pytest.fixture +def mock_transport() -> AsyncMock: + return AsyncMock(spec=ClientTransport) + + +@pytest.fixture +def retry_transport(mock_transport: AsyncMock) -> RetryTransport: + return RetryTransport( + mock_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + +class TestDefaultRetryPredicate: + def test_timeout_error_is_retryable(self) -> None: + error = A2AClientTimeoutError('timeout') + assert default_retry_predicate(error) is True + + def test_network_error_is_retryable(self) -> None: + cause = httpx.ConnectError('connection refused') + error = A2AClientError( + 'Network communication error: connection refused' + ) + error.__cause__ = cause + assert default_retry_predicate(error) is True + + @pytest.mark.parametrize('status_code', [429, 502, 503, 504]) + def test_retryable_http_status_codes(self, status_code: int) -> None: + request = httpx.Request('POST', 'http://example.com') + response = httpx.Response(status_code, request=request) + cause = httpx.HTTPStatusError( + 'error', request=request, response=response + ) + error = A2AClientError(f'HTTP Error {status_code}') + error.__cause__ = cause + assert default_retry_predicate(error) is True + + @pytest.mark.parametrize('status_code', [400, 401, 403, 404, 500]) + def test_non_retryable_http_status_codes(self, status_code: int) -> None: + request = httpx.Request('POST', 'http://example.com') + response = httpx.Response(status_code, request=request) + cause = httpx.HTTPStatusError( + 'error', request=request, response=response + ) + error = A2AClientError(f'HTTP Error {status_code}') + error.__cause__ = cause + assert default_retry_predicate(error) is False + + def test_json_decode_error_is_not_retryable(self) -> None: + cause = json.JSONDecodeError('msg', 'doc', 0) + error = A2AClientError('JSON Decode Error') + error.__cause__ = cause + assert default_retry_predicate(error) is False + + def test_domain_error_is_not_retryable(self) -> None: + error = TaskNotFoundError() + assert default_retry_predicate(error) is False + + def test_internal_error_is_not_retryable(self) -> None: + error = InternalError() + assert default_retry_predicate(error) is False + + def test_client_error_without_cause_is_not_retryable(self) -> None: + error = A2AClientError('some error') + assert default_retry_predicate(error) is False + + def test_non_a2a_error_is_not_retryable(self) -> None: + error = ValueError('not an A2A error') + assert default_retry_predicate(error) is False + + @pytest.mark.parametrize( + 'status_code, expected', + [ + ('UNAVAILABLE', True), + ('RESOURCE_EXHAUSTED', True), + ('NOT_FOUND', False), + ], + ) + def test_grpc_error_retriability( + self, status_code: str, expected: bool + ) -> None: + grpc = pytest.importorskip('grpc') + + class FakeAioRpcError(grpc.aio.AioRpcError, Exception): + def __init__(self, code: object) -> None: + self._code = code + + def code(self) -> object: + return self._code + + cause = FakeAioRpcError(getattr(grpc.StatusCode, status_code)) + error = A2AClientError(f'gRPC Error {status_code}') + error.__cause__ = cause + assert default_retry_predicate(error) is expected + + +class TestRetryTransport: + @pytest.mark.parametrize( + 'method_name, request_obj', + [ + ( + 'send_message', + SendMessageRequest(message=Message(parts=[Part(text='hello')])), + ), + ('get_task', GetTaskRequest(id='t1')), + ('list_tasks', ListTasksRequest()), + ('cancel_task', CancelTaskRequest(id='t1')), + ( + 'create_task_push_notification_config', + TaskPushNotificationConfig(task_id='t1'), + ), + ( + 'get_task_push_notification_config', + GetTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ( + 'list_task_push_notification_configs', + ListTaskPushNotificationConfigsRequest(task_id='t1'), + ), + ( + 'delete_task_push_notification_config', + DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ('get_extended_agent_card', GetExtendedAgentCardRequest()), + ], + ) + @pytest.mark.asyncio + async def test_delegates_to_base_transport( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + method_name: str, + request_obj: object, + ) -> None: + await getattr(retry_transport, method_name)(request_obj) + getattr(mock_transport, method_name).assert_called_once_with( + request_obj, context=None + ) + + @pytest.mark.asyncio + async def test_retries_on_network_error( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + cause = httpx.ConnectError('refused') + error = A2AClientError('Network communication error: refused') + error.__cause__ = cause + + expected = Task() + mock_transport.get_task.side_effect = [error, expected] + result = await retry_transport.get_task(GetTaskRequest(id='t1')) + assert result == expected + assert mock_transport.get_task.call_count == 2 + + @pytest.mark.asyncio + async def test_no_retry_on_domain_error( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + mock_transport.get_task.side_effect = TaskNotFoundError() + with pytest.raises(TaskNotFoundError): + await retry_transport.get_task(GetTaskRequest(id='t1')) + assert mock_transport.get_task.call_count == 1 + + @pytest.mark.asyncio + async def test_no_retry_on_non_retryable_http_status( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + request = httpx.Request('POST', 'http://example.com') + response = httpx.Response(400, request=request) + cause = httpx.HTTPStatusError( + 'bad request', request=request, response=response + ) + error = A2AClientError('HTTP Error 400: bad request') + error.__cause__ = cause + + mock_transport.send_message.side_effect = error + with pytest.raises(A2AClientError): + await retry_transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + @pytest.mark.asyncio + async def test_exponential_backoff_timing( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=1.0, + max_delay=30.0, + jitter=False, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with patch( + 'a2a.client.transports.retry.asyncio.sleep', + new_callable=AsyncMock, + ) as mock_sleep: + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + + assert mock_sleep.call_count == 3 + mock_sleep.assert_any_call(1.0) + mock_sleep.assert_any_call(2.0) + mock_sleep.assert_any_call(4.0) + + @pytest.mark.asyncio + async def test_max_delay_cap(self, mock_transport: AsyncMock) -> None: + transport = RetryTransport( + mock_transport, + max_retries=5, + base_delay=10.0, + max_delay=20.0, + jitter=False, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with patch( + 'a2a.client.transports.retry.asyncio.sleep', + new_callable=AsyncMock, + ) as mock_sleep: + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + + for call_args in mock_sleep.call_args_list: + assert call_args[0][0] <= 20.0 + + @pytest.mark.asyncio + async def test_jitter_produces_randomized_delays( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=1.0, + max_delay=30.0, + jitter=True, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with patch( + 'a2a.client.transports.retry.asyncio.sleep', + new_callable=AsyncMock, + ) as mock_sleep: + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + + for i, call_args in enumerate(mock_sleep.call_args_list): + delay = call_args[0][0] + max_possible = min(1.0 * (2**i), 30.0) + assert 0 <= delay <= max_possible + + @pytest.mark.asyncio + async def test_streaming_retries_pre_stream_failure( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def success_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + yield StreamResponse() + + mock_transport.send_message_streaming.side_effect = [ + A2AClientTimeoutError('timeout'), + success_stream(), + ] + events = [ + event + async for event in retry_transport.send_message_streaming( + SendMessageRequest() + ) + ] + + assert len(events) == 2 + assert mock_transport.send_message_streaming.call_count == 2 + + @pytest.mark.asyncio + async def test_streaming_no_retry_mid_stream( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def failing_mid_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + raise A2AClientTimeoutError('mid-stream timeout') + + mock_transport.send_message_streaming.return_value = ( + failing_mid_stream() + ) + + events: list[StreamResponse] = [] + with pytest.raises(A2AClientTimeoutError): + async for event in retry_transport.send_message_streaming( + SendMessageRequest() + ): + events.append(event) # noqa: PERF401 + + assert len(events) == 1 + assert mock_transport.send_message_streaming.call_count == 1 + + @pytest.mark.asyncio + async def test_subscribe_streaming_retries( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def success_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + + mock_transport.subscribe.side_effect = [ + A2AClientTimeoutError('timeout'), + success_stream(), + ] + events = [ + event + async for event in retry_transport.subscribe( + SubscribeToTaskRequest(id='t1') + ) + ] + + assert len(events) == 1 + assert mock_transport.subscribe.call_count == 2 + + @pytest.mark.asyncio + async def test_streaming_max_retries_exhausted( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + mock_transport.send_message_streaming.side_effect = ( + A2AClientTimeoutError('timeout') + ) + with pytest.raises(A2AClientTimeoutError): + async for _ in retry_transport.send_message_streaming( + SendMessageRequest() + ): + pass + assert mock_transport.send_message_streaming.call_count == 4 + + @pytest.mark.asyncio + async def test_custom_retry_predicate( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + retry_predicate=lambda e: isinstance(e, TaskNotFoundError), + ) + expected = Task() + mock_transport.get_task.side_effect = [ + TaskNotFoundError(), + expected, + ] + result = await transport.get_task(GetTaskRequest(id='t1')) + assert result == expected + assert mock_transport.get_task.call_count == 2 + + @pytest.mark.asyncio + async def test_custom_predicate_rejects_normally_retryable( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=0.01, + retry_predicate=lambda e: False, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + @pytest.mark.asyncio + async def test_on_retry_async_callback( + self, mock_transport: AsyncMock + ) -> None: + on_retry_mock = AsyncMock() + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + on_retry=on_retry_mock, + ) + error = A2AClientTimeoutError('timeout') + expected = SendMessageResponse() + mock_transport.send_message.side_effect = [error, expected] + + await transport.send_message(SendMessageRequest()) + + on_retry_mock.assert_called_once_with(1, error, 0.01) + + @pytest.mark.asyncio + async def test_on_retry_sync_callback( + self, mock_transport: AsyncMock + ) -> None: + calls: list[tuple[int, Exception, float]] = [] + + def sync_on_retry(attempt: int, error: Exception, delay: float) -> None: + calls.append((attempt, error, delay)) + + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + on_retry=sync_on_retry, + ) + error = A2AClientTimeoutError('timeout') + expected = SendMessageResponse() + mock_transport.send_message.side_effect = [error, expected] + + await transport.send_message(SendMessageRequest()) + + assert len(calls) == 1 + assert calls[0][0] == 1 + + @pytest.mark.asyncio + async def test_close_delegates_without_retry( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + await retry_transport.close() + mock_transport.close.assert_called_once() + + @pytest.mark.asyncio + async def test_context_passed_through( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + context = ClientCallContext(timeout=5.0) + request = SendMessageRequest( + message=Message(parts=[Part(text='hello')]) + ) + await retry_transport.send_message(request, context=context) + mock_transport.send_message.assert_called_once_with( + request, context=context + ) + + @pytest.mark.asyncio + async def test_streaming_delegates( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def mock_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + + mock_transport.send_message_streaming.return_value = mock_stream() + request = SendMessageRequest() + events = [ + event + async for event in retry_transport.send_message_streaming(request) + ] + + assert len(events) == 1 + mock_transport.send_message_streaming.assert_called_once_with( + request, context=None + ) + + @pytest.mark.asyncio + async def test_end_to_end_retry_within_context_manager( + self, mock_transport: AsyncMock + ) -> None: + expected = SendMessageResponse() + mock_transport.send_message.side_effect = [ + A2AClientTimeoutError('timeout'), + expected, + ] + + async with RetryTransport( + mock_transport, max_retries=2, base_delay=0.01, jitter=False + ) as t: + assert t is not mock_transport + result = await t.send_message( + SendMessageRequest(message=Message(parts=[Part(text='hello')])) + ) + assert result == expected + assert mock_transport.send_message.call_count == 2 + mock_transport.close.assert_not_awaited() + + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_end_to_end_retry_exhaustion_within_context_manager( + self, mock_transport: AsyncMock + ) -> None: + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with pytest.raises(A2AClientTimeoutError): + async with RetryTransport( + mock_transport, max_retries=2, base_delay=0.01, jitter=False + ) as t: + await t.send_message(SendMessageRequest()) + + assert mock_transport.send_message.call_count == 3 + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_max_retries_zero_disables_retry( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=0, + base_delay=0.01, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + def test_invalid_max_retries_raises_value_error( + self, mock_transport: AsyncMock + ) -> None: + with pytest.raises(ValueError, match='max_retries must be >= 0'): + RetryTransport(mock_transport, max_retries=-1) + + def test_invalid_base_delay_raises_value_error( + self, mock_transport: AsyncMock + ) -> None: + with pytest.raises(ValueError, match='base_delay must be > 0'): + RetryTransport(mock_transport, base_delay=0) + + def test_invalid_max_delay_raises_value_error( + self, mock_transport: AsyncMock + ) -> None: + with pytest.raises(ValueError, match='max_delay must be > 0'): + RetryTransport(mock_transport, max_delay=-1) + + +class TestRetryTransportIntegration: + """E2E tests: RetryTransport wrapping real transports against real servers.""" + + @pytest.fixture + def mock_request_handler(self) -> AsyncMock: + handler = AsyncMock(spec=RequestHandler) + handler.on_get_task.return_value = Task( + id='task-retry-test', + context_id='ctx-retry-test', + ) + return handler + + @pytest.fixture + def agent_card(self) -> AgentCard: + return AgentCard( + name='Retry Test Agent', + description='Agent for retry integration tests.', + version='1.0.0', + capabilities=AgentCapabilities(streaming=False), + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + supported_interfaces=[ + AgentInterface( + protocol_binding=TransportProtocol.HTTP_JSON, + url='http://testserver', + ), + AgentInterface( + protocol_binding=TransportProtocol.JSONRPC, + url='http://testserver', + ), + ], + ) + + @pytest.mark.asyncio + async def test_retry_with_rest_transport_recovers_from_503( + self, + mock_request_handler: AsyncMock, + agent_card: AgentCard, + ) -> None: + """RetryTransport + real RestTransport + real Starlette server with transient 503s.""" + rest_routes = create_rest_routes( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) + app = Starlette(routes=[*rest_routes]) + + # Wrap app with middleware that returns 503 for first 2 requests + failure_count = 0 + fail_limit = 2 + + async def transient_failure_app(scope, receive, send): + nonlocal failure_count + if scope['type'] == 'http' and failure_count < fail_limit: + failure_count += 1 + await send( + { + 'type': 'http.response.start', + 'status': 503, + 'headers': [ + [b'content-type', b'text/plain'], + ], + } + ) + await send( + { + 'type': 'http.response.body', + 'body': b'Service Unavailable', + } + ) + return + await app(scope, receive, send) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=transient_failure_app), + headers={VERSION_HEADER: PROTOCOL_VERSION_CURRENT}, + ) + inner_transport = RestTransport( + httpx_client, agent_card, 'http://testserver' + ) + retry_transport = RetryTransport( + inner_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + async with retry_transport: + result = await retry_transport.get_task( + GetTaskRequest(id='task-retry-test') + ) + + assert result.id == 'task-retry-test' + assert failure_count == fail_limit + mock_request_handler.on_get_task.assert_called_once() + + @pytest.mark.asyncio + async def test_retry_with_jsonrpc_transport_recovers_from_503( + self, + mock_request_handler: AsyncMock, + agent_card: AgentCard, + ) -> None: + """RetryTransport + real JsonRpcTransport + real Starlette server with transient 503s.""" + jsonrpc_routes = create_jsonrpc_routes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', + ) + app = Starlette(routes=[*jsonrpc_routes]) + + failure_count = 0 + fail_limit = 2 + + async def transient_failure_app(scope, receive, send): + nonlocal failure_count + if scope['type'] == 'http' and failure_count < fail_limit: + failure_count += 1 + await send( + { + 'type': 'http.response.start', + 'status': 503, + 'headers': [ + [b'content-type', b'text/plain'], + ], + } + ) + await send( + { + 'type': 'http.response.body', + 'body': b'Service Unavailable', + } + ) + return + await app(scope, receive, send) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=transient_failure_app), + headers={VERSION_HEADER: PROTOCOL_VERSION_CURRENT}, + ) + inner_transport = JsonRpcTransport( + httpx_client, agent_card, 'http://testserver' + ) + retry_transport = RetryTransport( + inner_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + async with retry_transport: + result = await retry_transport.get_task( + GetTaskRequest(id='task-retry-test') + ) + + assert result.id == 'task-retry-test' + assert failure_count == fail_limit + mock_request_handler.on_get_task.assert_called_once() + + @pytest.mark.asyncio + async def test_retry_exhaustion_with_persistent_503( + self, + mock_request_handler: AsyncMock, + agent_card: AgentCard, + ) -> None: + """Verify that retries are exhausted when 503 persists beyond max_retries.""" + rest_routes = create_rest_routes( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) + app = Starlette(routes=[*rest_routes]) + + # Always return 503 + async def always_fail_app(scope, receive, send): + if scope['type'] == 'http': + await send( + { + 'type': 'http.response.start', + 'status': 503, + 'headers': [ + [b'content-type', b'text/plain'], + ], + } + ) + await send( + { + 'type': 'http.response.body', + 'body': b'Service Unavailable', + } + ) + return + await app(scope, receive, send) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=always_fail_app) + ) + inner_transport = RestTransport( + httpx_client, agent_card, 'http://testserver' + ) + retry_transport = RetryTransport( + inner_transport, + max_retries=2, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + async with retry_transport: + with pytest.raises(A2AClientError, match='HTTP Error 503'): + await retry_transport.get_task( + GetTaskRequest(id='task-retry-test') + ) + + mock_request_handler.on_get_task.assert_not_called() diff --git a/tests/integration/test_retry_integration.py b/tests/integration/test_retry_integration.py new file mode 100644 index 00000000..4f18995f --- /dev/null +++ b/tests/integration/test_retry_integration.py @@ -0,0 +1,265 @@ +"""Integration tests for RetryTransport through the full client stack. + +Tests RetryTransport composed with ClientFactory/BaseClient against +real Starlette servers, validating the full client -> retry -> transport -> server +path that end users would follow. +""" + +from unittest.mock import ANY, AsyncMock + +import httpx +import pytest + +from starlette.applications import Starlette + +from a2a.client.base_client import BaseClient +from a2a.client.client import ClientConfig +from a2a.client.client_factory import ClientFactory +from a2a.client.transports.retry import RetryTransport +from a2a.server.request_handlers import RequestHandler +from a2a.server.routes import create_jsonrpc_routes, create_rest_routes +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + GetTaskRequest, + Message, + Part, + Role, + SendMessageRequest, + Task, + TaskState, + TaskStatus, +) +from a2a.utils.constants import TransportProtocol + + +# --- Test Constants --- + +TASK_RESPONSE = Task( + id='task-retry-integration', + context_id='ctx-retry-integration', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), +) + + +# --- Helpers --- + + +def _wrap_with_transient_503(app, fail_count: int = 2): + """Wraps an ASGI app to return 503 for the first N requests. + + Returns a tuple of (middleware_app, state_dict) where state_dict['count'] + tracks how many requests were intercepted. + """ + state = {'count': 0} + + async def middleware(scope, receive, send): + if scope['type'] == 'http' and state['count'] < fail_count: + state['count'] += 1 + await send( + { + 'type': 'http.response.start', + 'status': 503, + 'headers': [[b'content-type', b'text/plain']], + } + ) + await send( + { + 'type': 'http.response.body', + 'body': b'Service Unavailable', + } + ) + return + await app(scope, receive, send) + + return middleware, state + + +# --- Test Fixtures --- + + +@pytest.fixture +def mock_request_handler() -> AsyncMock: + """Provides a mock RequestHandler with retry-relevant responses.""" + handler = AsyncMock(spec=RequestHandler) + + # Configure responses for retry tests + handler.on_get_task.return_value = TASK_RESPONSE + handler.on_message_send.return_value = TASK_RESPONSE + + return handler + + +@pytest.fixture +def agent_card() -> AgentCard: + """Provides a sample AgentCard for retry integration tests.""" + return AgentCard( + name='Retry Integration Agent', + description='Agent for retry integration testing.', + version='1.0.0', + capabilities=AgentCapabilities(streaming=False), + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + supported_interfaces=[ + AgentInterface( + protocol_binding=TransportProtocol.HTTP_JSON, + url='http://testserver', + ), + AgentInterface( + protocol_binding=TransportProtocol.JSONRPC, + url='http://testserver', + ), + ], + ) + + +# --- The Integration Tests --- + + +@pytest.mark.asyncio +async def test_retry_with_client_factory_rest( + mock_request_handler: AsyncMock, + agent_card: AgentCard, +) -> None: + """Full stack: ClientFactory -> BaseClient -> RetryTransport -> RestTransport -> server.""" + rest_routes = create_rest_routes( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) + app = Starlette(routes=[*rest_routes]) + failing_app, state = _wrap_with_transient_503(app, fail_count=2) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=failing_app), + ) + + factory = ClientFactory( + config=ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ) + client = factory.create(agent_card) + + # Wrap the transport with RetryTransport + assert isinstance(client, BaseClient) + original_transport = client._transport + client._transport = RetryTransport( + original_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + params = GetTaskRequest(id=TASK_RESPONSE.id) + result = await client.get_task(request=params) + + assert result.id == TASK_RESPONSE.id + assert state['count'] == 2 + mock_request_handler.on_get_task.assert_awaited_once_with(params, ANY) + + await client.close() + + +@pytest.mark.asyncio +async def test_retry_with_client_factory_jsonrpc( + mock_request_handler: AsyncMock, + agent_card: AgentCard, +) -> None: + """Full stack: ClientFactory -> BaseClient -> RetryTransport -> JsonRpcTransport -> server.""" + jsonrpc_routes = create_jsonrpc_routes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', + ) + app = Starlette(routes=[*jsonrpc_routes]) + failing_app, state = _wrap_with_transient_503(app, fail_count=2) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=failing_app), + ) + + factory = ClientFactory( + config=ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.JSONRPC], + ) + ) + client = factory.create(agent_card) + + # Wrap the transport with RetryTransport + assert isinstance(client, BaseClient) + original_transport = client._transport + client._transport = RetryTransport( + original_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + params = GetTaskRequest(id=TASK_RESPONSE.id) + result = await client.get_task(request=params) + + assert result.id == TASK_RESPONSE.id + assert state['count'] == 2 + mock_request_handler.on_get_task.assert_awaited_once_with(params, ANY) + + await client.close() + + +@pytest.mark.asyncio +async def test_retry_send_message_blocking( + mock_request_handler: AsyncMock, + agent_card: AgentCard, +) -> None: + """Full stack: send_message through RetryTransport with transient failures.""" + rest_routes = create_rest_routes( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) + app = Starlette(routes=[*rest_routes]) + failing_app, state = _wrap_with_transient_503(app, fail_count=1) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=failing_app), + ) + + factory = ClientFactory( + config=ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ) + client = factory.create(agent_card) + + # Disable streaming to force blocking call + assert isinstance(client, BaseClient) + client._config.streaming = False + original_transport = client._transport + client._transport = RetryTransport( + original_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + ) + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-retry-test', + parts=[Part(text='Hello retry')], + ) + params = SendMessageRequest(message=message_to_send) + + events = [event async for event in client.send_message(request=params)] + + assert len(events) == 1 + _, task = events[0] + assert task is not None + assert task.id == TASK_RESPONSE.id + assert state['count'] == 1 + mock_request_handler.on_message_send.assert_awaited_once_with(params, ANY) + + await client.close()