diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index f3b584d4..a1b8d565 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -33,9 +33,7 @@ InvalidParamsError, ListTaskPushNotificationConfigParams, Message, - MessageSendConfiguration, MessageSendParams, - PushNotificationConfig, Task, TaskIdParams, TaskNotFoundError, @@ -202,18 +200,6 @@ async def _setup_message_execution( ) task = task_manager.update_with_message(params.message, task) - if self.should_add_push_info(params): - assert self._push_config_store is not None - assert isinstance( - params.configuration, MessageSendConfiguration - ) - assert isinstance( - params.configuration.pushNotificationConfig, - PushNotificationConfig, - ) - await self._push_config_store.set_info( - task.id, params.configuration.pushNotificationConfig - ) # Build request context request_context = await self._request_context_builder.build( @@ -228,6 +214,16 @@ async def _setup_message_execution( # Always assign a task ID. We may not actually upgrade to a task, but # dictating the task ID at this layer is useful for tracking running # agents. + + if ( + self._push_config_store + and params.configuration + and params.configuration.pushNotificationConfig + ): + await self._push_config_store.set_info( + task_id, params.configuration.pushNotificationConfig + ) + queue = await self._queue_manager.create_or_tap(task_id) result_aggregator = ResultAggregator(task_manager) # TODO: to manage the non-blocking flows. @@ -333,16 +329,6 @@ async def on_message_send_stream( if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) - if ( - self._push_config_store - and params.configuration - and params.configuration.pushNotificationConfig - ): - await self._push_config_store.set_info( - task_id, - params.configuration.pushNotificationConfig, - ) - await self._send_push_notification_if_needed( task_id, result_aggregator ) @@ -509,11 +495,3 @@ async def on_delete_task_push_notification_config( await self._push_config_store.delete_info( params.id, params.pushNotificationConfigId ) - - def should_add_push_info(self, params: MessageSendParams) -> bool: - """Determines if push notification info should be set for a task.""" - return bool( - self._push_config_store - and params.configuration - and params.configuration.pushNotificationConfig - ) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index fdf100f7..0b4b677b 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -401,6 +401,90 @@ async def get_current_result(): mock_agent_executor.execute.assert_awaited_once() +@pytest.mark.asyncio +async def test_on_message_send_with_push_notification_no_existing_Task(): + """Test on_message_send for new task sets push notification info if provided.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + task_id = 'push_task_1' + context_id = 'push_ctx_1' + + mock_task_store.get.return_value = ( + None # Simulate new task scenario for TaskManager + ) + + # Mock _request_context_builder.build to return a context with the generated/confirmed IDs + mock_request_context = MagicMock(spec=RequestContext) + mock_request_context.task_id = task_id + mock_request_context.context_id = context_id + mock_request_context_builder.build.return_value = mock_request_context + + request_handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + push_config_store=mock_push_notification_store, + request_context_builder=mock_request_context_builder, + ) + + push_config = PushNotificationConfig(url='http://callback.com/push') + message_config = MessageSendConfiguration( + pushNotificationConfig=push_config, + acceptedOutputModes=['text/plain'], # Added required field + ) + params = MessageSendParams( + message=Message( + role=Role.user, + messageId='msg_push', + parts=[], + taskId=task_id, + contextId=context_id, + ), + configuration=message_config, + ) + + # Mock ResultAggregator and its consume_and_break_on_interrupt + mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) + final_task_result = create_sample_task( + task_id=task_id, context_id=context_id, status_state=TaskState.completed + ) + mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( + final_task_result, + False, + ) + + # Mock the current_result property to return the final task result + async def get_current_result(): + return final_task_result + + # Configure the 'current_result' property on the type of the mock instance + type(mock_result_aggregator_instance).current_result = PropertyMock( + return_value=get_current_result() + ) + + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.ResultAggregator', + return_value=mock_result_aggregator_instance, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=None, + ), + ): + await request_handler.on_message_send( + params, create_server_call_context() + ) + + mock_push_notification_store.set_info.assert_awaited_once_with( + task_id, push_config + ) + # Other assertions for full flow if needed (e.g., agent execution) + mock_agent_executor.execute.assert_awaited_once() + + @pytest.mark.asyncio async def test_on_message_send_no_result_from_aggregator(): """Test on_message_send when aggregator returns (None, False)."""