Skip to content

Commit 48ea2ae

Browse files
committed
feat: update extension handling in transports and tests, migrate utility functions to common module
1 parent edd7982 commit 48ea2ae

File tree

9 files changed

+136
-131
lines changed

9 files changed

+136
-131
lines changed

src/a2a/client/base_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(
4040
):
4141
super().__init__(consumers, middleware, extensions)
4242
self._card = card
43-
config.extensions = extensions
4443
self._config = config
4544
self._transport = transport
4645

src/a2a/client/transports/jsonrpc.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
2020
from a2a.client.transports.base import ClientTransport
21-
from a2a.client.transports.utils import get_http_args, update_extension_header
21+
from a2a.extensions.common import update_extension_header
2222
from a2a.types import (
2323
AgentCard,
2424
CancelTaskRequest,
@@ -106,6 +106,11 @@ async def _apply_interceptors(
106106
)
107107
return final_request_payload, final_http_kwargs
108108

109+
def _get_http_args(
110+
self, context: ClientCallContext | None
111+
) -> dict[str, Any] | None:
112+
return context.state.get('http_kwargs') if context else None
113+
109114
async def send_message(
110115
self,
111116
request: MessageSendParams,
@@ -117,7 +122,7 @@ async def send_message(
117122
payload, modified_kwargs = await self._apply_interceptors(
118123
'message/send',
119124
rpc_request.model_dump(mode='json', exclude_none=True),
120-
get_http_args(context),
125+
self._get_http_args(context),
121126
context,
122127
)
123128
modified_kwargs = update_extension_header(
@@ -144,7 +149,7 @@ async def send_message_streaming(
144149
payload, modified_kwargs = await self._apply_interceptors(
145150
'message/stream',
146151
rpc_request.model_dump(mode='json', exclude_none=True),
147-
get_http_args(context),
152+
self._get_http_args(context),
148153
context,
149154
)
150155

@@ -217,7 +222,7 @@ async def get_task(
217222
payload, modified_kwargs = await self._apply_interceptors(
218223
'tasks/get',
219224
rpc_request.model_dump(mode='json', exclude_none=True),
220-
get_http_args(context),
225+
self._get_http_args(context),
221226
context,
222227
)
223228
modified_kwargs = update_extension_header(
@@ -240,7 +245,7 @@ async def cancel_task(
240245
payload, modified_kwargs = await self._apply_interceptors(
241246
'tasks/cancel',
242247
rpc_request.model_dump(mode='json', exclude_none=True),
243-
get_http_args(context),
248+
self._get_http_args(context),
244249
context,
245250
)
246251
modified_kwargs = update_extension_header(
@@ -265,7 +270,7 @@ async def set_task_callback(
265270
payload, modified_kwargs = await self._apply_interceptors(
266271
'tasks/pushNotificationConfig/set',
267272
rpc_request.model_dump(mode='json', exclude_none=True),
268-
get_http_args(context),
273+
self._get_http_args(context),
269274
context,
270275
)
271276
modified_kwargs = update_extension_header(
@@ -292,7 +297,7 @@ async def get_task_callback(
292297
payload, modified_kwargs = await self._apply_interceptors(
293298
'tasks/pushNotificationConfig/get',
294299
rpc_request.model_dump(mode='json', exclude_none=True),
295-
get_http_args(context),
300+
self._get_http_args(context),
296301
context,
297302
)
298303
modified_kwargs = update_extension_header(
@@ -319,7 +324,7 @@ async def resubscribe(
319324
payload, modified_kwargs = await self._apply_interceptors(
320325
'tasks/resubscribe',
321326
rpc_request.model_dump(mode='json', exclude_none=True),
322-
get_http_args(context),
327+
self._get_http_args(context),
323328
context,
324329
)
325330
modified_kwargs = update_extension_header(
@@ -363,7 +368,7 @@ async def get_card(
363368
if not card:
364369
resolver = A2ACardResolver(self.httpx_client, self.url)
365370
card = await resolver.get_agent_card(
366-
http_kwargs=get_http_args(context)
371+
http_kwargs=self._get_http_args(context)
367372
)
368373
self._needs_extended_card = (
369374
card.supports_authenticated_extended_card
@@ -377,7 +382,7 @@ async def get_card(
377382
payload, modified_kwargs = await self._apply_interceptors(
378383
request.method,
379384
request.model_dump(mode='json', exclude_none=True),
380-
get_http_args(context),
385+
self._get_http_args(context),
381386
context,
382387
)
383388
modified_kwargs = update_extension_header(

src/a2a/client/transports/rest.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
1414
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1515
from a2a.client.transports.base import ClientTransport
16-
from a2a.client.transports.utils import get_http_args, update_extension_header
16+
from a2a.extensions.common import update_extension_header
1717
from a2a.grpc import a2a_pb2
1818
from a2a.types import (
1919
AgentCard,
@@ -76,6 +76,11 @@ async def _apply_interceptors(
7676
# TODO: Implement interceptors for other transports
7777
return final_request_payload, final_http_kwargs
7878

79+
def _get_http_args(
80+
self, context: ClientCallContext | None
81+
) -> dict[str, Any] | None:
82+
return context.state.get('http_kwargs') if context else None
83+
7984
async def _prepare_send_message(
8085
self, request: MessageSendParams, context: ClientCallContext | None
8186
) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -93,7 +98,7 @@ async def _prepare_send_message(
9398
payload = MessageToDict(pb)
9499
payload, modified_kwargs = await self._apply_interceptors(
95100
payload,
96-
get_http_args(context),
101+
self._get_http_args(context),
97102
context,
98103
)
99104
modified_kwargs = update_extension_header(
@@ -209,7 +214,7 @@ async def get_task(
209214
"""Retrieves the current state and history of a specific task."""
210215
_payload, modified_kwargs = await self._apply_interceptors(
211216
request.model_dump(mode='json', exclude_none=True),
212-
get_http_args(context),
217+
self._get_http_args(context),
213218
context,
214219
)
215220
modified_kwargs = update_extension_header(
@@ -237,7 +242,7 @@ async def cancel_task(
237242
payload = MessageToDict(pb)
238243
payload, modified_kwargs = await self._apply_interceptors(
239244
payload,
240-
get_http_args(context),
245+
self._get_http_args(context),
241246
context,
242247
)
243248
modified_kwargs = update_extension_header(
@@ -264,7 +269,7 @@ async def set_task_callback(
264269
)
265270
payload = MessageToDict(pb)
266271
payload, modified_kwargs = await self._apply_interceptors(
267-
payload, get_http_args(context), context
272+
payload, self._get_http_args(context), context
268273
)
269274
modified_kwargs = update_extension_header(
270275
modified_kwargs, self.extensions
@@ -291,7 +296,7 @@ async def get_task_callback(
291296
payload = MessageToDict(pb)
292297
payload, modified_kwargs = await self._apply_interceptors(
293298
payload,
294-
get_http_args(context),
299+
self._get_http_args(context),
295300
context,
296301
)
297302
modified_kwargs = update_extension_header(
@@ -315,7 +320,7 @@ async def resubscribe(
315320
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
316321
]:
317322
"""Reconnects to get task updates."""
318-
http_kwargs = get_http_args(context) or {}
323+
http_kwargs = self._get_http_args(context) or {}
319324
http_kwargs.setdefault('timeout', None)
320325
modified_kwargs = update_extension_header(http_kwargs, self.extensions)
321326

@@ -351,7 +356,7 @@ async def get_card(
351356
if not card:
352357
resolver = A2ACardResolver(self.httpx_client, self.url)
353358
card = await resolver.get_agent_card(
354-
http_kwargs=get_http_args(context)
359+
http_kwargs=self._get_http_args(context)
355360
)
356361
self._needs_extended_card = (
357362
card.supports_authenticated_extended_card
@@ -363,7 +368,7 @@ async def get_card(
363368

364369
_, modified_kwargs = await self._apply_interceptors(
365370
{},
366-
get_http_args(context),
371+
self._get_http_args(context),
367372
context,
368373
)
369374
modified_kwargs = update_extension_header(

src/a2a/client/transports/utils.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

src/a2a/extensions/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
from a2a.types import AgentCard, AgentExtension
24

35

@@ -25,3 +27,18 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
2527
return ext
2628

2729
return None
30+
31+
32+
def update_extension_header(
33+
http_kwargs: dict[str, Any], extensions: list[str] | None
34+
) -> dict[str, Any]:
35+
if extensions:
36+
headers = http_kwargs.setdefault('headers', {})
37+
existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '')
38+
39+
existing_extensions = get_requested_extensions(
40+
[existing_extensions_str]
41+
)
42+
all_extensions = existing_extensions.union(extensions)
43+
headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions)
44+
return http_kwargs

tests/client/transports/test_jsonrpc_client.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,10 @@ async def test_send_message_with_extensions(
794794
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
795795
):
796796
"""Test that send_message adds extension headers when extensions are provided."""
797-
extensions = ['test_extension_1', 'test_extension_2']
797+
extensions = [
798+
'https://example.com/test-ext/v1',
799+
'https://example.com/test-ext/v2',
800+
]
798801
client = JsonRpcTransport(
799802
httpx_client=mock_httpx_client,
800803
agent_card=mock_agent_card,
@@ -827,8 +830,8 @@ async def test_send_message_with_extensions(
827830
actual_extensions = set(actual_extensions_list)
828831

829832
expected_extensions = {
830-
'test_extension_1',
831-
'test_extension_2',
833+
'https://example.com/test-ext/v1',
834+
'https://example.com/test-ext/v2',
832835
}
833836
assert len(actual_extensions_list) == 2
834837
assert actual_extensions == expected_extensions
@@ -842,7 +845,7 @@ async def test_send_message_streaming_with_extensions(
842845
mock_agent_card: MagicMock,
843846
):
844847
"""Test X-A2A-Extensions header in send_message_streaming."""
845-
extensions = ['test_extension']
848+
extensions = ['https://example.com/test-ext/v1']
846849
client = JsonRpcTransport(
847850
httpx_client=mock_httpx_client,
848851
agent_card=mock_agent_card,
@@ -866,4 +869,6 @@ async def test_send_message_streaming_with_extensions(
866869

867870
headers = kwargs.get('headers', {})
868871
assert HTTP_EXTENSION_HEADER in headers
869-
assert headers[HTTP_EXTENSION_HEADER] == 'test_extension'
872+
assert (
873+
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v1'
874+
)

tests/client/transports/test_rest_client.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ async def test_send_message_with_extensions(
3838
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
3939
):
4040
"""Test that send_message adds extensions to headers."""
41-
extensions = ['test_extension_1', 'test_extension_2']
41+
extensions = [
42+
'https://example.com/test-ext/v1',
43+
'https://example.com/test-ext/v2',
44+
]
4245
client = RestTransport(
4346
httpx_client=mock_httpx_client,
4447
extensions=extensions,
@@ -71,8 +74,8 @@ async def test_send_message_with_extensions(
7174
actual_extensions = set(actual_extensions_list)
7275

7376
expected_extensions = {
74-
'test_extension_1',
75-
'test_extension_2',
77+
'https://example.com/test-ext/v1',
78+
'https://example.com/test-ext/v2',
7679
}
7780
assert len(actual_extensions_list) == 2
7881
assert actual_extensions == expected_extensions
@@ -86,7 +89,7 @@ async def test_send_message_streaming_with_extensions(
8689
mock_agent_card: MagicMock,
8790
):
8891
"""Test X-A2A-Extensions header in send_message_streaming."""
89-
extensions = ['test_extension']
92+
extensions = ['https://example.com/test-ext/v1']
9093
client = RestTransport(
9194
httpx_client=mock_httpx_client,
9295
agent_card=mock_agent_card,
@@ -110,4 +113,6 @@ async def test_send_message_streaming_with_extensions(
110113

111114
headers = kwargs.get('headers', {})
112115
assert HTTP_EXTENSION_HEADER in headers
113-
assert headers[HTTP_EXTENSION_HEADER] == 'test_extension'
116+
assert (
117+
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v1'
118+
)

0 commit comments

Comments
 (0)