Skip to content

Commit d60a15e

Browse files
Preserve system messages when filtering assistant transcripts
1 parent be47f9b commit d60a15e

File tree

1 file changed

+38
-14
lines changed
  • integrations/adk-middleware/python/src/ag_ui_adk

1 file changed

+38
-14
lines changed

integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
367367
index = 0
368368
total_unseen = len(unseen_messages)
369369
app_name = self._get_app_name(input)
370+
skip_tool_message_batch = False
370371

371372
while index < total_unseen:
372373
current = unseen_messages[index]
@@ -378,8 +379,13 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
378379
tool_batch.append(unseen_messages[index])
379380
index += 1
380381

381-
async for event in self._handle_tool_result_submission(input, tool_messages=tool_batch):
382+
async for event in self._handle_tool_result_submission(
383+
input,
384+
tool_messages=tool_batch,
385+
include_message_batch=not skip_tool_message_batch,
386+
):
382387
yield event
388+
skip_tool_message_batch = False
383389
else:
384390
message_batch: List[Any] = []
385391
assistant_message_ids: List[str] = []
@@ -405,7 +411,11 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
405411
)
406412

407413
if not message_batch:
414+
if assistant_message_ids:
415+
skip_tool_message_batch = True
408416
continue
417+
else:
418+
skip_tool_message_batch = False
409419

410420
async for event in self._start_new_execution(input, message_batch=message_batch):
411421
yield event
@@ -498,17 +508,22 @@ async def _is_tool_result_submission(
498508
if not unseen_messages:
499509
return False
500510

501-
return all(getattr(message, "role", None) == "tool" for message in unseen_messages)
511+
last_message = unseen_messages[-1]
512+
return getattr(last_message, "role", None) == "tool"
502513

503514
async def _handle_tool_result_submission(
504515
self,
505516
input: RunAgentInput,
517+
*,
506518
tool_messages: Optional[List[Any]] = None,
519+
include_message_batch: bool = True,
507520
) -> AsyncGenerator[BaseEvent, None]:
508521
"""Handle tool result submission for existing execution.
509522
510523
Args:
511524
input: The run input containing tool results
525+
tool_messages: Optional pre-filtered tool messages to consider
526+
include_message_batch: Whether to forward the candidate messages to the execution
512527
513528
Yields:
514529
AG-UI events from continued execution
@@ -548,7 +563,12 @@ async def _handle_tool_result_submission(
548563
# Since all tools are long-running, all tool results are standalone
549564
# and should start new executions with the tool results
550565
logger.info(f"Starting new execution for tool result in thread {thread_id}")
551-
async for event in self._start_new_execution(input, tool_results=tool_results):
566+
message_batch = candidate_messages if include_message_batch else None
567+
async for event in self._start_new_execution(
568+
input,
569+
tool_results=tool_results,
570+
message_batch=message_batch,
571+
):
552572
yield event
553573

554574
except Exception as e:
@@ -896,17 +916,21 @@ def instruction_provider_wrapper_sync(*args, **kwargs):
896916

897917
# Create background task
898918
logger.debug(f"Creating background task for thread {input.thread_id}")
899-
task = asyncio.create_task(
900-
self._run_adk_in_background(
901-
input=input,
902-
adk_agent=adk_agent,
903-
user_id=user_id,
904-
app_name=app_name,
905-
event_queue=event_queue,
906-
tool_results=tool_results,
907-
message_batch=message_batch,
908-
)
909-
)
919+
run_kwargs = {
920+
"input": input,
921+
"adk_agent": adk_agent,
922+
"user_id": user_id,
923+
"app_name": app_name,
924+
"event_queue": event_queue,
925+
}
926+
927+
if tool_results is not None:
928+
run_kwargs["tool_results"] = tool_results
929+
930+
if message_batch is not None:
931+
run_kwargs["message_batch"] = message_batch
932+
933+
task = asyncio.create_task(self._run_adk_in_background(**run_kwargs))
910934
logger.debug(f"Background task created for thread {input.thread_id}: {task}")
911935

912936
return ExecutionState(

0 commit comments

Comments
 (0)