Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/google/adk/agents/sequential_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ async def _run_async_impl(
yield event
if ctx.should_pause_invocation(event):
pause_invocation = True

if event.actions and event.actions.escalate:
return
# Skip the rest of the sub-agents if the invocation is paused.
if pause_invocation:
return
Expand Down
30 changes: 30 additions & 0 deletions src/google/adk/tools/exit_sequence_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .tool_context import ToolContext


def exit_sequence(tool_context: ToolContext):
"""Exits the sequential execution of agents.

Call this function when you encounter a terminal condition and want to
prevent subsequent agents in the sequence from executing. This allows
for early termination of a sequential agent workflow when a blocking
or final condition is reached.

Args:
tool_context: The context of the current tool invocation.
"""
tool_context.actions.escalate = True
tool_context.actions.skip_summarization = True
126 changes: 126 additions & 0 deletions tests/unittests/agents/test_sequential_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.adk.agents.sequential_agent import SequentialAgentState
from google.adk.apps import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
import pytest
Expand Down Expand Up @@ -55,6 +56,42 @@ async def _run_live_impl(
)


class _TestingAgentWithEscalateAction(BaseAgent):

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
author=self.name,
invocation_id=ctx.invocation_id,
content=types.Content(
parts=[types.Part(text=f'Hello, async {self.name}!')]
),
actions=EventActions(escalate=True),
)
yield Event(
author=self.name,
invocation_id=ctx.invocation_id,
content=types.Content(
parts=[types.Part(text=f'I should not be seen after escalation!')]
),
)

@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
author=self.name,
invocation_id=ctx.invocation_id,
content=types.Content(
parts=[types.Part(text=f'Hello, live {self.name}!')]
),
actions=EventActions(escalate=True),
)


async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent, resumable: bool = False
) -> InvocationContext:
Expand Down Expand Up @@ -201,3 +238,92 @@ async def test_run_live(request: pytest.FixtureRequest):
assert events[1].author == agent_2.name
assert events[0].content.parts[0].text == f'Hello, live {agent_1.name}!'
assert events[1].content.parts[0].text == f'Hello, live {agent_2.name}!'


@pytest.mark.asyncio
async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
"""Test that SequentialAgent exits early when escalate action is triggered."""
escalating_agent = _TestingAgentWithEscalateAction(
name=f'{request.function.__name__}_escalating_agent'
)
normal_agent = _TestingAgent(name=f'{request.function.__name__}_normal_agent')
sequential_agent = SequentialAgent(
name=f'{request.function.__name__}_test_agent',
sub_agents=[
escalating_agent,
normal_agent, # This should not execute
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent
)
events = [e async for e in sequential_agent.run_async(parent_ctx)]

# Should only have 1 event from the escalating agent, normal agent should not run
assert len(events) == 1
assert events[0].author == escalating_agent.name
assert (
events[0].content.parts[0].text
== f'Hello, async {escalating_agent.name}!'
)
assert events[0].actions.escalate is True


@pytest.mark.asyncio
async def test_run_async_escalate_action_in_middle(
request: pytest.FixtureRequest,
):
"""Test that SequentialAgent exits when escalation happens in middle of sequence."""
first_agent = _TestingAgent(name=f'{request.function.__name__}_first_agent')
escalating_agent = _TestingAgentWithEscalateAction(
name=f'{request.function.__name__}_escalating_agent'
)
third_agent = _TestingAgent(name=f'{request.function.__name__}_third_agent')
sequential_agent = SequentialAgent(
name=f'{request.function.__name__}_test_agent',
sub_agents=[
first_agent,
escalating_agent,
third_agent, # This should not execute
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent
)
events = [e async for e in sequential_agent.run_async(parent_ctx)]

# Should have 2 events: one from first agent, one from escalating agent
assert len(events) == 2
assert events[0].author == first_agent.name
assert events[1].author == escalating_agent.name
assert events[1].actions.escalate is True

# Verify third agent did not run
third_agent_events = [e for e in events if e.author == third_agent.name]
assert len(third_agent_events) == 0


@pytest.mark.asyncio
async def test_run_async_no_escalate_action(request: pytest.FixtureRequest):
"""Test that SequentialAgent continues normally when no escalate action."""
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
agent_3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3')
sequential_agent = SequentialAgent(
name=f'{request.function.__name__}_test_agent',
sub_agents=[
agent_1,
agent_2,
agent_3,
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent
)
events = [e async for e in sequential_agent.run_async(parent_ctx)]

# All agents should execute
assert len(events) == 3
assert events[0].author == agent_1.name
assert events[1].author == agent_2.name
assert events[2].author == agent_3.name