diff --git a/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py b/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py index a28eb769c..ce76a5b7c 100644 --- a/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py +++ b/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py @@ -356,15 +356,37 @@ async def run(self, input: RunAgentInput) -> AsyncGenerator[BaseEvent, None]: Yields: AG-UI protocol events """ - # Check if this is a tool result submission for an existing execution - if self._is_tool_result_submission(input): - # Handle tool results for existing execution - async for event in self._handle_tool_result_submission(input): - yield event - else: - # Start new execution for regular requests + unseen_messages = await self._get_unseen_messages(input) + + if not unseen_messages: + # No unseen messages – fall through to normal execution handling async for event in self._start_new_execution(input): yield event + return + + index = 0 + total_unseen = len(unseen_messages) + + while index < total_unseen: + current = unseen_messages[index] + role = getattr(current, "role", None) + + if role == "tool": + tool_batch: List[Any] = [] + while index < total_unseen and getattr(unseen_messages[index], "role", None) == "tool": + tool_batch.append(unseen_messages[index]) + index += 1 + + async for event in self._handle_tool_result_submission(input, tool_messages=tool_batch): + yield event + else: + message_batch: List[Any] = [] + while index < total_unseen and getattr(unseen_messages[index], "role", None) != "tool": + message_batch.append(unseen_messages[index]) + index += 1 + + async for event in self._start_new_execution(input, message_batch=message_batch): + yield event async def _ensure_session_exists(self, app_name: str, user_id: str, session_id: str, initial_state: dict): """Ensure a session exists, creating it if necessary via session manager.""" @@ -389,40 +411,77 @@ async def _ensure_session_exists(self, app_name: str, user_id: str, session_id: logger.error(f"Failed to ensure session {session_id}: {e}") raise - async def _convert_latest_message(self, input: RunAgentInput) -> Optional[types.Content]: + async def _convert_latest_message( + self, + input: RunAgentInput, + messages: Optional[List[Any]] = None, + ) -> Optional[types.Content]: """Convert the latest user message to ADK Content format.""" - if not input.messages: + target_messages = messages if messages is not None else input.messages + + if not target_messages: return None - + # Get the latest user message - for message in reversed(input.messages): - if message.role == "user" and message.content: + for message in reversed(target_messages): + if getattr(message, "role", None) == "user" and getattr(message, "content", None): return types.Content( role="user", parts=[types.Part(text=message.content)] ) - + return None - def _is_tool_result_submission(self, input: RunAgentInput) -> bool: + async def _get_unseen_messages(self, input: RunAgentInput) -> List[Any]: + """Return messages that have not yet been processed for this session.""" + if not input.messages: + return [] + + app_name = self._get_app_name(input) + session_id = input.thread_id + processed_ids = self._session_manager.get_processed_message_ids(app_name, session_id) + + unseen_reversed: List[Any] = [] + + for message in reversed(input.messages): + message_id = getattr(message, "id", None) + if message_id and message_id in processed_ids: + break + unseen_reversed.append(message) + + unseen_reversed.reverse() + return unseen_reversed + + def _collect_message_ids(self, messages: List[Any]) -> List[str]: + """Extract message IDs from messages, skipping those without IDs.""" + return [getattr(message, "id") for message in messages if getattr(message, "id", None)] + + async def _is_tool_result_submission( + self, + input: RunAgentInput, + unseen_messages: Optional[List[Any]] = None, + ) -> bool: """Check if this request contains tool results. - + Args: input: The run input - + unseen_messages: Optional list of unseen messages to inspect + Returns: - True if the last message is a tool result + True if all unseen messages are tool results """ - if not input.messages: + unseen_messages = unseen_messages if unseen_messages is not None else await self._get_unseen_messages(input) + + if not unseen_messages: return False - - last_message = input.messages[-1] - return hasattr(last_message, 'role') and last_message.role == "tool" - + + return all(getattr(message, "role", None) == "tool" for message in unseen_messages) + async def _handle_tool_result_submission( - self, - input: RunAgentInput + self, + input: RunAgentInput, + tool_messages: Optional[List[Any]] = None, ) -> AsyncGenerator[BaseEvent, None]: """Handle tool result submission for existing execution. @@ -434,8 +493,9 @@ async def _handle_tool_result_submission( """ thread_id = input.thread_id - # Extract tool results that is send by the frontend - tool_results = await self._extract_tool_results(input) + # Extract tool results that are sent by the frontend + candidate_messages = tool_messages if tool_messages is not None else await self._get_unseen_messages(input) + tool_results = await self._extract_tool_results(input, candidate_messages) # if the tool results are not sent by the fronted then call the tool function if not tool_results: @@ -466,7 +526,7 @@ async def _handle_tool_result_submission( # Since all tools are long-running, all tool results are standalone # and should start new executions with the tool results logger.info(f"Starting new execution for tool result in thread {thread_id}") - async for event in self._start_new_execution(input): + async for event in self._start_new_execution(input, tool_results=tool_results): yield event except Exception as e: @@ -477,17 +537,22 @@ async def _handle_tool_result_submission( code="TOOL_RESULT_PROCESSING_ERROR" ) - async def _extract_tool_results(self, input: RunAgentInput) -> List[Dict]: + async def _extract_tool_results( + self, + input: RunAgentInput, + candidate_messages: Optional[List[Any]] = None, + ) -> List[Dict]: """Extract tool messages with their names from input. - - Only extracts the most recent tool message to avoid accumulation issues - where multiple tool results are sent to the LLM causing API errors. - + + Only extracts tool messages provided in candidate_messages. When no + candidates are supplied, all messages are considered. + Args: input: The run input - + candidate_messages: Optional subset of messages to inspect + Returns: - List of dicts containing tool name and message (single item for most recent) + List of dicts containing tool name and message ordered chronologically """ # Create a mapping of tool_call_id to tool name tool_call_map = {} @@ -495,27 +560,26 @@ async def _extract_tool_results(self, input: RunAgentInput) -> List[Dict]: if hasattr(message, 'tool_calls') and message.tool_calls: for tool_call in message.tool_calls: tool_call_map[tool_call.id] = tool_call.function.name - - # Find the most recent tool message (should be the last one in a tool result submission) - most_recent_tool_message = None - for message in reversed(input.messages): + + messages_to_check = candidate_messages or input.messages + extracted_results: List[Dict] = [] + + for message in messages_to_check: if hasattr(message, 'role') and message.role == "tool": - most_recent_tool_message = message - break - - if most_recent_tool_message: - tool_name = tool_call_map.get(most_recent_tool_message.tool_call_id, "unknown") - - # Debug: Log the extracted tool message - logger.debug(f"Extracted most recent ToolMessage: role={most_recent_tool_message.role}, tool_call_id={most_recent_tool_message.tool_call_id}, content='{most_recent_tool_message.content}'") - - return [{ - 'tool_name': tool_name, - 'message': most_recent_tool_message - }] - - return [] - + tool_name = tool_call_map.get(getattr(message, 'tool_call_id', None), "unknown") + logger.debug( + "Extracted ToolMessage: role=%s, tool_call_id=%s, content='%s'", + getattr(message, 'role', None), + getattr(message, 'tool_call_id', None), + getattr(message, 'content', None), + ) + extracted_results.append({ + 'tool_name': tool_name, + 'message': message + }) + + return extracted_results + async def _stream_events( self, execution: ExecutionState @@ -588,8 +652,11 @@ async def _stream_events( break async def _start_new_execution( - self, - input: RunAgentInput + self, + input: RunAgentInput, + *, + tool_results: Optional[List[Dict]] = None, + message_batch: Optional[List[Any]] = None, ) -> AsyncGenerator[BaseEvent, None]: """Start a new ADK execution with tool support. @@ -631,7 +698,11 @@ async def _start_new_execution( logger.debug(f"Previous execution completed with error: {e}") # Start background execution - execution = await self._start_background_execution(input) + execution = await self._start_background_execution( + input, + tool_results=tool_results, + message_batch=message_batch, + ) # Store execution (replacing any previous one) async with self._execution_lock: @@ -703,8 +774,11 @@ async def _start_new_execution( logger.info(f"Preserving execution for thread {input.thread_id} - has pending tool calls (HITL scenario)") async def _start_background_execution( - self, - input: RunAgentInput + self, + input: RunAgentInput, + *, + tool_results: Optional[List[Dict]] = None, + message_batch: Optional[List[Any]] = None, ) -> ExecutionState: """Start ADK execution in background with tool support. @@ -806,7 +880,9 @@ def instruction_provider_wrapper_sync(*args, **kwargs): adk_agent=adk_agent, user_id=user_id, app_name=app_name, - event_queue=event_queue + event_queue=event_queue, + tool_results=tool_results, + message_batch=message_batch, ) ) logger.debug(f"Background task created for thread {input.thread_id}: {task}") @@ -823,7 +899,9 @@ async def _run_adk_in_background( adk_agent: BaseAgent, user_id: str, app_name: str, - event_queue: asyncio.Queue + event_queue: asyncio.Queue, + tool_results: Optional[List[Dict]] = None, + message_batch: Optional[List[Any]] = None, ): """Run ADK agent in background, emitting events to queue. @@ -860,20 +938,35 @@ async def _run_adk_in_background( # Convert messages + unseen_messages = message_batch if message_batch is not None else await self._get_unseen_messages(input) + + active_tool_results: Optional[List[Dict]] = tool_results + if active_tool_results is None and await self._is_tool_result_submission(input, unseen_messages): + active_tool_results = await self._extract_tool_results(input, unseen_messages) + + if active_tool_results: + tool_messages = [result["message"] for result in active_tool_results] + message_ids = self._collect_message_ids(tool_messages) + if message_ids: + self._session_manager.mark_messages_processed(app_name, input.thread_id, message_ids) + elif unseen_messages: + message_ids = self._collect_message_ids(unseen_messages) + if message_ids: + self._session_manager.mark_messages_processed(app_name, input.thread_id, message_ids) + # only use this new_message if there is no tool response from the user - new_message = await self._convert_latest_message(input) - + new_message = await self._convert_latest_message(input, unseen_messages if message_batch is not None else None) + # if there is a tool response submission by the user then we need to only pass the tool response to the adk runner - if self._is_tool_result_submission(input): - tool_results = await self._extract_tool_results(input) + if active_tool_results: parts = [] - for tool_msg in tool_results: + for tool_msg in active_tool_results: tool_call_id = tool_msg['message'].tool_call_id content = tool_msg['message'].content - + # Debug: Log the actual tool message content we received logger.debug(f"Received tool result for call {tool_call_id}: content='{content}', type={type(content)}") - + # Parse JSON content, handling empty or invalid JSON gracefully try: if content and content.strip(): @@ -885,23 +978,24 @@ async def _run_adk_in_background( except json.JSONDecodeError as json_error: # Handle invalid JSON by providing detailed error result result = { - "error": f"Invalid JSON in tool result: {str(json_error)}", + "error": f"Invalid JSON in tool result: {str(json_error)}", "raw_content": content, "error_type": "JSON_DECODE_ERROR", "line": getattr(json_error, 'lineno', None), "column": getattr(json_error, 'colno', None) } logger.error(f"Invalid JSON in tool result for call {tool_call_id}: {json_error} at line {getattr(json_error, 'lineno', '?')}, column {getattr(json_error, 'colno', '?')}") - + updated_function_response_part = types.Part( - function_response=types.FunctionResponse( - id= tool_call_id, - name=tool_msg["tool_name"], - response=result, + function_response=types.FunctionResponse( + id=tool_call_id, + name=tool_msg["tool_name"], + response=result, + ) ) - ) parts.append(updated_function_response_part) - new_message = types.Content(parts=parts, role='user') + new_message = types.Content(parts=parts, role='function') + # Create event translator event_translator = EventTranslator() diff --git a/integrations/adk-middleware/python/src/ag_ui_adk/session_manager.py b/integrations/adk-middleware/python/src/ag_ui_adk/session_manager.py index fdb7e3125..bd0a8b6bd 100644 --- a/integrations/adk-middleware/python/src/ag_ui_adk/session_manager.py +++ b/integrations/adk-middleware/python/src/ag_ui_adk/session_manager.py @@ -2,7 +2,7 @@ """Session manager that adds production features to ADK's native session service.""" -from typing import Dict, Optional, Set, Any, Union +from typing import Dict, Optional, Set, Any, Union, Iterable import asyncio import logging import time @@ -67,6 +67,7 @@ def __init__( # Minimal tracking: just keys and user counts self._session_keys: Set[str] = set() # "app_name:session_id" keys self._user_sessions: Dict[str, Set[str]] = {} # user_id -> set of session_keys + self._processed_message_ids: Dict[str, Set[str]] = {} self._cleanup_task: Optional[asyncio.Task] = None self._initialized = True @@ -108,7 +109,7 @@ async def get_or_create_session( Returns the ADK session object directly. """ - session_key = f"{app_name}:{session_id}" + session_key = self._make_session_key(app_name, session_id) # Check user limits before creating if session_key not in self._session_keys and self._max_per_user: @@ -504,19 +505,40 @@ async def bulk_update_user_state( def _track_session(self, session_key: str, user_id: str): """Track a session key for enumeration.""" self._session_keys.add(session_key) - + if user_id not in self._user_sessions: self._user_sessions[user_id] = set() self._user_sessions[user_id].add(session_key) - + def _untrack_session(self, session_key: str, user_id: str): """Remove session tracking.""" self._session_keys.discard(session_key) - + self._processed_message_ids.pop(session_key, None) + if user_id in self._user_sessions: self._user_sessions[user_id].discard(session_key) if not self._user_sessions[user_id]: del self._user_sessions[user_id] + + def _make_session_key(self, app_name: str, session_id: str) -> str: + return f"{app_name}:{session_id}" + + def get_processed_message_ids(self, app_name: str, session_id: str) -> Set[str]: + session_key = self._make_session_key(app_name, session_id) + return set(self._processed_message_ids.get(session_key, set())) + + def mark_messages_processed( + self, + app_name: str, + session_id: str, + message_ids: Iterable[str], + ) -> None: + session_key = self._make_session_key(app_name, session_id) + processed_ids = self._processed_message_ids.setdefault(session_key, set()) + + for message_id in message_ids: + if message_id: + processed_ids.add(message_id) async def _remove_oldest_user_session(self, user_id: str): """Remove the oldest session for a user based on lastUpdateTime.""" @@ -544,7 +566,7 @@ async def _remove_oldest_user_session(self, user_id: str): logger.error(f"Error checking session {session_key}: {e}") if oldest_session: - session_key = f"{oldest_session.app_name}:{oldest_session.id}" + session_key = self._make_session_key(oldest_session.app_name, oldest_session.id) await self._delete_session(oldest_session) logger.info(f"Removed oldest session for user {user_id}: {session_key}") diff --git a/integrations/adk-middleware/python/tests/test_tool_result_flow.py b/integrations/adk-middleware/python/tests/test_tool_result_flow.py index 967900346..ba4938bfe 100644 --- a/integrations/adk-middleware/python/tests/test_tool_result_flow.py +++ b/integrations/adk-middleware/python/tests/test_tool_result_flow.py @@ -2,7 +2,6 @@ """Test tool result submission flow in ADKAgent.""" import pytest -import asyncio import json from unittest.mock import AsyncMock, MagicMock, patch @@ -12,6 +11,7 @@ ) from ag_ui_adk import ADKAgent +from ag_ui_adk.session_manager import SessionManager class TestToolResultFlow: @@ -45,14 +45,20 @@ def mock_adk_agent(self): @pytest.fixture def ag_ui_adk(self, mock_adk_agent): """Create ADK middleware with mocked dependencies.""" - return ADKAgent( + SessionManager.reset_instance() + agent = ADKAgent( adk_agent=mock_adk_agent, user_id="test_user", execution_timeout_seconds=60, tool_timeout_seconds=30 ) + try: + yield agent + finally: + SessionManager.reset_instance() - def test_is_tool_result_submission_with_tool_message(self, ag_ui_adk): + @pytest.mark.asyncio + async def test_is_tool_result_submission_with_tool_message(self, ag_ui_adk): """Test detection of tool result submission.""" # Input with tool message as last message input_with_tool = RunAgentInput( @@ -68,9 +74,10 @@ def test_is_tool_result_submission_with_tool_message(self, ag_ui_adk): forwarded_props={} ) - assert ag_ui_adk._is_tool_result_submission(input_with_tool) is True + assert await ag_ui_adk._is_tool_result_submission(input_with_tool) is True - def test_is_tool_result_submission_with_user_message(self, ag_ui_adk): + @pytest.mark.asyncio + async def test_is_tool_result_submission_with_user_message(self, ag_ui_adk): """Test detection when last message is not a tool result.""" # Input with user message as last message input_without_tool = RunAgentInput( @@ -86,9 +93,10 @@ def test_is_tool_result_submission_with_user_message(self, ag_ui_adk): forwarded_props={} ) - assert ag_ui_adk._is_tool_result_submission(input_without_tool) is False + assert await ag_ui_adk._is_tool_result_submission(input_without_tool) is False - def test_is_tool_result_submission_empty_messages(self, ag_ui_adk): + @pytest.mark.asyncio + async def test_is_tool_result_submission_empty_messages(self, ag_ui_adk): """Test detection with empty messages.""" empty_input = RunAgentInput( thread_id="thread_1", @@ -100,7 +108,72 @@ def test_is_tool_result_submission_empty_messages(self, ag_ui_adk): forwarded_props={} ) - assert ag_ui_adk._is_tool_result_submission(empty_input) is False + assert await ag_ui_adk._is_tool_result_submission(empty_input) is False + + @pytest.mark.asyncio + async def test_is_tool_result_submission_ignores_processed_history(self, ag_ui_adk): + """Ensure previously processed tool messages are ignored.""" + replay_input = RunAgentInput( + thread_id="thread_1", + run_id="run_1", + messages=[ + UserMessage(id="1", role="user", content="Do something"), + ToolMessage(id="2", role="tool", content='{"result": "success"}', tool_call_id="call_1") + ], + tools=[], + context=[], + state={}, + forwarded_props={} + ) + + app_name = ag_ui_adk._get_app_name(replay_input) + ag_ui_adk._session_manager.mark_messages_processed(app_name, replay_input.thread_id, ["1", "2"]) + + assert await ag_ui_adk._is_tool_result_submission(replay_input) is False + + @pytest.mark.asyncio + async def test_is_tool_result_submission_multiple_tool_messages(self, ag_ui_adk): + """Detect tool submissions when multiple unseen tool results arrive together.""" + batched_input = RunAgentInput( + thread_id="thread_1", + run_id="run_1", + messages=[ + UserMessage(id="1", role="user", content="First"), + ToolMessage(id="2", role="tool", content='{"result": "partial"}', tool_call_id="call_1"), + ToolMessage(id="3", role="tool", content='{"result": "done"}', tool_call_id="call_2") + ], + tools=[], + context=[], + state={}, + forwarded_props={} + ) + + app_name = ag_ui_adk._get_app_name(batched_input) + ag_ui_adk._session_manager.mark_messages_processed(app_name, batched_input.thread_id, ["1"]) + + assert await ag_ui_adk._is_tool_result_submission(batched_input) is True + + @pytest.mark.asyncio + async def test_is_tool_result_submission_new_user_after_tool(self, ag_ui_adk): + """Treat batched updates that end with a user message as non-tool submissions.""" + batched_input = RunAgentInput( + thread_id="thread_1", + run_id="run_1", + messages=[ + UserMessage(id="1", role="user", content="First"), + ToolMessage(id="2", role="tool", content='{"result": "intermediate"}', tool_call_id="call_1"), + UserMessage(id="3", role="user", content="Thanks!") + ], + tools=[], + context=[], + state={}, + forwarded_props={} + ) + + app_name = ag_ui_adk._get_app_name(batched_input) + ag_ui_adk._session_manager.mark_messages_processed(app_name, batched_input.thread_id, ["1"]) + + assert await ag_ui_adk._is_tool_result_submission(batched_input) is False @pytest.mark.asyncio async def test_extract_tool_results_single_tool(self, ag_ui_adk): @@ -118,7 +191,7 @@ async def test_extract_tool_results_single_tool(self, ag_ui_adk): forwarded_props={} ) - tool_results = await ag_ui_adk._extract_tool_results(input_data) + tool_results = await ag_ui_adk._extract_tool_results(input_data, input_data.messages) assert len(tool_results) == 1 assert tool_results[0]['message'].role == "tool" @@ -128,7 +201,7 @@ async def test_extract_tool_results_single_tool(self, ag_ui_adk): @pytest.mark.asyncio async def test_extract_tool_results_multiple_tools(self, ag_ui_adk): - """Test extraction of most recent tool result when multiple exist.""" + """Test extraction of all unseen tool results when multiple exist.""" input_data = RunAgentInput( thread_id="thread_1", run_id="run_1", @@ -143,12 +216,11 @@ async def test_extract_tool_results_multiple_tools(self, ag_ui_adk): forwarded_props={} ) - tool_results = await ag_ui_adk._extract_tool_results(input_data) + unseen_messages = input_data.messages[1:] + tool_results = await ag_ui_adk._extract_tool_results(input_data, unseen_messages) - # Should only extract the most recent tool result to prevent API errors - assert len(tool_results) == 1 - assert tool_results[0]['message'].tool_call_id == "call_2" - assert tool_results[0]['message'].content == '{"result": "second"}' + assert len(tool_results) == 2 + assert [result['message'].tool_call_id for result in tool_results] == ["call_1", "call_2"] @pytest.mark.asyncio async def test_extract_tool_results_mixed_messages(self, ag_ui_adk): @@ -168,9 +240,9 @@ async def test_extract_tool_results_mixed_messages(self, ag_ui_adk): forwarded_props={} ) - tool_results = await ag_ui_adk._extract_tool_results(input_data) + unseen_messages = input_data.messages[3:] + tool_results = await ag_ui_adk._extract_tool_results(input_data, unseen_messages) - # Should only extract the most recent tool message to prevent API errors assert len(tool_results) == 1 assert tool_results[0]['message'].role == "tool" assert tool_results[0]['message'].tool_call_id == "call_2" @@ -325,7 +397,7 @@ async def test_handle_tool_result_submission_invalid_json(self, ag_ui_adk): @pytest.mark.asyncio async def test_handle_tool_result_submission_multiple_results(self, ag_ui_adk): - """Test handling multiple tool results in one submission - only most recent is extracted.""" + """Test handling multiple tool results in one submission preserves all unseen results.""" thread_id = "test_thread" input_data = RunAgentInput( @@ -341,11 +413,9 @@ async def test_handle_tool_result_submission_multiple_results(self, ag_ui_adk): forwarded_props={} ) - # Should extract only the most recent tool result to prevent API errors - tool_results = await ag_ui_adk._extract_tool_results(input_data) - assert len(tool_results) == 1 - assert tool_results[0]['message'].tool_call_id == "call_2" - assert tool_results[0]['message'].content == '{"result": "second"}' + tool_results = await ag_ui_adk._extract_tool_results(input_data, input_data.messages) + assert len(tool_results) == 2 + assert [result['message'].tool_call_id for result in tool_results] == ["call_1", "call_2"] @pytest.mark.asyncio async def test_tool_result_flow_integration(self, ag_ui_adk): @@ -368,7 +438,7 @@ async def test_tool_result_flow_integration(self, ag_ui_adk): # In the all-long-running architecture, tool result inputs are processed as new executions # Mock the background execution to avoid ADK library errors - async def mock_start_new_execution(input_data): + async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None): yield RunStartedEvent( type=EventType.RUN_STARTED, thread_id=input_data.thread_id, @@ -391,6 +461,147 @@ async def mock_start_new_execution(input_data): assert events[0].type == EventType.RUN_STARTED assert events[1].type == EventType.RUN_FINISHED + @pytest.mark.asyncio + async def test_run_processes_mixed_unseen_messages(self, ag_ui_adk): + """Ensure mixed unseen tool and user messages are handled sequentially.""" + input_data = RunAgentInput( + thread_id="thread_mixed", + run_id="run_mixed", + messages=[ + ToolMessage(id="tool_1", role="tool", content='{"result": "value"}', tool_call_id="call_1"), + UserMessage(id="user_2", role="user", content="Next question"), + ], + tools=[], + context=[], + state={}, + forwarded_props={}, + ) + + start_calls = [] + + async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None): + start_calls.append((tool_results, message_batch)) + yield RunStartedEvent( + type=EventType.RUN_STARTED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + yield RunFinishedEvent( + type=EventType.RUN_FINISHED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + + with patch.object( + ag_ui_adk, + '_start_new_execution', + side_effect=mock_start_new_execution, + ), patch.object( + ag_ui_adk, + '_handle_tool_result_submission', + wraps=ag_ui_adk._handle_tool_result_submission, + ) as handle_mock: + events = [] + async for event in ag_ui_adk.run(input_data): + events.append(event) + + assert len(events) == 4 + assert [event.type for event in events] == [ + EventType.RUN_STARTED, + EventType.RUN_FINISHED, + EventType.RUN_STARTED, + EventType.RUN_FINISHED, + ] + + # First call should originate from tool processing with populated tool_results + assert len(start_calls) == 2 + first_tool_results, first_batch = start_calls[0] + assert first_tool_results is not None and len(first_tool_results) == 1 + assert first_tool_results[0]['message'].tool_call_id == "call_1" + assert first_batch == [input_data.messages[0]] + + second_tool_results, second_batch = start_calls[1] + assert second_tool_results is None + assert second_batch == [input_data.messages[1]] + + assert handle_mock.call_count == 1 + assert 'tool_messages' in handle_mock.call_args.kwargs + tool_messages = handle_mock.call_args.kwargs['tool_messages'] + assert len(tool_messages) == 1 + assert getattr(tool_messages[0], 'id', None) == "tool_1" + + @pytest.mark.asyncio + async def test_run_preserves_order_for_user_then_tool(self, ag_ui_adk): + """Verify user updates are handled before subsequent tool messages.""" + input_data = RunAgentInput( + thread_id="thread_order", + run_id="run_order", + messages=[ + UserMessage(id="user_1", role="user", content="Question"), + ToolMessage(id="tool_2", role="tool", content='{"result": "answer"}', tool_call_id="call_2"), + ], + tools=[], + context=[], + state={}, + forwarded_props={}, + ) + + call_sequence = [] + + async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None): + call_sequence.append(("start", tool_results, message_batch)) + yield RunStartedEvent( + type=EventType.RUN_STARTED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + yield RunFinishedEvent( + type=EventType.RUN_FINISHED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + + async def mock_handle_tool_result_submission(input_data, *, tool_messages=None, **kwargs): + call_sequence.append(("tool", tool_messages)) + yield RunStartedEvent( + type=EventType.RUN_STARTED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + yield RunFinishedEvent( + type=EventType.RUN_FINISHED, + thread_id=input_data.thread_id, + run_id=input_data.run_id, + ) + + with patch.object( + ag_ui_adk, + '_start_new_execution', + side_effect=mock_start_new_execution, + ), patch.object( + ag_ui_adk, + '_handle_tool_result_submission', + side_effect=mock_handle_tool_result_submission, + ): + events = [] + async for event in ag_ui_adk.run(input_data): + events.append(event) + + assert [event.type for event in events] == [ + EventType.RUN_STARTED, + EventType.RUN_FINISHED, + EventType.RUN_STARTED, + EventType.RUN_FINISHED, + ] + + assert call_sequence[0][0] == "start" + assert call_sequence[0][1] is None + assert call_sequence[0][2] == [input_data.messages[0]] + + assert call_sequence[1][0] == "tool" + assert len(call_sequence[1][1]) == 1 + assert getattr(call_sequence[1][1][0], 'id', None) == "tool_2" + @pytest.mark.asyncio async def test_new_execution_routing(self, ag_ui_adk, sample_tool): """Test that non-tool messages route to new execution.""" @@ -412,7 +623,7 @@ async def test_new_execution_routing(self, ag_ui_adk, sample_tool): RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id="thread_1", run_id="run_1") ] - async def mock_start_new_execution(input_data): + async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None): for event in mock_events: yield event @@ -423,4 +634,4 @@ async def mock_start_new_execution(input_data): assert len(events) == 2 assert isinstance(events[0], RunStartedEvent) - assert isinstance(events[1], RunFinishedEvent) \ No newline at end of file + assert isinstance(events[1], RunFinishedEvent)