Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 170 additions & 76 deletions integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,37 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]:
Yields:
AG-UI protocol events
"""
# Check if this is a tool result submission for an existing execution
if self._is_tool_result_submission(input):
# Handle tool results for existing execution
async for event in self._handle_tool_result_submission(input):
yield event
else:
# Start new execution for regular requests
unseen_messages = await self._get_unseen_messages(input)

if not unseen_messages:
# No unseen messages – fall through to normal execution handling
async for event in self._start_new_execution(input):
yield event
return

index = 0
total_unseen = len(unseen_messages)

while index < total_unseen:
current = unseen_messages[index]
role = getattr(current, "role", None)

if role == "tool":
tool_batch: List[Any] = []
while index < total_unseen and getattr(unseen_messages[index], "role", None) == "tool":
tool_batch.append(unseen_messages[index])
index += 1

async for event in self._handle_tool_result_submission(input, tool_messages=tool_batch):
yield event
else:
message_batch: List[Any] = []
while index < total_unseen and getattr(unseen_messages[index], "role", None) != "tool":
message_batch.append(unseen_messages[index])
index += 1

async for event in self._start_new_execution(input, message_batch=message_batch):
yield event

async def _ensure_session_exists(self, app_name: str, user_id: str, session_id: str, initial_state: dict):
"""Ensure a session exists, creating it if necessary via session manager."""
Expand All @@ -389,40 +411,77 @@ async def _ensure_session_exists(self, app_name: str, user_id: str, session_id:
logger.error(f"Failed to ensure session {session_id}: {e}")
raise

async def _convert_latest_message(self, input: RunAgentInput) -> Optional[types.Content]:
async def _convert_latest_message(
self,
input: RunAgentInput,
messages: Optional[List[Any]] = None,
) -> Optional[types.Content]:
"""Convert the latest user message to ADK Content format."""
if not input.messages:
target_messages = messages if messages is not None else input.messages

if not target_messages:
return None

# Get the latest user message
for message in reversed(input.messages):
if message.role == "user" and message.content:
for message in reversed(target_messages):
if getattr(message, "role", None) == "user" and getattr(message, "content", None):
return types.Content(
role="user",
parts=[types.Part(text=message.content)]
)

return None


def _is_tool_result_submission(self, input: RunAgentInput) -> bool:
async def _get_unseen_messages(self, input: RunAgentInput) -> List[Any]:
"""Return messages that have not yet been processed for this session."""
if not input.messages:
return []

app_name = self._get_app_name(input)
session_id = input.thread_id
processed_ids = self._session_manager.get_processed_message_ids(app_name, session_id)

unseen_reversed: List[Any] = []

for message in reversed(input.messages):
message_id = getattr(message, "id", None)
if message_id and message_id in processed_ids:
break
unseen_reversed.append(message)

unseen_reversed.reverse()
return unseen_reversed

def _collect_message_ids(self, messages: List[Any]) -> List[str]:
"""Extract message IDs from messages, skipping those without IDs."""
return [getattr(message, "id") for message in messages if getattr(message, "id", None)]

async def _is_tool_result_submission(
self,
input: RunAgentInput,
unseen_messages: Optional[List[Any]] = None,
) -> bool:
"""Check if this request contains tool results.

Args:
input: The run input

unseen_messages: Optional list of unseen messages to inspect

Returns:
True if the last message is a tool result
True if all unseen messages are tool results
"""
if not input.messages:
unseen_messages = unseen_messages if unseen_messages is not None else await self._get_unseen_messages(input)

if not unseen_messages:
return False

last_message = input.messages[-1]
return hasattr(last_message, 'role') and last_message.role == "tool"


return all(getattr(message, "role", None) == "tool" for message in unseen_messages)

async def _handle_tool_result_submission(
self,
input: RunAgentInput
self,
input: RunAgentInput,
tool_messages: Optional[List[Any]] = None,
) -> AsyncGenerator[BaseEvent, None]:
"""Handle tool result submission for existing execution.

