Skip to content

Commit 16ee453

Browse files
committed
add integration test for extensions. Add a test case to test_common.py. Change desription of update_extension_header
1 parent a97c5b3 commit 16ee453

File tree

3 files changed

+68
-2
lines changed

3 files changed

+68
-2
lines changed

src/a2a/extensions/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def update_extension_header(
3333
http_kwargs: dict[str, Any],
3434
extensions: list[str] | None,
3535
) -> dict[str, Any]:
36-
"""Update the X-A2A-Extensions header and update active extensions."""
36+
"""Update the X-A2A-Extensions header with active extensions."""
3737
if extensions is not None:
3838
headers = http_kwargs.setdefault('headers', {})
3939
headers[HTTP_EXTENSION_HEADER] = ','.join(extensions)

tests/extensions/test_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,12 @@ def test_update_extension_header_with_other_headers_extensions_none():
126126
result_kwargs = update_extension_header(http_kwargs, None)
127127
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
128128
assert result_kwargs['headers']['X_Other'] == 'Test'
129+
130+
131+
def test_update_extension_header_empty_header():
132+
extensions = ['ext']
133+
http_kwargs = {}
134+
result_kwargs = update_extension_header(http_kwargs, extensions)
135+
headers = result_kwargs.get('headers', {})
136+
assert HTTP_EXTENSION_HEADER in headers
137+
assert headers[HTTP_EXTENSION_HEADER] == 'ext'

tests/integration/test_client_server_integration.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import asyncio
22
from collections.abc import AsyncGenerator
33
from typing import NamedTuple
4-
from unittest.mock import ANY, AsyncMock
4+
from unittest.mock import ANY, AsyncMock, patch
55

66
import grpc
77
import httpx
88
import pytest
99
import pytest_asyncio
1010
from grpc.aio import Channel
1111

12+
from a2a.client import ClientConfig
13+
from a2a.client.base_client import BaseClient
1214
from a2a.client.transports import JsonRpcTransport, RestTransport
1315
from a2a.client.transports.base import ClientTransport
1416
from a2a.client.transports.grpc import GrpcTransport
@@ -767,3 +769,58 @@ def channel_factory(address: str) -> Channel:
767769
assert transport._needs_extended_card is False
768770

769771
await transport.close()
772+
773+
774+
@pytest.mark.asyncio
775+
async def test_base_client_sends_message_with_extensions(
776+
jsonrpc_setup: TransportSetup, agent_card: AgentCard
777+
) -> None:
778+
"""
779+
Integration test for BaseClient with JSON-RPC transport to ensure extensions are included in headers.
780+
"""
781+
transport = jsonrpc_setup.transport
782+
agent_card.capabilities.streaming = False
783+
784+
# Create a BaseClient instance
785+
client = BaseClient(
786+
card=agent_card,
787+
config=ClientConfig(streaming=False),
788+
transport=transport,
789+
consumers=[],
790+
middleware=[],
791+
)
792+
793+
message_to_send = Message(
794+
role=Role.user,
795+
message_id='msg-integration-test-extensions',
796+
parts=[Part(root=TextPart(text='Hello, extensions test!'))],
797+
)
798+
extensions = [
799+
'https://example.com/test-ext/v1',
800+
'https://example.com/test-ext/v2',
801+
]
802+
803+
with patch.object(
804+
transport, '_send_request', new_callable=AsyncMock
805+
) as mock_send_request:
806+
mock_send_request.return_value = {
807+
'id': '123',
808+
'jsonrpc': '2.0',
809+
'result': TASK_FROM_BLOCKING.model_dump(mode='json'),
810+
}
811+
812+
# Call send_message on the BaseClient
813+
async for _ in client.send_message(
814+
request=message_to_send, extensions=extensions
815+
):
816+
pass
817+
818+
mock_send_request.assert_called_once()
819+
call_args, _ = mock_send_request.call_args
820+
kwargs = call_args[1]
821+
headers = kwargs.get('headers', {})
822+
assert 'X-A2A-Extensions' in headers
823+
assert headers['X-A2A-Extensions'] == ','.join(extensions)
824+
825+
if hasattr(transport, 'close'):
826+
await transport.close()

0 commit comments

Comments
 (0)