File tree Expand file tree Collapse file tree 3 files changed +17
-3
lines changed
Expand file tree Collapse file tree 3 files changed +17
-3
lines changed Original file line number Diff line number Diff line change 11from collections .abc import AsyncIterator
2+ from typing import Any
23
34from a2a .client .client import (
45 Client ,
@@ -47,6 +48,7 @@ async def send_message(
4748 request : Message ,
4849 * ,
4950 context : ClientCallContext | None = None ,
51+ request_metadata : dict [str , Any ] | None = None ,
5052 ) -> AsyncIterator [ClientEvent | Message ]:
5153 """Sends a message to the agent.
5254
@@ -57,6 +59,7 @@ async def send_message(
5759 Args:
5860 request: The message to send to the agent.
5961 context: The client call context.
62+ request_metadata: Extensions Metadata attached to the request.
6063
6164 Yields:
6265 An async iterator of `ClientEvent` or a final `Message` response.
@@ -70,7 +73,9 @@ async def send_message(
7073 else None
7174 ),
7275 )
73- params = MessageSendParams (message = request , configuration = config )
76+ params = MessageSendParams (
77+ message = request , configuration = config , metadata = request_metadata
78+ )
7479
7580 if not self ._config .streaming or not self ._card .capabilities .streaming :
7681 response = await self ._transport .send_message (
Original file line number Diff line number Diff line change @@ -110,6 +110,7 @@ async def send_message(
110110 request : Message ,
111111 * ,
112112 context : ClientCallContext | None = None ,
113+ request_metadata : dict [str , Any ] | None = None ,
113114 ) -> AsyncIterator [ClientEvent | Message ]:
114115 """Sends a message to the server.
115116
Original file line number Diff line number Diff line change @@ -73,9 +73,14 @@ async def create_stream(*args, **kwargs):
7373
7474 mock_transport .send_message_streaming .return_value = create_stream ()
7575
76- events = [event async for event in base_client .send_message (sample_message )]
76+ meta = {'test' : 1 }
77+ stream = base_client .send_message (sample_message , request_metadata = meta )
78+ events = [event async for event in stream ]
7779
7880 mock_transport .send_message_streaming .assert_called_once ()
81+ assert (
82+ mock_transport .send_message_streaming .call_args [0 ][0 ].metadata == meta
83+ )
7984 assert not mock_transport .send_message .called
8085 assert len (events ) == 1
8186 assert events [0 ][0 ].id == 'task-123'
@@ -92,9 +97,12 @@ async def test_send_message_non_streaming(
9297 status = TaskStatus (state = TaskState .completed ),
9398 )
9499
95- events = [event async for event in base_client .send_message (sample_message )]
100+ meta = {'test' : 1 }
101+ stream = base_client .send_message (sample_message , request_metadata = meta )
102+ events = [event async for event in stream ]
96103
97104 mock_transport .send_message .assert_called_once ()
105+ assert mock_transport .send_message .call_args [0 ][0 ].metadata == meta
98106 assert not mock_transport .send_message_streaming .called
99107 assert len (events ) == 1
100108 assert events [0 ][0 ].id == 'task-456'
You can’t perform that action at this time.
0 commit comments