Expand All @@ -434,8 +493,9 @@ async def _handle_tool_result_submission(
"""
thread_id = input.thread_id

# Extract tool results that is send by the frontend
tool_results = await self._extract_tool_results(input)
# Extract tool results that are sent by the frontend
candidate_messages = tool_messages if tool_messages is not None else await self._get_unseen_messages(input)
tool_results = await self._extract_tool_results(input, candidate_messages)

# if the tool results are not sent by the fronted then call the tool function
if not tool_results:
Expand Down Expand Up @@ -466,7 +526,7 @@ async def _handle_tool_result_submission(
# Since all tools are long-running, all tool results are standalone
# and should start new executions with the tool results
logger.info(f"Starting new execution for tool result in thread {thread_id}")
async for event in self._start_new_execution(input):
async for event in self._start_new_execution(input, tool_results=tool_results):
yield event

except Exception as e:
Expand All @@ -477,45 +537,49 @@ async def _handle_tool_result_submission(
code="TOOL_RESULT_PROCESSING_ERROR"
)

async def _extract_tool_results(self, input: RunAgentInput) -> List[Dict]:
async def _extract_tool_results(
self,
input: RunAgentInput,
candidate_messages: Optional[List[Any]] = None,
) -> List[Dict]:
"""Extract tool messages with their names from input.
Only extracts the most recent tool message to avoid accumulation issues
where multiple tool results are sent to the LLM causing API errors.

Only extracts tool messages provided in candidate_messages. When no
candidates are supplied, all messages are considered.

Args:
input: The run input

candidate_messages: Optional subset of messages to inspect

Returns:
List of dicts containing tool name and message (single item for most recent)
List of dicts containing tool name and message ordered chronologically
"""
# Create a mapping of tool_call_id to tool name
tool_call_map = {}
for message in input.messages:
if hasattr(message, 'tool_calls') and message.tool_calls:
for tool_call in message.tool_calls:
tool_call_map[tool_call.id] = tool_call.function.name

# Find the most recent tool message (should be the last one in a tool result submission)
most_recent_tool_message = None
for message in reversed(input.messages):

messages_to_check = candidate_messages or input.messages
extracted_results: List[Dict] = []

for message in messages_to_check:
if hasattr(message, 'role') and message.role == "tool":
most_recent_tool_message = message
break

if most_recent_tool_message:
tool_name = tool_call_map.get(most_recent_tool_message.tool_call_id, "unknown")

# Debug: Log the extracted tool message
logger.debug(f"Extracted most recent ToolMessage: role={most_recent_tool_message.role}, tool_call_id={most_recent_tool_message.tool_call_id}, content='{most_recent_tool_message.content}'")

return [{
'tool_name': tool_name,
'message': most_recent_tool_message
}]

return []

tool_name = tool_call_map.get(getattr(message, 'tool_call_id', None), "unknown")
logger.debug(
"Extracted ToolMessage: role=%s, tool_call_id=%s, content='%s'",
getattr(message, 'role', None),
getattr(message, 'tool_call_id', None),
getattr(message, 'content', None),
)
extracted_results.append({
'tool_name': tool_name,
'message': message
})

return extracted_results

async def _stream_events(
self,
execution: ExecutionState
Expand Down Expand Up @@ -588,8 +652,11 @@ async def _stream_events(
break

async def _start_new_execution(
self,
input: RunAgentInput
self,
input: RunAgentInput,
*,
tool_results: Optional[List[Dict]] = None,
message_batch: Optional[List[Any]] = None,
) -> AsyncGenerator[BaseEvent, None]:
"""Start a new ADK execution with tool support.

Expand Down Expand Up @@ -631,7 +698,11 @@ async def _start_new_execution(
logger.debug(f"Previous execution completed with error: {e}")

# Start background execution
execution = await self._start_background_execution(input)
execution = await self._start_background_execution(
input,
tool_results=tool_results,
message_batch=message_batch,
)

# Store execution (replacing any previous one)
async with self._execution_lock:
Expand Down Expand Up @@ -703,8 +774,11 @@ async def _start_new_execution(
logger.info(f"Preserving execution for thread {input.thread_id} - has pending tool calls (HITL scenario)")

async def _start_background_execution(
self,
input: RunAgentInput
self,
input: RunAgentInput,
*,
tool_results: Optional[List[Dict]] = None,
message_batch: Optional[List[Any]] = None,
) -> ExecutionState:
"""Start ADK execution in background with tool support.

Expand Down Expand Up @@ -806,7 +880,9 @@ def instruction_provider_wrapper_sync(*args, **kwargs):
adk_agent=adk_agent,
user_id=user_id,
app_name=app_name,
event_queue=event_queue
event_queue=event_queue,
tool_results=tool_results,
message_batch=message_batch,
)
)
logger.debug(f"Background task created for thread {input.thread_id}: {task}")
Expand All @@ -823,7 +899,9 @@ async def _run_adk_in_background(
adk_agent: BaseAgent,
user_id: str,
app_name: str,
event_queue: asyncio.Queue
event_queue: asyncio.Queue,
tool_results: Optional[List[Dict]] = None,
message_batch: Optional[List[Any]] = None,
):
"""Run ADK agent in background, emitting events to queue.

Expand Down Expand Up @@ -860,20 +938,35 @@ async def _run_adk_in_background(


# Convert messages
unseen_messages = message_batch if message_batch is not None else await self._get_unseen_messages(input)

active_tool_results: Optional[List[Dict]] = tool_results
if active_tool_results is None and await self._is_tool_result_submission(input, unseen_messages):
active_tool_results = await self._extract_tool_results(input, unseen_messages)

if active_tool_results:
tool_messages = [result["message"] for result in active_tool_results]
message_ids = self._collect_message_ids(tool_messages)
if message_ids:
self._session_manager.mark_messages_processed(app_name, input.thread_id, message_ids)
elif unseen_messages:
message_ids = self._collect_message_ids(unseen_messages)
if message_ids:
self._session_manager.mark_messages_processed(app_name, input.thread_id, message_ids)

# only use this new_message if there is no tool response from the user
new_message = await self._convert_latest_message(input)
new_message = await self._convert_latest_message(input, unseen_messages if message_batch is not None else None)

# if there is a tool response submission by the user then we need to only pass the tool response to the adk runner
if self._is_tool_result_submission(input):
tool_results = await self._extract_tool_results(input)
if active_tool_results:
parts = []
for tool_msg in tool_results:
for tool_msg in active_tool_results:
tool_call_id = tool_msg['message'].tool_call_id
content = tool_msg['message'].content

# Debug: Log the actual tool message content we received
logger.debug(f"Received tool result for call {tool_call_id}: content='{content}', type={type(content)}")

# Parse JSON content, handling empty or invalid JSON gracefully
try:
if content and content.strip():
Expand All @@ -885,23 +978,24 @@ async def _run_adk_in_background(
except json.JSONDecodeError as json_error:
# Handle invalid JSON by providing detailed error result
result = {
"error": f"Invalid JSON in tool result: {str(json_error)}",
"error": f"Invalid JSON in tool result: {str(json_error)}",
"raw_content": content,
"error_type": "JSON_DECODE_ERROR",
"line": getattr(json_error, 'lineno', None),
"column": getattr(json_error, 'colno', None)
}
logger.error(f"Invalid JSON in tool result for call {tool_call_id}: {json_error} at line {getattr(json_error, 'lineno', '?')}, column {getattr(json_error, 'colno', '?')}")

updated_function_response_part = types.Part(
function_response=types.FunctionResponse(
id= tool_call_id,
name=tool_msg["tool_name"],
response=result,
function_response=types.FunctionResponse(
id=tool_call_id,
name=tool_msg["tool_name"],
response=result,
)
)
)
parts.append(updated_function_response_part)
new_message = types.Content(parts=parts, role='user')
new_message = types.Content(parts=parts, role='function')

# Create event translator
event_translator = EventTranslator()

Expand Down
Loading
Loading