Skip to content

Commit e52c533

Browse files
Merge pull request #115 from Contextable/codex/filter-assistant-messages-in-adkagent.run
Filter assistant transcripts from ADK message batching
2 parents 6002818 + 22874e1 commit e52c533

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
366366

367367
index = 0
368368
total_unseen = len(unseen_messages)
369+
app_name = self._get_app_name(input)
369370

370371
while index < total_unseen:
371372
current = unseen_messages[index]
@@ -381,10 +382,31 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
381382
yield event
382383
else:
383384
message_batch: List[Any] = []
385+
assistant_message_ids: List[str] = []
386+
384387
while index < total_unseen and getattr(unseen_messages[index], "role", None) != "tool":
385-
message_batch.append(unseen_messages[index])
388+
candidate = unseen_messages[index]
389+
candidate_role = getattr(candidate, "role", None)
390+
391+
if candidate_role == "assistant":
392+
message_id = getattr(candidate, "id", None)
393+
if message_id:
394+
assistant_message_ids.append(message_id)
395+
else:
396+
message_batch.append(candidate)
397+
386398
index += 1
387399

400+
if assistant_message_ids:
401+
self._session_manager.mark_messages_processed(
402+
app_name,
403+
input.thread_id,
404+
assistant_message_ids,
405+
)
406+
407+
if not message_batch:
408+
continue
409+
388410
async for event in self._start_new_execution(input, message_batch=message_batch):
389411
yield event
390412

integrations/adk-middleware/python/tests/test_tool_result_flow.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from ag_ui.core import (
99
RunAgentInput, BaseEvent, EventType, Tool as AGUITool,
10-
UserMessage, ToolMessage, RunStartedEvent, RunFinishedEvent, RunErrorEvent
10+
UserMessage, ToolMessage, RunStartedEvent, RunFinishedEvent, RunErrorEvent,
11+
AssistantMessage, ToolCall, FunctionCall,
1112
)
1213

1314
from ag_ui_adk import ADKAgent
@@ -530,6 +531,116 @@ async def mock_start_new_execution(input_data, *, tool_results=None, message_bat
530531
assert len(tool_messages) == 1
531532
assert getattr(tool_messages[0], 'id', None) == "tool_1"
532533

534+
@pytest.mark.asyncio
535+
async def test_run_skips_assistant_history_before_tool_result(self, ag_ui_adk):
536+
"""Assistant tool call history should not trigger a new execution before tool results arrive."""
537+
assistant_call = AssistantMessage(
538+
id="assistant_tool",
539+
role="assistant",
540+
content=None,
541+
tool_calls=[
542+
ToolCall(
543+
id="call_1",
544+
function=FunctionCall(name="test_tool", arguments="{}"),
545+
)
546+
],
547+
)
548+
549+
tool_result = ToolMessage(
550+
id="tool_result",
551+
role="tool",
552+
content='{"result": "value"}',
553+
tool_call_id="call_1",
554+
)
555+
556+
input_data = RunAgentInput(
557+
thread_id="thread_assistant_tool",
558+
run_id="run_assistant_tool",
559+
messages=[
560+
UserMessage(id="user_initial", role="user", content="Initial question"),
561+
assistant_call,
562+
tool_result,
563+
],
564+
tools=[],
565+
context=[],
566+
state={},
567+
forwarded_props={},
568+
)
569+
570+
# Mark the initial user message as already processed so only the assistant call and tool result are unseen
571+
app_name = ag_ui_adk._get_app_name(input_data)
572+
ag_ui_adk._session_manager.mark_messages_processed(app_name, input_data.thread_id, ["user_initial"])
573+
574+
start_calls = []
575+
576+
async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None):
577+
start_calls.append((tool_results, message_batch))
578+
579+
call_id = None
580+
if tool_results:
581+
call_id = tool_results[0]['message'].tool_call_id
582+
elif message_batch:
583+
for message in message_batch:
584+
tool_calls = getattr(message, "tool_calls", None)
585+
if tool_calls:
586+
call_id = tool_calls[0].id
587+
break
588+
589+
if call_id:
590+
await ag_ui_adk._add_pending_tool_call_with_context(
591+
input_data.thread_id,
592+
call_id,
593+
ag_ui_adk._get_app_name(input_data),
594+
ag_ui_adk._get_user_id(input_data),
595+
)
596+
597+
yield RunStartedEvent(
598+
type=EventType.RUN_STARTED,
599+
thread_id=input_data.thread_id,
600+
run_id=input_data.run_id,
601+
)
602+
yield RunFinishedEvent(
603+
type=EventType.RUN_FINISHED,
604+
thread_id=input_data.thread_id,
605+
run_id=input_data.run_id,
606+
)
607+
608+
with patch.object(
609+
ag_ui_adk,
610+
'_start_new_execution',
611+
side_effect=mock_start_new_execution,
612+
) as start_mock, patch.object(
613+
ag_ui_adk,
614+
'_handle_tool_result_submission',
615+
wraps=ag_ui_adk._handle_tool_result_submission,
616+
), patch.object(
617+
ag_ui_adk,
618+
'_add_pending_tool_call_with_context',
619+
new_callable=AsyncMock,
620+
) as pending_mock:
621+
events = []
622+
async for event in ag_ui_adk.run(input_data):
623+
events.append(event)
624+
625+
assert [event.type for event in events] == [
626+
EventType.RUN_STARTED,
627+
EventType.RUN_FINISHED,
628+
]
629+
630+
assert start_mock.call_count == 1
631+
assert len(start_calls) == 1
632+
first_tool_results, first_batch = start_calls[0]
633+
assert first_tool_results is not None
634+
assert first_batch is None
635+
assert first_tool_results[0]['message'].id == "tool_result"
636+
637+
assert pending_mock.await_count == 1
638+
pending_call = pending_mock.await_args_list[0]
639+
assert pending_call.args[1] == "call_1"
640+
641+
processed_ids = ag_ui_adk._session_manager.get_processed_message_ids(app_name, input_data.thread_id)
642+
assert "assistant_tool" in processed_ids
643+
533644
@pytest.mark.asyncio
534645
async def test_run_preserves_order_for_user_then_tool(self, ag_ui_adk):
535646
"""Verify user updates are handled before subsequent tool messages."""

0 commit comments

Comments
 (0)