Skip to content

Commit 692f8d1

Browse files
committed
Fix unit test and refactor to share logic between send and send_streaming methods
1 parent 3e71ae2 commit 692f8d1

File tree

2 files changed

+91
-94
lines changed

2 files changed

+91
-94
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 76 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,25 @@ async def _run_event_stream(
169169
await self.agent_executor.execute(request, queue)
170170
await queue.close()
171171

172-
async def on_message_send(
172+
async def _setup_message_execution(
173173
self,
174174
params: MessageSendParams,
175175
context: ServerCallContext | None = None,
176-
) -> Message | Task:
177-
"""Default handler for 'message/send' interface (non-streaming).
176+
) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]:
177+
"""Common setup logic for both streaming and non-streaming message handling.
178178
179-
Starts the agent execution for the message and waits for the final
180-
result (Task or Message).
179+
Returns:
180+
A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
181181
"""
182+
# Create task manager and validate existing task
182183
task_manager = TaskManager(
183184
task_id=params.message.taskId,
184185
context_id=params.message.contextId,
185186
task_store=self.task_store,
186187
initial_message=params.message,
187188
)
188189
task: Task | None = await task_manager.get_task()
190+
189191
if task:
190192
if task.status.state in TERMINAL_TASK_STATES:
191193
raise ServerError(
@@ -207,6 +209,8 @@ async def on_message_send(
207209
await self._push_notifier.set_info(
208210
task.id, params.configuration.pushNotificationConfig
209211
)
212+
213+
# Build request context
210214
request_context = await self._request_context_builder.build(
211215
params=params,
212216
task_id=task.id if task else None,
@@ -223,13 +227,49 @@ async def on_message_send(
223227
result_aggregator = ResultAggregator(task_manager)
224228
# TODO: to manage the non-blocking flows.
225229
producer_task = asyncio.create_task(
226-
self._run_event_stream(
227-
request_context,
228-
queue,
229-
)
230+
self._run_event_stream(request_context, queue)
230231
)
231232
await self._register_producer(task_id, producer_task)
232233

234+
return task_manager, task_id, queue, result_aggregator, producer_task
235+
236+
def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
237+
"""Validates that agent-generated task ID matches the expected task ID."""
238+
if task_id != event_task_id:
239+
logger.error(
240+
f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.'
241+
)
242+
raise ServerError(
243+
InternalError(message='Task ID mismatch in agent response')
244+
)
245+
246+
async def _send_push_notification_if_needed(
247+
self, task_id: str, result_aggregator: ResultAggregator
248+
) -> None:
249+
"""Sends push notification if configured and task is available."""
250+
if self._push_notifier and task_id:
251+
latest_task = await result_aggregator.current_result
252+
if isinstance(latest_task, Task):
253+
await self._push_notifier.send_notification(latest_task)
254+
255+
async def on_message_send(
256+
self,
257+
params: MessageSendParams,
258+
context: ServerCallContext | None = None,
259+
) -> Message | Task:
260+
"""Default handler for 'message/send' interface (non-streaming).
261+
262+
Starts the agent execution for the message and waits for the final
263+
result (Task or Message).
264+
"""
265+
(
266+
task_manager,
267+
task_id,
268+
queue,
269+
result_aggregator,
270+
producer_task,
271+
) = await self._setup_message_execution(params, context)
272+
233273
consumer = EventConsumer(queue)
234274
producer_task.add_done_callback(consumer.agent_task_callback)
235275

@@ -242,18 +282,12 @@ async def on_message_send(
242282
if not result:
243283
raise ServerError(error=InternalError())
244284

245-
if isinstance(result, Task) and task_id != result.id:
246-
logger.error(
247-
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
248-
)
249-
raise ServerError(
250-
InternalError(message='Task ID mismatch in agent response')
251-
)
285+
if isinstance(result, Task):
286+
self._validate_task_id_match(task_id, result.id)
252287

253-
if self._push_notifier and task_id:
254-
latest_task = await result_aggregator.current_result
255-
if isinstance(latest_task, Task):
256-
await self._push_notifier.send_notification(latest_task)
288+
await self._send_push_notification_if_needed(
289+
task_id, result_aggregator
290+
)
257291

258292
finally:
259293
if interrupted:
@@ -276,85 +310,34 @@ async def on_message_send_stream(
276310
Starts the agent execution and yields events as they are produced
277311
by the agent.
278312
"""
279-
task_manager = TaskManager(
280-
task_id=params.message.taskId,
281-
context_id=params.message.contextId,
282-
task_store=self.task_store,
283-
initial_message=params.message,
284-
)
285-
task: Task | None = await task_manager.get_task()
286-
287-
if task:
288-
if task.status.state in TERMINAL_TASK_STATES:
289-
raise ServerError(
290-
error=InvalidParamsError(
291-
message=f'Task {task.id} is in terminal state: {task.status.state}'
292-
)
293-
)
294-
295-
task = task_manager.update_with_message(params.message, task)
296-
if self.should_add_push_info(params):
297-
assert isinstance(self._push_notifier, PushNotifier)
298-
assert isinstance(
299-
params.configuration, MessageSendConfiguration
300-
)
301-
assert isinstance(
302-
params.configuration.pushNotificationConfig,
303-
PushNotificationConfig,
304-
)
305-
await self._push_notifier.set_info(
306-
task.id, params.configuration.pushNotificationConfig
307-
)
308-
else:
309-
queue = EventQueue()
310-
result_aggregator = ResultAggregator(task_manager)
311-
request_context = await self._request_context_builder.build(
312-
params=params,
313-
task_id=task.id if task else None,
314-
context_id=params.message.contextId,
315-
task=task,
316-
context=context,
317-
)
318-
319-
task_id = cast('str', request_context.task_id)
320-
queue = await self._queue_manager.create_or_tap(task_id)
321-
producer_task = asyncio.create_task(
322-
self._run_event_stream(
323-
request_context,
324-
queue,
325-
)
326-
)
327-
await self._register_producer(task_id, producer_task)
313+
(
314+
task_manager,
315+
task_id,
316+
queue,
317+
result_aggregator,
318+
producer_task,
319+
) = await self._setup_message_execution(params, context)
328320

329321
try:
330322
consumer = EventConsumer(queue)
331323
producer_task.add_done_callback(consumer.agent_task_callback)
332324
async for event in result_aggregator.consume_and_emit(consumer):
333325
if isinstance(event, Task):
334-
if task_id != event.id:
335-
logger.error(
336-
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
337-
)
338-
raise ServerError(
339-
InternalError(
340-
message='Task ID mismatch in agent response'
341-
)
342-
)
343-
344-
if (
345-
self._push_notifier
346-
and params.configuration
347-
and params.configuration.pushNotificationConfig
348-
):
349-
await self._push_notifier.set_info(
350-
task_id,
351-
params.configuration.pushNotificationConfig,
352-
)
353-
354-
if self._push_notifier and task_id:
355-
latest_task = await result_aggregator.current_result
356-
if isinstance(latest_task, Task):
357-
await self._push_notifier.send_notification(latest_task)
326+
self._validate_task_id_match(task_id, event.id)
327+
328+
if (
329+
self._push_notifier
330+
and params.configuration
331+
and params.configuration.pushNotificationConfig
332+
):
333+
await self._push_notifier.set_info(
334+
task_id,
335+
params.configuration.pushNotificationConfig,
336+
)
337+
338+
await self._send_push_notification_if_needed(
339+
task_id, result_aggregator
340+
)
358341
yield event
359342
finally:
360343
await self._cleanup_producer(producer_task, task_id)

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,15 @@ async def test_on_message_send_with_push_notification():
361361
False,
362362
)
363363

364+
# Mock the current_result property to return the final task result
365+
async def get_current_result():
366+
return final_task_result
367+
368+
# Configure the 'current_result' property on the type of the mock instance
369+
type(mock_result_aggregator_instance).current_result = PropertyMock(
370+
return_value=get_current_result()
371+
)
372+
364373
with (
365374
patch(
366375
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
@@ -380,6 +389,9 @@ async def test_on_message_send_with_push_notification():
380389
)
381390

382391
mock_push_notifier.set_info.assert_awaited_once_with(task_id, push_config)
392+
mock_push_notifier.send_notification.assert_awaited_once_with(
393+
final_task_result
394+
)
383395
# Other assertions for full flow if needed (e.g., agent execution)
384396
mock_agent_executor.execute.assert_awaited_once()
385397

@@ -1139,12 +1151,14 @@ async def consume_stream():
11391151
texts = [p.root.text for e in events for p in e.status.message.parts]
11401152
assert texts == ['Event 0', 'Event 1', 'Event 2']
11411153

1154+
11421155
TERMINAL_TASK_STATES = {
11431156
TaskState.completed,
11441157
TaskState.canceled,
11451158
TaskState.failed,
11461159
TaskState.rejected,
1147-
}
1160+
}
1161+
11481162

11491163
@pytest.mark.asyncio
11501164
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)

0 commit comments

Comments
 (0)