Skip to content

Commit 697438f

Browse files
committed
feat: Add client-side extension support
This commit introduces support for clients to declare the extensions they support. - Adds an `extensions` list to `ClientConfig`. - Updates `ClientFactory` to pass `client_extensions` to `JsonRpcTransport` and `RestTransport`. - Adds `_update_extension_header` method to both transports to update the `X-A2A-Extensions` header. - Modifies `send_message` and `send_message_streaming` in `JsonRpcTransport` to include the extension headers. - Modifies `_prepare_send_message` in `RestTransport` to include the extension headers. - Adds tests for the extension header logic in both JSON-RPC and REST transports, including a new test file `test_rest_client.py`.
1 parent 01b421e commit 697438f

File tree

6 files changed

+406
-0
lines changed

6 files changed

+406
-0
lines changed

src/a2a/client/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class ClientConfig:
6767
)
6868
"""Push notification callbacks to use for every request."""
6969

70+
extensions: list[str] = dataclasses.field(default_factory=list)
71+
"""A list of extension URIs the client supports."""
72+
7073

7174
UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None
7275
# Alias for emitted events from client

src/a2a/client/client_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _register_defaults(
7777
TransportProtocol.jsonrpc,
7878
lambda card, url, config, interceptors: JsonRpcTransport(
7979
config.httpx_client or httpx.AsyncClient(),
80+
config.extensions or None,
8081
card,
8182
url,
8283
interceptors,
@@ -87,6 +88,7 @@ def _register_defaults(
8788
TransportProtocol.http_json,
8889
lambda card, url, config, interceptors: RestTransport(
8990
config.httpx_client or httpx.AsyncClient(),
91+
config.extensions or None,
9092
card,
9193
url,
9294
interceptors,

src/a2a/client/transports/jsonrpc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
2020
from a2a.client.transports.base import ClientTransport
21+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2122
from a2a.types import (
2223
AgentCard,
2324
CancelTaskRequest,
@@ -59,6 +60,7 @@ class JsonRpcTransport(ClientTransport):
5960
def __init__(
6061
self,
6162
httpx_client: httpx.AsyncClient,
63+
client_extensions: list[str] | None = None,
6264
agent_card: AgentCard | None = None,
6365
url: str | None = None,
6466
interceptors: list[ClientCallInterceptor] | None = None,
@@ -72,6 +74,7 @@ def __init__(
7274
raise ValueError('Must provide either agent_card or url')
7375

7476
self.httpx_client = httpx_client
77+
self.client_extensions = client_extensions
7578
self.agent_card = agent_card
7679
self.interceptors = interceptors or []
7780
self._needs_extended_card = (
@@ -80,6 +83,20 @@ def __init__(
8083
else True
8184
)
8285

86+
def _update_extension_header(
87+
self, http_kwargs: dict[str, Any]
88+
) -> dict[str, Any]:
89+
if self.client_extensions:
90+
headers = http_kwargs.get('headers', {})
91+
existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '')
92+
split = (
93+
existing_extensions.split(', ') if existing_extensions else []
94+
)
95+
updated_extensions = list(set(self.client_extensions + split))
96+
headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions)
97+
http_kwargs['headers'] = headers
98+
return http_kwargs
99+
83100
async def _apply_interceptors(
84101
self,
85102
method_name: str,
@@ -122,6 +139,7 @@ async def send_message(
122139
self._get_http_args(context),
123140
context,
124141
)
142+
modified_kwargs = self._update_extension_header(modified_kwargs)
125143
response_data = await self._send_request(payload, modified_kwargs)
126144
response = SendMessageResponse.model_validate(response_data)
127145
if isinstance(response.root, JSONRPCErrorResponse):
@@ -147,6 +165,7 @@ async def send_message_streaming(
147165
context,
148166
)
149167

168+
modified_kwargs = self._update_extension_header(modified_kwargs)
150169
modified_kwargs.setdefault(
151170
'timeout', self.httpx_client.timeout.as_dict().get('read', None)
152171
)

src/a2a/client/transports/rest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
1414
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1515
from a2a.client.transports.base import ClientTransport
16+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
1617
from a2a.grpc import a2a_pb2
1718
from a2a.types import (
1819
AgentCard,
@@ -40,6 +41,7 @@ class RestTransport(ClientTransport):
4041
def __init__(
4142
self,
4243
httpx_client: httpx.AsyncClient,
44+
client_extensions: list[str] | None = None,
4345
agent_card: AgentCard | None = None,
4446
url: str | None = None,
4547
interceptors: list[ClientCallInterceptor] | None = None,
@@ -54,6 +56,7 @@ def __init__(
5456
if self.url.endswith('/'):
5557
self.url = self.url[:-1]
5658
self.httpx_client = httpx_client
59+
self.client_extensions = client_extensions
5760
self.agent_card = agent_card
5861
self.interceptors = interceptors or []
5962
self._needs_extended_card = (
@@ -62,6 +65,20 @@ def __init__(
6265
else True
6366
)
6467

68+
def _update_extension_header(
69+
self, http_kwargs: dict[str, Any]
70+
) -> dict[str, Any]:
71+
if self.client_extensions:
72+
headers = http_kwargs.get('headers', {})
73+
existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '')
74+
split = (
75+
existing_extensions.split(', ') if existing_extensions else []
76+
)
77+
updated_extensions = list(set(self.client_extensions + split))
78+
headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions)
79+
http_kwargs['headers'] = headers
80+
return http_kwargs
81+
6582
async def _apply_interceptors(
6683
self,
6784
request_payload: dict[str, Any],
@@ -98,6 +115,7 @@ async def _prepare_send_message(
98115
self._get_http_args(context),
99116
context,
100117
)
118+
modified_kwargs = self._update_extension_header(modified_kwargs)
101119
return payload, modified_kwargs
102120

103121
async def send_message(

tests/client/test_jsonrpc_client.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
create_text_message_object,
1818
)
1919
from a2a.client.transports.jsonrpc import JsonRpcTransport
20+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2021
from a2a.types import (
2122
AgentCapabilities,
2223
AgentCard,
@@ -785,3 +786,181 @@ async def test_close(self, mock_httpx_client: AsyncMock):
785786
)
786787
await client.close()
787788
mock_httpx_client.aclose.assert_called_once()
789+
790+
791+
class TestJsonRpcTransportExtensions:
792+
def test_update_extension_header_no_initial_headers(
793+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
794+
):
795+
extensions = ['test_extension_1', 'test_extension_2']
796+
client = JsonRpcTransport(
797+
mock_httpx_client, extensions, mock_agent_card
798+
)
799+
http_kwargs = {}
800+
result_kwargs = client._update_extension_header(http_kwargs)
801+
actual_extensions = set(
802+
result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ')
803+
)
804+
expected_extensions = {'test_extension_1', 'test_extension_2'}
805+
assert actual_extensions == expected_extensions
806+
807+
def test_update_extension_header_with_existing_other_headers(
808+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
809+
):
810+
extensions = ['test_extension_1']
811+
client = JsonRpcTransport(
812+
mock_httpx_client, extensions, mock_agent_card
813+
)
814+
http_kwargs = {'headers': {'X_Other': 'Test'}}
815+
result_kwargs = client._update_extension_header(http_kwargs)
816+
assert (
817+
result_kwargs['headers'][HTTP_EXTENSION_HEADER]
818+
== 'test_extension_1'
819+
)
820+
assert result_kwargs['headers']['X_Other'] == 'Test'
821+
822+
def test_update_extension_header_merge_with_existing_extensions(
823+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
824+
):
825+
extensions = ['test_extension_1', 'test_extension_2']
826+
client = JsonRpcTransport(
827+
mock_httpx_client, extensions, mock_agent_card
828+
)
829+
http_kwargs = {
830+
'headers': {
831+
HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3'
832+
}
833+
}
834+
result_kwargs = client._update_extension_header(http_kwargs)
835+
actual_extensions_list = result_kwargs['headers'][
836+
HTTP_EXTENSION_HEADER
837+
].split(', ')
838+
actual_extensions = set(actual_extensions_list)
839+
expected_extensions = {
840+
'test_extension_1',
841+
'test_extension_2',
842+
'test_extension_3',
843+
}
844+
assert len(actual_extensions_list) == 3
845+
assert actual_extensions == expected_extensions
846+
847+
def test_update_extension_header_no_client_extensions(
848+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
849+
):
850+
client = JsonRpcTransport(mock_httpx_client, None, mock_agent_card)
851+
http_kwargs = {'headers': {'X_Other': 'Test'}}
852+
result_kwargs = client._update_extension_header(http_kwargs)
853+
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
854+
assert result_kwargs['headers']['X_Other'] == 'Test'
855+
856+
def test_update_extension_header_empty_client_extensions(
857+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
858+
):
859+
client = JsonRpcTransport(mock_httpx_client, [], mock_agent_card)
860+
http_kwargs = {'headers': {'X_Other': 'Test'}}
861+
result_kwargs = client._update_extension_header(http_kwargs)
862+
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
863+
assert result_kwargs['headers']['X_Other'] == 'Test'
864+
865+
@pytest.mark.asyncio
866+
async def test_send_message_with_extensions(
867+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
868+
):
869+
"""Test that send_message adds extension headers when client_extensions are provided."""
870+
extensions = ['test_extension_1', 'test_extension_2']
871+
client = JsonRpcTransport(
872+
httpx_client=mock_httpx_client,
873+
client_extensions=extensions,
874+
agent_card=mock_agent_card,
875+
)
876+
params = MessageSendParams(
877+
message=create_text_message_object(content='Hello')
878+
)
879+
success_response = create_text_message_object(
880+
role=Role.agent, content='Hi there!'
881+
)
882+
rpc_response = SendMessageSuccessResponse(
883+
id='123', jsonrpc='2.0', result=success_response
884+
)
885+
# Mock the response from httpx_client.post
886+
mock_response = AsyncMock(spec=httpx.Response)
887+
mock_response.status_code = 200
888+
mock_response.json.return_value = rpc_response.model_dump(mode='json')
889+
mock_httpx_client.post.return_value = mock_response
890+
891+
await client.send_message(request=params)
892+
893+
mock_httpx_client.post.assert_called_once()
894+
_, mock_kwargs = mock_httpx_client.post.call_args
895+
headers = mock_kwargs.get('headers', {})
896+
assert HTTP_EXTENSION_HEADER in headers
897+
actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', '))
898+
expected_extensions = {'test_extension_1', 'test_extension_2'}
899+
assert actual_extensions == expected_extensions
900+
901+
@pytest.mark.asyncio
902+
async def test_send_message_no_extensions(
903+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
904+
):
905+
"""Test that send_message does not add extension headers when client_extensions is None."""
906+
client = JsonRpcTransport(
907+
httpx_client=mock_httpx_client,
908+
client_extensions=None,
909+
agent_card=mock_agent_card,
910+
)
911+
params = MessageSendParams(
912+
message=create_text_message_object(content='Hello')
913+
)
914+
success_response = create_text_message_object(
915+
role=Role.agent, content='Hi there!'
916+
)
917+
rpc_response = SendMessageSuccessResponse(
918+
id='123', jsonrpc='2.0', result=success_response
919+
)
920+
# Mock the response from httpx_client.post
921+
mock_response = AsyncMock(spec=httpx.Response)
922+
mock_response.status_code = 200
923+
mock_response.json.return_value = rpc_response.model_dump(mode='json')
924+
mock_httpx_client.post.return_value = mock_response
925+
926+
await client.send_message(request=params)
927+
928+
mock_httpx_client.post.assert_called_once()
929+
_, mock_kwargs = mock_httpx_client.post.call_args
930+
headers = mock_kwargs.get('headers', {})
931+
assert HTTP_EXTENSION_HEADER not in headers
932+
933+
@pytest.mark.asyncio
934+
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
935+
async def test_send_message_streaming_with_extensions(
936+
self,
937+
mock_aconnect_sse: AsyncMock,
938+
mock_httpx_client: AsyncMock,
939+
mock_agent_card: MagicMock,
940+
):
941+
"""Test X-A2A-Extensions header in send_message_streaming."""
942+
extensions = ['test_extension']
943+
client = JsonRpcTransport(
944+
httpx_client=mock_httpx_client,
945+
client_extensions=extensions,
946+
agent_card=mock_agent_card,
947+
)
948+
params = MessageSendParams(
949+
message=create_text_message_object(content='Hello stream')
950+
)
951+
952+
mock_event_source = AsyncMock(spec=EventSource)
953+
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
954+
mock_aconnect_sse.return_value.__aenter__.return_value = (
955+
mock_event_source
956+
)
957+
958+
async for _ in client.send_message_streaming(request=params):
959+
pass
960+
961+
mock_aconnect_sse.assert_called_once()
962+
_, kwargs = mock_aconnect_sse.call_args
963+
964+
headers = kwargs.get('headers', {})
965+
assert HTTP_EXTENSION_HEADER in headers
966+
assert headers[HTTP_EXTENSION_HEADER] == 'test_extension'

0 commit comments

Comments
 (0)