diff --git a/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py b/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py index 2a1e4e161..c06c4e4e2 100644 --- a/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py +++ b/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py @@ -410,15 +410,33 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]: assistant_message_ids, ) - if not message_batch: - if assistant_message_ids: + next_role = ( + getattr(unseen_messages[index], "role", None) + if index < total_unseen + else None + ) + has_upcoming_tool_batch = next_role == "tool" + + if message_batch: + if assistant_message_ids and has_upcoming_tool_batch: + message_ids = self._collect_message_ids(message_batch) + if message_ids: + self._session_manager.mark_messages_processed( + app_name, + input.thread_id, + message_ids, + ) skip_tool_message_batch = True - continue - else: + continue + skip_tool_message_batch = False - async for event in self._start_new_execution(input, message_batch=message_batch): - yield event + async for event in self._start_new_execution(input, message_batch=message_batch): + yield event + else: + if assistant_message_ids: + skip_tool_message_batch = True + continue async def _ensure_session_exists(self, app_name: str, user_id: str, session_id: str, initial_state: dict): """Ensure a session exists, creating it if necessary via session manager.""" diff --git a/integrations/adk-middleware/python/tests/test_tool_result_flow.py b/integrations/adk-middleware/python/tests/test_tool_result_flow.py index 168bb3295..6620d072e 100644 --- a/integrations/adk-middleware/python/tests/test_tool_result_flow.py +++ b/integrations/adk-middleware/python/tests/test_tool_result_flow.py @@ -641,6 +641,102 @@ async def mock_start_new_execution(input_data, *, tool_results=None, message_bat processed_ids = ag_ui_adk._session_manager.get_processed_message_ids(app_name, input_data.thread_id) assert "assistant_tool" in processed_ids + @pytest.mark.asyncio + async def test_run_does_not_restart_execution_for_replayed_prompt_with_tool_results(self, ag_ui_adk): + """Replayed prompts with new IDs should not spawn duplicate executions when tool results are present.""" + replayed_user = UserMessage( + id="user_replay_v2", + role="user", + content="Check status", + ) + + assistant_call = AssistantMessage( + id="assistant_tool_v2", + role="assistant", + content=None, + tool_calls=[ + ToolCall( + id="call_replay", + function=FunctionCall(name="test_tool", arguments="{}"), + ) + ], + ) + + tool_result = ToolMessage( + id="tool_result_v2", + role="tool", + content='{"result": "value"}', + tool_call_id="call_replay", + ) + + input_data = RunAgentInput( + thread_id="thread_replayed_prompt", + run_id="run_replayed_prompt", + messages=[ + replayed_user, + assistant_call, + tool_result, + ], + tools=[], + context=[], + state={}, + forwarded_props={}, + ) + + app_name = ag_ui_adk._get_app_name(input_data) + user_id = ag_ui_adk._get_user_id(input_data) + await ag_ui_adk._ensure_session_exists(app_name, user_id, input_data.thread_id, input_data.state) + await ag_ui_adk._add_pending_tool_call_with_context( + input_data.thread_id, + "call_replay", + app_name, + user_id, + ) + + start_calls = [] + + async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None): + start_calls.append((tool_results, message_batch)) + yield RunStartedEvent( + type=EventType.RUN_STARTED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + yield RunFinishedEvent( + type=EventType.RUN_FINISHED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + + with patch.object( + ag_ui_adk, + '_start_new_execution', + side_effect=mock_start_new_execution, + ) as start_mock, patch.object( + ag_ui_adk, + '_handle_tool_result_submission', + wraps=ag_ui_adk._handle_tool_result_submission, + ) as handle_mock: + events = [] + async for event in ag_ui_adk.run(input_data): + events.append(event) + + assert [event.type for event in events] == [ + EventType.RUN_STARTED, + EventType.RUN_FINISHED, + ] + + assert start_mock.call_count == 1 + assert len(start_calls) == 1 + tool_results, message_batch = start_calls[0] + assert tool_results is not None + assert message_batch is None + assert handle_mock.call_count == 1 + assert handle_mock.call_args.kwargs.get('include_message_batch') is False + + processed_ids = ag_ui_adk._session_manager.get_processed_message_ids(app_name, input_data.thread_id) + assert "user_replay_v2" in processed_ids + @pytest.mark.asyncio async def test_run_preserves_order_for_user_then_tool(self, ag_ui_adk): """Verify user updates are handled before subsequent tool messages."""