|
17 | 17 | create_text_message_object, |
18 | 18 | ) |
19 | 19 | from a2a.client.transports.jsonrpc import JsonRpcTransport |
| 20 | +from a2a.extensions.common import HTTP_EXTENSION_HEADER |
20 | 21 | from a2a.types import ( |
21 | 22 | AgentCapabilities, |
22 | 23 | AgentCard, |
@@ -785,3 +786,181 @@ async def test_close(self, mock_httpx_client: AsyncMock): |
785 | 786 | ) |
786 | 787 | await client.close() |
787 | 788 | mock_httpx_client.aclose.assert_called_once() |
| 789 | + |
| 790 | + |
| 791 | +class TestJsonRpcTransportExtensions: |
| 792 | + def test_update_extension_header_no_initial_headers( |
| 793 | + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock |
| 794 | + ): |
| 795 | + extensions = ['test_extension_1', 'test_extension_2'] |
| 796 | + client = JsonRpcTransport( |
| 797 | + mock_httpx_client, extensions, mock_agent_card |
| 798 | + ) |
| 799 | + http_kwargs = {} |
| 800 | + result_kwargs = client._update_extension_header(http_kwargs) |
| 801 | + actual_extensions = set( |
| 802 | + result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ') |
| 803 | + ) |
| 804 | + expected_extensions = {'test_extension_1', 'test_extension_2'} |
| 805 | + assert actual_extensions == expected_extensions |
| 806 | + |
| 807 | + def test_update_extension_header_with_existing_other_headers( |
| 808 | + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock |
| 809 | + ): |
| 810 | + extensions = ['test_extension_1'] |
| 811 | + client = JsonRpcTransport( |
| 812 | + mock_httpx_client, extensions, mock_agent_card |
| 813 | + ) |
| 814 | + http_kwargs = {'headers': {'X_Other': 'Test'}} |
| 815 | + result_kwargs = client._update_extension_header(http_kwargs) |
| 816 | + assert ( |
| 817 | + result_kwargs['headers'][HTTP_EXTENSION_HEADER] |
| 818 | + == 'test_extension_1' |
| 819 | + ) |
| 820 | + assert result_kwargs['headers']['X_Other'] == 'Test' |
| 821 | + |
| 822 | + def test_update_extension_header_merge_with_existing_extensions( |
| 823 | + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock |
| 824 | + ): |
| 825 | + extensions = ['test_extension_1', 'test_extension_2'] |
| 826 | + client = JsonRpcTransport( |
| 827 | + mock_httpx_client, extensions, mock_agent_card |
| 828 | + ) |
| 829 | + http_kwargs = { |
| 830 | + 'headers': { |
| 831 | + HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3' |
| 832 | + } |
| 833 | + } |
| 834 | + result_kwargs = client._update_extension_header(http_kwargs) |
| 835 | + actual_extensions_list = result_kwargs['headers'][ |
| 836 | + HTTP_EXTENSION_HEADER |
| 837 | + ].split(', ') |
| 838 | + actual_extensions = set(actual_extensions_list) |
| 839 | + expected_extensions = { |
| 840 | + 'test_extension_1', |
| 841 | + 'test_extension_2', |
| 842 | + 'test_extension_3', |
| 843 | + } |
| 844 | + assert len(actual_extensions_list) == 3 |
| 845 | + assert actual_extensions == expected_extensions |
| 846 | + |
| 847 | + def test_update_extension_header_no_client_extensions( |
| 848 | + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock |
| 849 | + ): |
| 850 | + client = JsonRpcTransport(mock_httpx_client, None, mock_agent_card) |
| 851 | + http_kwargs = {'headers': {'X_Other': 'Test'}} |
| 852 | + result_kwargs = client._update_extension_header(http_kwargs) |
| 853 | + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] |
| 854 | + assert result_kwargs['headers']['X_Other'] == 'Test' |
| 855 | + |
| 856 | + def test_update_extension_header_empty_client_extensions( |
| 857 | + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock |
| 858 | + ): |
| 859 | + client = JsonRpcTransport(mock_httpx_client, [], mock_agent_card) |
| 860 | + http_kwargs = {'headers': {'X_Other': 'Test'}} |
| 861 | + result_kwargs = client._update_extension_header(http_kwargs) |
| 862 | + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] |
| 863 | + assert result_kwargs['headers']['X_Other'] == 'Test' |
| 864 | + |
| 865 | + @pytest.mark.asyncio |
| 866 | + async def test_send_message_with_extensions( |
| 867 | + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock |
| 868 | + ): |
| 869 | + """Test that send_message adds extension headers when client_extensions are provided.""" |
| 870 | + extensions = ['test_extension_1', 'test_extension_2'] |
| 871 | + client = JsonRpcTransport( |
| 872 | + httpx_client=mock_httpx_client, |
| 873 | + client_extensions=extensions, |
| 874 | + agent_card=mock_agent_card, |
| 875 | + ) |
| 876 | + params = MessageSendParams( |
| 877 | + message=create_text_message_object(content='Hello') |
| 878 | + ) |
| 879 | + success_response = create_text_message_object( |
| 880 | + role=Role.agent, content='Hi there!' |
| 881 | + ) |
| 882 | + rpc_response = SendMessageSuccessResponse( |
| 883 | + id='123', jsonrpc='2.0', result=success_response |
| 884 | + ) |
| 885 | + # Mock the response from httpx_client.post |
| 886 | + mock_response = AsyncMock(spec=httpx.Response) |
| 887 | + mock_response.status_code = 200 |
| 888 | + mock_response.json.return_value = rpc_response.model_dump(mode='json') |
| 889 | + mock_httpx_client.post.return_value = mock_response |
| 890 | + |
| 891 | + await client.send_message(request=params) |
| 892 | + |
| 893 | + mock_httpx_client.post.assert_called_once() |
| 894 | + _, mock_kwargs = mock_httpx_client.post.call_args |
| 895 | + headers = mock_kwargs.get('headers', {}) |
| 896 | + assert HTTP_EXTENSION_HEADER in headers |
| 897 | + actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', ')) |
| 898 | + expected_extensions = {'test_extension_1', 'test_extension_2'} |
| 899 | + assert actual_extensions == expected_extensions |
| 900 | + |
| 901 | + @pytest.mark.asyncio |
| 902 | + async def test_send_message_no_extensions( |
| 903 | + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock |
| 904 | + ): |
| 905 | + """Test that send_message does not add extension headers when client_extensions is None.""" |
| 906 | + client = JsonRpcTransport( |
| 907 | + httpx_client=mock_httpx_client, |
| 908 | + client_extensions=None, |
| 909 | + agent_card=mock_agent_card, |
| 910 | + ) |
| 911 | + params = MessageSendParams( |
| 912 | + message=create_text_message_object(content='Hello') |
| 913 | + ) |
| 914 | + success_response = create_text_message_object( |
| 915 | + role=Role.agent, content='Hi there!' |
| 916 | + ) |
| 917 | + rpc_response = SendMessageSuccessResponse( |
| 918 | + id='123', jsonrpc='2.0', result=success_response |
| 919 | + ) |
| 920 | + # Mock the response from httpx_client.post |
| 921 | + mock_response = AsyncMock(spec=httpx.Response) |
| 922 | + mock_response.status_code = 200 |
| 923 | + mock_response.json.return_value = rpc_response.model_dump(mode='json') |
| 924 | + mock_httpx_client.post.return_value = mock_response |
| 925 | + |
| 926 | + await client.send_message(request=params) |
| 927 | + |
| 928 | + mock_httpx_client.post.assert_called_once() |
| 929 | + _, mock_kwargs = mock_httpx_client.post.call_args |
| 930 | + headers = mock_kwargs.get('headers', {}) |
| 931 | + assert HTTP_EXTENSION_HEADER not in headers |
| 932 | + |
| 933 | + @pytest.mark.asyncio |
| 934 | + @patch('a2a.client.transports.jsonrpc.aconnect_sse') |
| 935 | + async def test_send_message_streaming_with_extensions( |
| 936 | + self, |
| 937 | + mock_aconnect_sse: AsyncMock, |
| 938 | + mock_httpx_client: AsyncMock, |
| 939 | + mock_agent_card: MagicMock, |
| 940 | + ): |
| 941 | + """Test X-A2A-Extensions header in send_message_streaming.""" |
| 942 | + extensions = ['test_extension'] |
| 943 | + client = JsonRpcTransport( |
| 944 | + httpx_client=mock_httpx_client, |
| 945 | + client_extensions=extensions, |
| 946 | + agent_card=mock_agent_card, |
| 947 | + ) |
| 948 | + params = MessageSendParams( |
| 949 | + message=create_text_message_object(content='Hello stream') |
| 950 | + ) |
| 951 | + |
| 952 | + mock_event_source = AsyncMock(spec=EventSource) |
| 953 | + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) |
| 954 | + mock_aconnect_sse.return_value.__aenter__.return_value = ( |
| 955 | + mock_event_source |
| 956 | + ) |
| 957 | + |
| 958 | + async for _ in client.send_message_streaming(request=params): |
| 959 | + pass |
| 960 | + |
| 961 | + mock_aconnect_sse.assert_called_once() |
| 962 | + _, kwargs = mock_aconnect_sse.call_args |
| 963 | + |
| 964 | + headers = kwargs.get('headers', {}) |
| 965 | + assert HTTP_EXTENSION_HEADER in headers |
| 966 | + assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' |
0 commit comments