diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index 5cc5b654ad..f53a8ba635 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -81,7 +81,8 @@ async def _run_async_impl( yield event if ctx.should_pause_invocation(event): pause_invocation = True - + if event.actions and event.actions.escalate: + return # Skip the rest of the sub-agents if the invocation is paused. if pause_invocation: return diff --git a/src/google/adk/tools/exit_sequence_tool.py b/src/google/adk/tools/exit_sequence_tool.py new file mode 100644 index 0000000000..ac8f9e524d --- /dev/null +++ b/src/google/adk/tools/exit_sequence_tool.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tool_context import ToolContext + + +def exit_sequence(tool_context: ToolContext): + """Exits the sequential execution of agents. + + Call this function when you encounter a terminal condition and want to + prevent subsequent agents in the sequence from executing. This allows + for early termination of a sequential agent workflow when a blocking + or final condition is reached. + + Args: + tool_context: The context of the current tool invocation. + """ + tool_context.actions.escalate = True + tool_context.actions.skip_summarization = True diff --git a/tests/unittests/agents/test_sequential_agent.py b/tests/unittests/agents/test_sequential_agent.py index 9703e0ca29..a486fb2435 100644 --- a/tests/unittests/agents/test_sequential_agent.py +++ b/tests/unittests/agents/test_sequential_agent.py @@ -22,6 +22,7 @@ from google.adk.agents.sequential_agent import SequentialAgentState from google.adk.apps import ResumabilityConfig from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.genai import types import pytest @@ -55,6 +56,42 @@ async def _run_live_impl( ) +class _TestingAgentWithEscalateAction(BaseAgent): + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content( + parts=[types.Part(text=f'Hello, async {self.name}!')] + ), + actions=EventActions(escalate=True), + ) + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content( + parts=[types.Part(text=f'I should not be seen after escalation!')] + ), + ) + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content( + parts=[types.Part(text=f'Hello, live {self.name}!')] + ), + actions=EventActions(escalate=True), + ) + + async def _create_parent_invocation_context( test_name: str, agent: BaseAgent, resumable: bool = False ) -> InvocationContext: @@ -201,3 +238,92 @@ async def test_run_live(request: pytest.FixtureRequest): assert events[1].author == agent_2.name assert events[0].content.parts[0].text == f'Hello, live {agent_1.name}!' assert events[1].content.parts[0].text == f'Hello, live {agent_2.name}!' + + +@pytest.mark.asyncio +async def test_run_async_with_escalate_action(request: pytest.FixtureRequest): + """Test that SequentialAgent exits early when escalate action is triggered.""" + escalating_agent = _TestingAgentWithEscalateAction( + name=f'{request.function.__name__}_escalating_agent' + ) + normal_agent = _TestingAgent(name=f'{request.function.__name__}_normal_agent') + sequential_agent = SequentialAgent( + name=f'{request.function.__name__}_test_agent', + sub_agents=[ + escalating_agent, + normal_agent, # This should not execute + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # Should only have 1 event from the escalating agent, normal agent should not run + assert len(events) == 1 + assert events[0].author == escalating_agent.name + assert ( + events[0].content.parts[0].text + == f'Hello, async {escalating_agent.name}!' + ) + assert events[0].actions.escalate is True + + +@pytest.mark.asyncio +async def test_run_async_escalate_action_in_middle( + request: pytest.FixtureRequest, +): + """Test that SequentialAgent exits when escalation happens in middle of sequence.""" + first_agent = _TestingAgent(name=f'{request.function.__name__}_first_agent') + escalating_agent = _TestingAgentWithEscalateAction( + name=f'{request.function.__name__}_escalating_agent' + ) + third_agent = _TestingAgent(name=f'{request.function.__name__}_third_agent') + sequential_agent = SequentialAgent( + name=f'{request.function.__name__}_test_agent', + sub_agents=[ + first_agent, + escalating_agent, + third_agent, # This should not execute + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # Should have 2 events: one from first agent, one from escalating agent + assert len(events) == 2 + assert events[0].author == first_agent.name + assert events[1].author == escalating_agent.name + assert events[1].actions.escalate is True + + # Verify third agent did not run + third_agent_events = [e for e in events if e.author == third_agent.name] + assert len(third_agent_events) == 0 + + +@pytest.mark.asyncio +async def test_run_async_no_escalate_action(request: pytest.FixtureRequest): + """Test that SequentialAgent continues normally when no escalate action.""" + agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') + agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') + agent_3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3') + sequential_agent = SequentialAgent( + name=f'{request.function.__name__}_test_agent', + sub_agents=[ + agent_1, + agent_2, + agent_3, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # All agents should execute + assert len(events) == 3 + assert events[0].author == agent_1.name + assert events[1].author == agent_2.name + assert events[2].author == agent_3.name