Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,15 +410,33 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
assistant_message_ids,
)

if not message_batch:
if assistant_message_ids:
next_role = (
getattr(unseen_messages[index], "role", None)
if index < total_unseen
else None
)
has_upcoming_tool_batch = next_role == "tool"

if message_batch:
if assistant_message_ids and has_upcoming_tool_batch:
message_ids = self._collect_message_ids(message_batch)
if message_ids:
self._session_manager.mark_messages_processed(
app_name,
input.thread_id,
message_ids,
)
skip_tool_message_batch = True
continue
else:
continue

skip_tool_message_batch = False

async for event in self._start_new_execution(input, message_batch=message_batch):
yield event
async for event in self._start_new_execution(input, message_batch=message_batch):
yield event
else:
if assistant_message_ids:
skip_tool_message_batch = True
continue

async def _ensure_session_exists(self, app_name: str, user_id: str, session_id: str, initial_state: dict):
"""Ensure a session exists, creating it if necessary via session manager."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,102 @@ async def mock_start_new_execution(input_data, *, tool_results=None, message_bat
processed_ids = ag_ui_adk._session_manager.get_processed_message_ids(app_name, input_data.thread_id)
assert "assistant_tool" in processed_ids

@pytest.mark.asyncio
async def test_run_does_not_restart_execution_for_replayed_prompt_with_tool_results(self, ag_ui_adk):
"""Replayed prompts with new IDs should not spawn duplicate executions when tool results are present."""
replayed_user = UserMessage(
id="user_replay_v2",
role="user",
content="Check status",
)

assistant_call = AssistantMessage(
id="assistant_tool_v2",
role="assistant",
content=None,
tool_calls=[
ToolCall(
id="call_replay",
function=FunctionCall(name="test_tool", arguments="{}"),
)
],
)

tool_result = ToolMessage(
id="tool_result_v2",
role="tool",
content='{"result": "value"}',
tool_call_id="call_replay",
)

input_data = RunAgentInput(
thread_id="thread_replayed_prompt",
run_id="run_replayed_prompt",
messages=[
replayed_user,
assistant_call,
tool_result,
],
tools=[],
context=[],
state={},
forwarded_props={},
)

app_name = ag_ui_adk._get_app_name(input_data)
user_id = ag_ui_adk._get_user_id(input_data)
await ag_ui_adk._ensure_session_exists(app_name, user_id, input_data.thread_id, input_data.state)
await ag_ui_adk._add_pending_tool_call_with_context(
input_data.thread_id,
"call_replay",
app_name,
user_id,
)

start_calls = []

async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None):
start_calls.append((tool_results, message_batch))
yield RunStartedEvent(
type=EventType.RUN_STARTED,
thread_id=input_data.thread_id,
run_id=input_data.run_id,
)
yield RunFinishedEvent(
type=EventType.RUN_FINISHED,
thread_id=input_data.thread_id,
run_id=input_data.run_id,
)

with patch.object(
ag_ui_adk,
'_start_new_execution',
side_effect=mock_start_new_execution,
) as start_mock, patch.object(
ag_ui_adk,
'_handle_tool_result_submission',
wraps=ag_ui_adk._handle_tool_result_submission,
) as handle_mock:
events = []
async for event in ag_ui_adk.run(input_data):
events.append(event)

assert [event.type for event in events] == [
EventType.RUN_STARTED,
EventType.RUN_FINISHED,
]

assert start_mock.call_count == 1
assert len(start_calls) == 1
tool_results, message_batch = start_calls[0]
assert tool_results is not None
assert message_batch is None
assert handle_mock.call_count == 1
assert handle_mock.call_args.kwargs.get('include_message_batch') is False

processed_ids = ag_ui_adk._session_manager.get_processed_message_ids(app_name, input_data.thread_id)
assert "user_replay_v2" in processed_ids

@pytest.mark.asyncio
async def test_run_preserves_order_for_user_then_tool(self, ag_ui_adk):
"""Verify user updates are handled before subsequent tool messages."""
Expand Down
Loading