Skip to content

Commit 12b4a1d

Browse files
authored
feat: add metadata to send message request (#532)
Extended `client.send_message` to take `metadata` parameter which gets attached to `MessageSendParams`.
1 parent 5268218 commit 12b4a1d

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

src/a2a/client/base_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import AsyncIterator
2+
from typing import Any
23

34
from a2a.client.client import (
45
Client,
@@ -47,6 +48,7 @@ async def send_message(
4748
request: Message,
4849
*,
4950
context: ClientCallContext | None = None,
51+
request_metadata: dict[str, Any] | None = None,
5052
) -> AsyncIterator[ClientEvent | Message]:
5153
"""Sends a message to the agent.
5254
@@ -57,6 +59,7 @@ async def send_message(
5759
Args:
5860
request: The message to send to the agent.
5961
context: The client call context.
62+
request_metadata: Extensions Metadata attached to the request.
6063
6164
Yields:
6265
An async iterator of `ClientEvent` or a final `Message` response.
@@ -70,7 +73,9 @@ async def send_message(
7073
else None
7174
),
7275
)
73-
params = MessageSendParams(message=request, configuration=config)
76+
params = MessageSendParams(
77+
message=request, configuration=config, metadata=request_metadata
78+
)
7479

7580
if not self._config.streaming or not self._card.capabilities.streaming:
7681
response = await self._transport.send_message(

src/a2a/client/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ async def send_message(
110110
request: Message,
111111
*,
112112
context: ClientCallContext | None = None,
113+
request_metadata: dict[str, Any] | None = None,
113114
) -> AsyncIterator[ClientEvent | Message]:
114115
"""Sends a message to the server.
115116

tests/client/test_base_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,14 @@ async def create_stream(*args, **kwargs):
7373

7474
mock_transport.send_message_streaming.return_value = create_stream()
7575

76-
events = [event async for event in base_client.send_message(sample_message)]
76+
meta = {'test': 1}
77+
stream = base_client.send_message(sample_message, request_metadata=meta)
78+
events = [event async for event in stream]
7779

7880
mock_transport.send_message_streaming.assert_called_once()
81+
assert (
82+
mock_transport.send_message_streaming.call_args[0][0].metadata == meta
83+
)
7984
assert not mock_transport.send_message.called
8085
assert len(events) == 1
8186
assert events[0][0].id == 'task-123'
@@ -92,9 +97,12 @@ async def test_send_message_non_streaming(
9297
status=TaskStatus(state=TaskState.completed),
9398
)
9499

95-
events = [event async for event in base_client.send_message(sample_message)]
100+
meta = {'test': 1}
101+
stream = base_client.send_message(sample_message, request_metadata=meta)
102+
events = [event async for event in stream]
96103

97104
mock_transport.send_message.assert_called_once()
105+
assert mock_transport.send_message.call_args[0][0].metadata == meta
98106
assert not mock_transport.send_message_streaming.called
99107
assert len(events) == 1
100108
assert events[0][0].id == 'task-456'

0 commit comments

Comments
 (0)