diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 3d673b31..393d85ec 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -10,6 +10,7 @@ A2AClientError, A2AClientHTTPError, A2AClientJSONError, + A2AClientTimeoutError, ) from a2a.client.grpc_client import A2AGrpcClient from a2a.client.helpers import create_text_message_object @@ -22,6 +23,7 @@ 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', + 'A2AClientTimeoutError', 'A2AGrpcClient', 'AuthInterceptor', 'ClientCallContext', diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 66a1e49b..66dfe0a4 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -10,7 +10,11 @@ from httpx_sse import SSEError, aconnect_sse from pydantic import ValidationError -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientTimeoutError, +) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.types import ( AgentCard, @@ -340,6 +344,8 @@ async def _send_request( ) response.raise_for_status() return response.json() + except httpx.ReadTimeout as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except json.JSONDecodeError as e: diff --git a/src/a2a/client/errors.py b/src/a2a/client/errors.py index da02e582..5fe5512a 100644 --- a/src/a2a/client/errors.py +++ b/src/a2a/client/errors.py @@ -31,3 +31,16 @@ def __init__(self, message: str): """ self.message = message super().__init__(f'JSON Error: {message}') + + +class A2AClientTimeoutError(A2AClientError): + """Client exception for timeout errors during a request.""" + + def __init__(self, message: str): + """Initializes the A2AClientTimeoutError. + + Args: + message: A descriptive error message. + """ + self.message = message + super().__init__(f'Timeout Error: {message}') diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 5b6e9491..00ab8796 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -14,6 +14,7 @@ A2AClient, A2AClientHTTPError, A2AClientJSONError, + A2AClientTimeoutError, create_text_message_object, ) from a2a.types import ( @@ -1266,3 +1267,25 @@ async def test_cancel_task_error_response( mode='json', exclude_none=True ) == error_details.model_dump(exclude_none=True) assert response.root.id == 'err_cancel_req' + + @pytest.mark.asyncio + async def test_send_message_client_timeout( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + mock_httpx_client.post.side_effect = httpx.ReadTimeout( + 'Request timed out' + ) + client = A2AClient( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + + request = SendMessageRequest(id=123, params=params) + + with pytest.raises(A2AClientTimeoutError) as exc_info: + await client.send_message(request=request) + + assert 'Request timed out' in str(exc_info.value)