Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 79 additions & 89 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TaskState.rejected,
}


@trace_class(kind=SpanKind.SERVER)
class DefaultRequestHandler(RequestHandler):
"""Default request handler for all incoming requests.
Expand Down Expand Up @@ -168,23 +169,25 @@ async def _run_event_stream(
await self.agent_executor.execute(request, queue)
await queue.close()

async def on_message_send(
async def _setup_message_execution(
self,
params: MessageSendParams,
context: ServerCallContext | None = None,
) -> Message | Task:
"""Default handler for 'message/send' interface (non-streaming).
) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]:
"""Common setup logic for both streaming and non-streaming message handling.

Starts the agent execution for the message and waits for the final
result (Task or Message).
Returns:
A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
"""
# Create task manager and validate existing task
task_manager = TaskManager(
task_id=params.message.taskId,
context_id=params.message.contextId,
task_store=self.task_store,
initial_message=params.message,
)
task: Task | None = await task_manager.get_task()

if task:
if task.status.state in TERMINAL_TASK_STATES:
raise ServerError(
Expand All @@ -206,6 +209,8 @@ async def on_message_send(
await self._push_notifier.set_info(
task.id, params.configuration.pushNotificationConfig
)

# Build request context
request_context = await self._request_context_builder.build(
params=params,
task_id=task.id if task else None,
Expand All @@ -222,13 +227,49 @@ async def on_message_send(
result_aggregator = ResultAggregator(task_manager)
# TODO: to manage the non-blocking flows.
producer_task = asyncio.create_task(
self._run_event_stream(
request_context,
queue,
)
self._run_event_stream(request_context, queue)
)
await self._register_producer(task_id, producer_task)

return task_manager, task_id, queue, result_aggregator, producer_task

def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
"""Validates that agent-generated task ID matches the expected task ID."""
if task_id != event_task_id:
logger.error(
f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(message='Task ID mismatch in agent response')
)

async def _send_push_notification_if_needed(
self, task_id: str, result_aggregator: ResultAggregator
) -> None:
"""Sends push notification if configured and task is available."""
if self._push_notifier and task_id:
latest_task = await result_aggregator.current_result
if isinstance(latest_task, Task):
await self._push_notifier.send_notification(latest_task)

async def on_message_send(
self,
params: MessageSendParams,
context: ServerCallContext | None = None,
) -> Message | Task:
"""Default handler for 'message/send' interface (non-streaming).

Starts the agent execution for the message and waits for the final
result (Task or Message).
"""
(
task_manager,
task_id,
queue,
result_aggregator,
producer_task,
) = await self._setup_message_execution(params, context)

consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)

Expand All @@ -241,13 +282,13 @@ async def on_message_send(
if not result:
raise ServerError(error=InternalError())

if isinstance(result, Task) and task_id != result.id:
logger.error(
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(message='Task ID mismatch in agent response')
)
if isinstance(result, Task):
self._validate_task_id_match(task_id, result.id)

await self._send_push_notification_if_needed(
task_id, result_aggregator
)

except Exception as e:
logger.error(f'Agent execution failed. Error: {e}')
raise
Expand All @@ -272,85 +313,34 @@ async def on_message_send_stream(
Starts the agent execution and yields events as they are produced
by the agent.
"""
task_manager = TaskManager(
task_id=params.message.taskId,
context_id=params.message.contextId,
task_store=self.task_store,
initial_message=params.message,
)
task: Task | None = await task_manager.get_task()

if task:
if task.status.state in TERMINAL_TASK_STATES:
raise ServerError(
error=InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state}'
)
)

task = task_manager.update_with_message(params.message, task)
if self.should_add_push_info(params):
assert isinstance(self._push_notifier, PushNotifier)
assert isinstance(
params.configuration, MessageSendConfiguration
)
assert isinstance(
params.configuration.pushNotificationConfig,
PushNotificationConfig,
)
await self._push_notifier.set_info(
task.id, params.configuration.pushNotificationConfig
)
else:
queue = EventQueue()
result_aggregator = ResultAggregator(task_manager)
request_context = await self._request_context_builder.build(
params=params,
task_id=task.id if task else None,
context_id=params.message.contextId,
task=task,
context=context,
)

task_id = cast('str', request_context.task_id)
queue = await self._queue_manager.create_or_tap(task_id)
producer_task = asyncio.create_task(
self._run_event_stream(
request_context,
queue,
)
)
await self._register_producer(task_id, producer_task)
(
task_manager,
task_id,
queue,
result_aggregator,
producer_task,
) = await self._setup_message_execution(params, context)

try:
consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)
async for event in result_aggregator.consume_and_emit(consumer):
if isinstance(event, Task):
if task_id != event.id:
logger.error(
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(
message='Task ID mismatch in agent response'
)
)

if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_notifier.set_info(
task_id,
params.configuration.pushNotificationConfig,
)

if self._push_notifier and task_id:
latest_task = await result_aggregator.current_result
if isinstance(latest_task, Task):
await self._push_notifier.send_notification(latest_task)
self._validate_task_id_match(task_id, event.id)

if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_notifier.set_info(
task_id,
params.configuration.pushNotificationConfig,
)

await self._send_push_notification_if_needed(
task_id, result_aggregator
)
yield event
finally:
await self._cleanup_producer(producer_task, task_id)
Expand Down
16 changes: 15 additions & 1 deletion tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,15 @@ async def test_on_message_send_with_push_notification():
False,
)

# Mock the current_result property to return the final task result
async def get_current_result():
return final_task_result

# Configure the 'current_result' property on the type of the mock instance
type(mock_result_aggregator_instance).current_result = PropertyMock(
return_value=get_current_result()
)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
Expand All @@ -380,6 +389,9 @@ async def test_on_message_send_with_push_notification():
)

mock_push_notifier.set_info.assert_awaited_once_with(task_id, push_config)
mock_push_notifier.send_notification.assert_awaited_once_with(
final_task_result
)
# Other assertions for full flow if needed (e.g., agent execution)
mock_agent_executor.execute.assert_awaited_once()

Expand Down Expand Up @@ -1139,12 +1151,14 @@ async def consume_stream():
texts = [p.root.text for e in events for p in e.status.message.parts]
assert texts == ['Event 0', 'Event 1', 'Event 2']


TERMINAL_TASK_STATES = {
TaskState.completed,
TaskState.canceled,
TaskState.failed,
TaskState.rejected,
}
}


@pytest.mark.asyncio
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
Expand Down
Loading