Skip to content

Commit 41f4456

Browse files
committed
test(client): add tests for call-site MessageSendConfiguration merge behavior in BaseClient.send_message
1 parent 051ab20 commit 41f4456

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

src/a2a/client/base_client.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def send_message(
4646
self,
4747
request: Message,
4848
*,
49+
configuration: MessageSendConfiguration | None = None,
4950
context: ClientCallContext | None = None,
5051
) -> AsyncIterator[ClientEvent | Message]:
5152
"""Sends a message to the agent.
@@ -56,12 +57,13 @@ async def send_message(
5657
5758
Args:
5859
request: The message to send to the agent.
60+
configuration: Optional per-call overrides for message sending behavior.
5961
context: The client call context.
6062
6163
Yields:
6264
An async iterator of `ClientEvent` or a final `Message` response.
6365
"""
64-
config = MessageSendConfiguration(
66+
base_config = MessageSendConfiguration(
6567
accepted_output_modes=self._config.accepted_output_modes,
6668
blocking=not self._config.polling,
6769
push_notification_config=(
@@ -70,6 +72,12 @@ async def send_message(
7072
else None
7173
),
7274
)
75+
if configuration is not None:
76+
overrides = configuration.model_dump(exclude_unset=True, exclude_none=True)
77+
config = base_config.model_copy(update=overrides)
78+
else:
79+
config = base_config
80+
7381
params = MessageSendParams(message=request, configuration=config)
7482

7583
if not self._config.streaming or not self._card.capabilities.streaming:

tests/client/test_base_client.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from a2a.types import (
99
AgentCapabilities,
1010
AgentCard,
11+
MessageSendConfiguration,
1112
Message,
1213
Part,
1314
Role,
@@ -115,3 +116,82 @@ async def test_send_message_non_streaming_agent_capability_false(
115116
assert not mock_transport.send_message_streaming.called
116117
assert len(events) == 1
117118
assert events[0][0].id == 'task-789'
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_send_message_uses_callsite_configuration_partial_override_non_streaming(
123+
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
124+
):
125+
base_client._config.streaming = False
126+
mock_transport.send_message.return_value = Task(
127+
id='task-cfg-ns-1',
128+
context_id='ctx-cfg-ns-1',
129+
status=TaskStatus(state=TaskState.completed),
130+
)
131+
132+
cfg = MessageSendConfiguration(history_length=2)
133+
events = [ev async for ev in base_client.send_message(sample_message, configuration=cfg)]
134+
135+
mock_transport.send_message.assert_called_once()
136+
assert not mock_transport.send_message_streaming.called
137+
assert len(events) == 1 and events[0][0].id == 'task-cfg-ns-1'
138+
139+
params = mock_transport.send_message.await_args.args[0]
140+
assert params.configuration.history_length == 2
141+
assert params.configuration.blocking == (not base_client._config.polling)
142+
assert params.configuration.accepted_output_modes == base_client._config.accepted_output_modes
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_send_message_ignores_none_fields_in_callsite_configuration_non_streaming(
147+
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
148+
):
149+
base_client._config.streaming = False
150+
mock_transport.send_message.return_value = Task(
151+
id='task-cfg-ns-2',
152+
context_id='ctx-cfg-ns-2',
153+
status=TaskStatus(state=TaskState.completed),
154+
)
155+
156+
cfg = MessageSendConfiguration(history_length=None, blocking=None)
157+
events = [ev async for ev in base_client.send_message(sample_message, configuration=cfg)]
158+
159+
mock_transport.send_message.assert_called_once()
160+
assert len(events) == 1 and events[0][0].id == 'task-cfg-ns-2'
161+
162+
params = mock_transport.send_message.await_args.args[0]
163+
assert params.configuration.history_length is None
164+
assert params.configuration.blocking == (not base_client._config.polling)
165+
assert params.configuration.accepted_output_modes == base_client._config.accepted_output_modes
166+
167+
168+
@pytest.mark.asyncio
169+
async def test_send_message_uses_callsite_configuration_partial_override_streaming(
170+
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
171+
):
172+
base_client._config.streaming = True
173+
base_client._card.capabilities.streaming = True
174+
175+
async def create_stream(*args, **kwargs):
176+
yield Task(
177+
id='task-cfg-s-1',
178+
context_id='ctx-cfg-s-1',
179+
status=TaskStatus(state=TaskState.completed),
180+
)
181+
182+
mock_transport.send_message_streaming.return_value = create_stream()
183+
184+
cfg = MessageSendConfiguration(history_length=0)
185+
events = [ev async for ev in base_client.send_message(sample_message, configuration=cfg)]
186+
187+
mock_transport.send_message_streaming.assert_called_once()
188+
assert not mock_transport.send_message.called
189+
assert len(events) == 1
190+
first = events[0][0] if isinstance(events[0], tuple) else events[0]
191+
assert first.id == 'task-cfg-s-1'
192+
193+
params = mock_transport.send_message_streaming.call_args.args[0]
194+
assert params.configuration.history_length == 0
195+
assert params.configuration.blocking == (not base_client._config.polling)
196+
assert params.configuration.accepted_output_modes == base_client._config.accepted_output_modes
197+

0 commit comments

Comments
 (0)