Skip to content

Commit b229d53

Browse files
I have successfully implemented an early exit mechanism for SequentialAgent similar to what exists in LoopAgent. Here's what was accomplished
1 parent 632bf8b commit b229d53

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

src/google/adk/agents/sequential_agent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,16 @@ async def _run_async_impl(
4242
self, ctx: InvocationContext
4343
) -> AsyncGenerator[Event, None]:
4444
for sub_agent in self.sub_agents:
45+
should_exit = False
4546
async with Aclosing(sub_agent.run_async(ctx)) as agen:
4647
async for event in agen:
4748
yield event
49+
if event.actions.escalate:
50+
should_exit = True
51+
break
52+
53+
if should_exit:
54+
return
4855

4956
@override
5057
async def _run_live_impl(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .tool_context import ToolContext
16+
17+
18+
def exit_sequence(tool_context: ToolContext):
19+
"""Exits the sequential execution of agents.
20+
21+
Call this function when you encounter a terminal condition and want to
22+
prevent subsequent agents in the sequence from executing. This allows
23+
for early termination of a sequential agent workflow when a blocking
24+
or final condition is reached.
25+
26+
Args:
27+
tool_context: The context of the current tool invocation.
28+
"""
29+
tool_context.actions.escalate = True
30+
tool_context.actions.skip_summarization = True

tests/unittests/agents/test_sequential_agent.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.adk.agents.invocation_context import InvocationContext
2121
from google.adk.agents.sequential_agent import SequentialAgent
2222
from google.adk.events.event import Event
23+
from google.adk.events.event_actions import EventActions
2324
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2425
from google.genai import types
2526
import pytest
@@ -53,6 +54,42 @@ async def _run_live_impl(
5354
)
5455

5556

57+
class _TestingAgentWithEscalateAction(BaseAgent):
58+
59+
@override
60+
async def _run_async_impl(
61+
self, ctx: InvocationContext
62+
) -> AsyncGenerator[Event, None]:
63+
yield Event(
64+
author=self.name,
65+
invocation_id=ctx.invocation_id,
66+
content=types.Content(
67+
parts=[types.Part(text=f'Hello, async {self.name}!')]
68+
),
69+
actions=EventActions(escalate=True),
70+
)
71+
yield Event(
72+
author=self.name,
73+
invocation_id=ctx.invocation_id,
74+
content=types.Content(
75+
parts=[types.Part(text=f'I should not be seen after escalation!')]
76+
),
77+
)
78+
79+
@override
80+
async def _run_live_impl(
81+
self, ctx: InvocationContext
82+
) -> AsyncGenerator[Event, None]:
83+
yield Event(
84+
author=self.name,
85+
invocation_id=ctx.invocation_id,
86+
content=types.Content(
87+
parts=[types.Part(text=f'Hello, live {self.name}!')]
88+
),
89+
actions=EventActions(escalate=True),
90+
)
91+
92+
5693
async def _create_parent_invocation_context(
5794
test_name: str, agent: BaseAgent
5895
) -> InvocationContext:
@@ -112,3 +149,98 @@ async def test_run_live(request: pytest.FixtureRequest):
112149
assert events[1].author == agent_2.name
113150
assert events[0].content.parts[0].text == f'Hello, live {agent_1.name}!'
114151
assert events[1].content.parts[0].text == f'Hello, live {agent_2.name}!'
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
156+
"""Test that SequentialAgent exits early when escalate action is triggered."""
157+
escalating_agent = _TestingAgentWithEscalateAction(
158+
name=f'{request.function.__name__}_escalating_agent'
159+
)
160+
normal_agent = _TestingAgent(name=f'{request.function.__name__}_normal_agent')
161+
sequential_agent = SequentialAgent(
162+
name=f'{request.function.__name__}_test_agent',
163+
sub_agents=[
164+
escalating_agent,
165+
normal_agent, # This should not execute
166+
],
167+
)
168+
parent_ctx = await _create_parent_invocation_context(
169+
request.function.__name__, sequential_agent
170+
)
171+
events = [e async for e in sequential_agent.run_async(parent_ctx)]
172+
173+
# Should only have 1 event from the escalating agent, normal agent should not run
174+
assert len(events) == 1
175+
assert events[0].author == escalating_agent.name
176+
assert (
177+
events[0].content.parts[0].text
178+
== f'Hello, async {escalating_agent.name}!'
179+
)
180+
assert events[0].actions.escalate is True
181+
182+
# Verify that the post-escalation event from the same agent is not yielded
183+
assert (
184+
'I should not be seen after escalation!'
185+
not in events[0].content.parts[0].text
186+
)
187+
188+
189+
@pytest.mark.asyncio
190+
async def test_run_async_escalate_action_in_middle(
191+
request: pytest.FixtureRequest,
192+
):
193+
"""Test that SequentialAgent exits when escalation happens in middle of sequence."""
194+
first_agent = _TestingAgent(name=f'{request.function.__name__}_first_agent')
195+
escalating_agent = _TestingAgentWithEscalateAction(
196+
name=f'{request.function.__name__}_escalating_agent'
197+
)
198+
third_agent = _TestingAgent(name=f'{request.function.__name__}_third_agent')
199+
sequential_agent = SequentialAgent(
200+
name=f'{request.function.__name__}_test_agent',
201+
sub_agents=[
202+
first_agent,
203+
escalating_agent,
204+
third_agent, # This should not execute
205+
],
206+
)
207+
parent_ctx = await _create_parent_invocation_context(
208+
request.function.__name__, sequential_agent
209+
)
210+
events = [e async for e in sequential_agent.run_async(parent_ctx)]
211+
212+
# Should have 2 events: one from first agent, one from escalating agent
213+
assert len(events) == 2
214+
assert events[0].author == first_agent.name
215+
assert events[1].author == escalating_agent.name
216+
assert events[1].actions.escalate is True
217+
218+
# Verify third agent did not run
219+
third_agent_events = [e for e in events if e.author == third_agent.name]
220+
assert len(third_agent_events) == 0
221+
222+
223+
@pytest.mark.asyncio
224+
async def test_run_async_no_escalate_action(request: pytest.FixtureRequest):
225+
"""Test that SequentialAgent continues normally when no escalate action."""
226+
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
227+
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
228+
agent_3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3')
229+
sequential_agent = SequentialAgent(
230+
name=f'{request.function.__name__}_test_agent',
231+
sub_agents=[
232+
agent_1,
233+
agent_2,
234+
agent_3,
235+
],
236+
)
237+
parent_ctx = await _create_parent_invocation_context(
238+
request.function.__name__, sequential_agent
239+
)
240+
events = [e async for e in sequential_agent.run_async(parent_ctx)]
241+
242+
# All agents should execute
243+
assert len(events) == 3
244+
assert events[0].author == agent_1.name
245+
assert events[1].author == agent_2.name
246+
assert events[2].author == agent_3.name

0 commit comments

Comments
 (0)