diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 4ef889ad..ff86a069 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -55,6 +55,7 @@ TaskState.rejected, } + @trace_class(kind=SpanKind.SERVER) class DefaultRequestHandler(RequestHandler): """Default request handler for all incoming requests. @@ -168,16 +169,17 @@ async def _run_event_stream( await self.agent_executor.execute(request, queue) await queue.close() - async def on_message_send( + async def _setup_message_execution( self, params: MessageSendParams, context: ServerCallContext | None = None, - ) -> Message | Task: - """Default handler for 'message/send' interface (non-streaming). + ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: + """Common setup logic for both streaming and non-streaming message handling. - Starts the agent execution for the message and waits for the final - result (Task or Message). + Returns: + A tuple of (task_manager, task_id, queue, result_aggregator, producer_task) """ + # Create task manager and validate existing task task_manager = TaskManager( task_id=params.message.taskId, context_id=params.message.contextId, @@ -185,6 +187,7 @@ async def on_message_send( initial_message=params.message, ) task: Task | None = await task_manager.get_task() + if task: if task.status.state in TERMINAL_TASK_STATES: raise ServerError( @@ -206,6 +209,8 @@ async def on_message_send( await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) + + # Build request context request_context = await self._request_context_builder.build( params=params, task_id=task.id if task else None, @@ -222,13 +227,49 @@ async def on_message_send( result_aggregator = ResultAggregator(task_manager) # TODO: to manage the non-blocking flows. producer_task = asyncio.create_task( - self._run_event_stream( - request_context, - queue, - ) + self._run_event_stream(request_context, queue) ) await self._register_producer(task_id, producer_task) + return task_manager, task_id, queue, result_aggregator, producer_task + + def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: + """Validates that agent-generated task ID matches the expected task ID.""" + if task_id != event_task_id: + logger.error( + f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.' + ) + raise ServerError( + InternalError(message='Task ID mismatch in agent response') + ) + + async def _send_push_notification_if_needed( + self, task_id: str, result_aggregator: ResultAggregator + ) -> None: + """Sends push notification if configured and task is available.""" + if self._push_notifier and task_id: + latest_task = await result_aggregator.current_result + if isinstance(latest_task, Task): + await self._push_notifier.send_notification(latest_task) + + async def on_message_send( + self, + params: MessageSendParams, + context: ServerCallContext | None = None, + ) -> Message | Task: + """Default handler for 'message/send' interface (non-streaming). + + Starts the agent execution for the message and waits for the final + result (Task or Message). + """ + ( + task_manager, + task_id, + queue, + result_aggregator, + producer_task, + ) = await self._setup_message_execution(params, context) + consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) @@ -241,13 +282,13 @@ async def on_message_send( if not result: raise ServerError(error=InternalError()) - if isinstance(result, Task) and task_id != result.id: - logger.error( - f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.' - ) - raise ServerError( - InternalError(message='Task ID mismatch in agent response') - ) + if isinstance(result, Task): + self._validate_task_id_match(task_id, result.id) + + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) + except Exception as e: logger.error(f'Agent execution failed. Error: {e}') raise @@ -272,85 +313,34 @@ async def on_message_send_stream( Starts the agent execution and yields events as they are produced by the agent. """ - task_manager = TaskManager( - task_id=params.message.taskId, - context_id=params.message.contextId, - task_store=self.task_store, - initial_message=params.message, - ) - task: Task | None = await task_manager.get_task() - - if task: - if task.status.state in TERMINAL_TASK_STATES: - raise ServerError( - error=InvalidParamsError( - message=f'Task {task.id} is in terminal state: {task.status.state}' - ) - ) - - task = task_manager.update_with_message(params.message, task) - if self.should_add_push_info(params): - assert isinstance(self._push_notifier, PushNotifier) - assert isinstance( - params.configuration, MessageSendConfiguration - ) - assert isinstance( - params.configuration.pushNotificationConfig, - PushNotificationConfig, - ) - await self._push_notifier.set_info( - task.id, params.configuration.pushNotificationConfig - ) - else: - queue = EventQueue() - result_aggregator = ResultAggregator(task_manager) - request_context = await self._request_context_builder.build( - params=params, - task_id=task.id if task else None, - context_id=params.message.contextId, - task=task, - context=context, - ) - - task_id = cast('str', request_context.task_id) - queue = await self._queue_manager.create_or_tap(task_id) - producer_task = asyncio.create_task( - self._run_event_stream( - request_context, - queue, - ) - ) - await self._register_producer(task_id, producer_task) + ( + task_manager, + task_id, + queue, + result_aggregator, + producer_task, + ) = await self._setup_message_execution(params, context) try: consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) async for event in result_aggregator.consume_and_emit(consumer): if isinstance(event, Task): - if task_id != event.id: - logger.error( - f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.' - ) - raise ServerError( - InternalError( - message='Task ID mismatch in agent response' - ) - ) - - if ( - self._push_notifier - and params.configuration - and params.configuration.pushNotificationConfig - ): - await self._push_notifier.set_info( - task_id, - params.configuration.pushNotificationConfig, - ) - - if self._push_notifier and task_id: - latest_task = await result_aggregator.current_result - if isinstance(latest_task, Task): - await self._push_notifier.send_notification(latest_task) + self._validate_task_id_match(task_id, event.id) + + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + ): + await self._push_notifier.set_info( + task_id, + params.configuration.pushNotificationConfig, + ) + + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) yield event finally: await self._cleanup_producer(producer_task, task_id) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 6f67c0f8..dd713752 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -361,6 +361,15 @@ async def test_on_message_send_with_push_notification(): False, ) + # Mock the current_result property to return the final task result + async def get_current_result(): + return final_task_result + + # Configure the 'current_result' property on the type of the mock instance + type(mock_result_aggregator_instance).current_result = PropertyMock( + return_value=get_current_result() + ) + with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -380,6 +389,9 @@ async def test_on_message_send_with_push_notification(): ) mock_push_notifier.set_info.assert_awaited_once_with(task_id, push_config) + mock_push_notifier.send_notification.assert_awaited_once_with( + final_task_result + ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -1139,12 +1151,14 @@ async def consume_stream(): texts = [p.root.text for e in events for p in e.status.message.parts] assert texts == ['Event 0', 'Event 1', 'Event 2'] + TERMINAL_TASK_STATES = { TaskState.completed, TaskState.canceled, TaskState.failed, TaskState.rejected, -} +} + @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)