@@ -65,7 +65,14 @@ def grpc_transport(
6565) -> GrpcTransport :
6666 """Provides a GrpcTransport instance."""
6767 channel = AsyncMock ()
68- transport = GrpcTransport (channel = channel , agent_card = sample_agent_card )
68+ transport = GrpcTransport (
69+ channel = channel ,
70+ agent_card = sample_agent_card ,
71+ extensions = [
72+ 'https://example.com/test-ext/v1' ,
73+ 'https://example.com/test-ext/v2' ,
74+ ],
75+ )
6976 transport .stub = mock_grpc_stub
7077 return transport
7178
@@ -189,6 +196,13 @@ async def test_send_message_task_response(
189196 response = await grpc_transport .send_message (sample_message_send_params )
190197
191198 mock_grpc_stub .SendMessage .assert_awaited_once ()
199+ _ , kwargs = mock_grpc_stub .SendMessage .call_args
200+ assert kwargs ['metadata' ] == [
201+ (
202+ HTTP_EXTENSION_HEADER ,
203+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
204+ )
205+ ]
192206 assert isinstance (response , Task )
193207 assert response .id == sample_task .id
194208
@@ -208,6 +222,13 @@ async def test_send_message_message_response(
208222 response = await grpc_transport .send_message (sample_message_send_params )
209223
210224 mock_grpc_stub .SendMessage .assert_awaited_once ()
225+ _ , kwargs = mock_grpc_stub .SendMessage .call_args
226+ assert kwargs ['metadata' ] == [
227+ (
228+ HTTP_EXTENSION_HEADER ,
229+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
230+ )
231+ ]
211232 assert isinstance (response , Message )
212233 assert response .message_id == sample_message .message_id
213234 assert get_text_parts (response .parts ) == get_text_parts (
@@ -256,6 +277,13 @@ async def test_send_message_streaming( # noqa: PLR0913
256277 ]
257278
258279 mock_grpc_stub .SendStreamingMessage .assert_called_once ()
280+ _ , kwargs = mock_grpc_stub .SendStreamingMessage .call_args
281+ assert kwargs ['metadata' ] == [
282+ (
283+ HTTP_EXTENSION_HEADER ,
284+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
285+ )
286+ ]
259287 assert isinstance (responses [0 ], Message )
260288 assert responses [0 ].message_id == sample_message .message_id
261289 assert isinstance (responses [1 ], Task )
@@ -279,7 +307,13 @@ async def test_get_task(
279307 mock_grpc_stub .GetTask .assert_awaited_once_with (
280308 a2a_pb2 .GetTaskRequest (
281309 name = f'tasks/{ sample_task .id } ' , history_length = None
282- )
310+ ),
311+ metadata = [
312+ (
313+ HTTP_EXTENSION_HEADER ,
314+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
315+ )
316+ ],
283317 )
284318 assert response .id == sample_task .id
285319
@@ -298,7 +332,13 @@ async def test_get_task_with_history(
298332 mock_grpc_stub .GetTask .assert_awaited_once_with (
299333 a2a_pb2 .GetTaskRequest (
300334 name = f'tasks/{ sample_task .id } ' , history_length = history_len
301- )
335+ ),
336+ metadata = [
337+ (
338+ HTTP_EXTENSION_HEADER ,
339+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
340+ )
341+ ],
302342 )
303343
304344
@@ -317,7 +357,13 @@ async def test_cancel_task(
317357 response = await grpc_transport .cancel_task (params )
318358
319359 mock_grpc_stub .CancelTask .assert_awaited_once_with (
320- a2a_pb2 .CancelTaskRequest (name = f'tasks/{ sample_task .id } ' )
360+ a2a_pb2 .CancelTaskRequest (name = f'tasks/{ sample_task .id } ' ),
361+ metadata = [
362+ (
363+ HTTP_EXTENSION_HEADER ,
364+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
365+ )
366+ ],
321367 )
322368 assert response .status .state == TaskState .canceled
323369
@@ -346,7 +392,13 @@ async def test_set_task_callback_with_valid_task(
346392 config = proto_utils .ToProto .task_push_notification_config (
347393 sample_task_push_notification_config
348394 ),
349- )
395+ ),
396+ metadata = [
397+ (
398+ HTTP_EXTENSION_HEADER ,
399+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
400+ )
401+ ],
350402 )
351403 assert response .task_id == sample_task_push_notification_config .task_id
352404
@@ -403,7 +455,13 @@ async def test_get_task_callback_with_valid_task(
403455 f'tasks/{ params .id } /'
404456 f'pushNotificationConfigs/{ params .push_notification_config_id } '
405457 ),
406- )
458+ ),
459+ metadata = [
460+ (
461+ HTTP_EXTENSION_HEADER ,
462+ 'https://example.com/test-ext/v1, https://example.com/test-ext/v2' ,
463+ )
464+ ],
407465 )
408466 assert response .task_id == sample_task_push_notification_config .task_id
409467
@@ -435,61 +493,3 @@ async def test_get_task_callback_with_invalid_task(
435493 'Bad TaskPushNotificationConfig resource name'
436494 in exc_info .value .error .message
437495 )
438-
439-
440- @pytest .mark .asyncio
441- async def test_send_message_with_extensions (
442- mock_grpc_stub : AsyncMock ,
443- sample_agent_card : AgentCard ,
444- sample_message_send_params : MessageSendParams ,
445- sample_task : Task ,
446- ) -> None :
447- """Test send_message with extensions."""
448- extensions = ['test_extension_1' , 'test_extension_2' ]
449- channel = AsyncMock ()
450- transport = GrpcTransport (
451- channel = channel , agent_card = sample_agent_card , extensions = extensions
452- )
453- transport .stub = mock_grpc_stub
454-
455- mock_grpc_stub .SendMessage .return_value = a2a_pb2 .SendMessageResponse (
456- task = proto_utils .ToProto .task (sample_task )
457- )
458-
459- await transport .send_message (sample_message_send_params )
460-
461- mock_grpc_stub .SendMessage .assert_awaited_once ()
462- args , _ = mock_grpc_stub .SendMessage .call_args
463- request = args [0 ]
464- metadata = proto_utils .FromProto .metadata (request .metadata )
465- assert HTTP_EXTENSION_HEADER in metadata
466- assert metadata [HTTP_EXTENSION_HEADER ] == 'test_extension_1,test_extension_2'
467-
468-
469- @pytest .mark .asyncio
470- async def test_send_message_streaming_with_extensions (
471- mock_grpc_stub : AsyncMock ,
472- sample_agent_card : AgentCard ,
473- sample_message_send_params : MessageSendParams ,
474- ) -> None :
475- """Test send_message_streaming with extensions."""
476- extensions = ['test_extension_1' , 'test_extension_2' ]
477- channel = AsyncMock ()
478- transport = GrpcTransport (
479- channel = channel , agent_card = sample_agent_card , extensions = extensions
480- )
481- transport .stub = mock_grpc_stub
482-
483- stream = MagicMock ()
484- stream .read = AsyncMock (side_effect = [grpc .aio .EOF ])
485- mock_grpc_stub .SendStreamingMessage .return_value = stream
486-
487- async for _ in transport .send_message_streaming (sample_message_send_params ):
488- pass
489-
490- mock_grpc_stub .SendStreamingMessage .assert_called_once ()
491- args , _ = mock_grpc_stub .SendStreamingMessage .call_args
492- request = args [0 ]
493- metadata = proto_utils .FromProto .metadata (request .metadata )
494- assert HTTP_EXTENSION_HEADER in metadata
495- assert metadata [HTTP_EXTENSION_HEADER ] == 'test_extension_1,test_extension_2'
0 commit comments