Skip to content

Commit 50a7f30

Browse files
I have successfully implemented an early exit mechanism for SequentialAgent similar to what exists in LoopAgent. Here's what was accomplished
1 parent 943abec commit 50a7f30

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .tool_context import ToolContext
16+
17+
18+
def exit_sequence(tool_context: ToolContext):
19+
"""Exits the sequential execution of agents.
20+
21+
Call this function when you encounter a terminal condition and want to
22+
prevent subsequent agents in the sequence from executing. This allows
23+
for early termination of a sequential agent workflow when a blocking
24+
or final condition is reached.
25+
26+
Args:
27+
tool_context: The context of the current tool invocation.
28+
"""
29+
tool_context.actions.escalate = True
30+
tool_context.actions.skip_summarization = True

tests/unittests/agents/test_sequential_agent.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.adk.agents.sequential_agent import SequentialAgentState
2323
from google.adk.apps import ResumabilityConfig
2424
from google.adk.events.event import Event
25+
from google.adk.events.event_actions import EventActions
2526
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2627
from google.genai import types
2728
import pytest
@@ -55,6 +56,42 @@ async def _run_live_impl(
5556
)
5657

5758

59+
class _TestingAgentWithEscalateAction(BaseAgent):
60+
61+
@override
62+
async def _run_async_impl(
63+
self, ctx: InvocationContext
64+
) -> AsyncGenerator[Event, None]:
65+
yield Event(
66+
author=self.name,
67+
invocation_id=ctx.invocation_id,
68+
content=types.Content(
69+
parts=[types.Part(text=f'Hello, async {self.name}!')]
70+
),
71+
actions=EventActions(escalate=True),
72+
)
73+
yield Event(
74+
author=self.name,
75+
invocation_id=ctx.invocation_id,
76+
content=types.Content(
77+
parts=[types.Part(text=f'I should not be seen after escalation!')]
78+
),
79+
)
80+
81+
@override
82+
async def _run_live_impl(
83+
self, ctx: InvocationContext
84+
) -> AsyncGenerator[Event, None]:
85+
yield Event(
86+
author=self.name,
87+
invocation_id=ctx.invocation_id,
88+
content=types.Content(
89+
parts=[types.Part(text=f'Hello, live {self.name}!')]
90+
),
91+
actions=EventActions(escalate=True),
92+
)
93+
94+
5895
async def _create_parent_invocation_context(
5996
test_name: str, agent: BaseAgent, resumable: bool = False
6097
) -> InvocationContext:
@@ -201,3 +238,98 @@ async def test_run_live(request: pytest.FixtureRequest):
201238
assert events[1].author == agent_2.name
202239
assert events[0].content.parts[0].text == f'Hello, live {agent_1.name}!'
203240
assert events[1].content.parts[0].text == f'Hello, live {agent_2.name}!'
241+
242+
243+
@pytest.mark.asyncio
244+
async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
245+
"""Test that SequentialAgent exits early when escalate action is triggered."""
246+
escalating_agent = _TestingAgentWithEscalateAction(
247+
name=f'{request.function.__name__}_escalating_agent'
248+
)
249+
normal_agent = _TestingAgent(name=f'{request.function.__name__}_normal_agent')
250+
sequential_agent = SequentialAgent(
251+
name=f'{request.function.__name__}_test_agent',
252+
sub_agents=[
253+
escalating_agent,
254+
normal_agent, # This should not execute
255+
],
256+
)
257+
parent_ctx = await _create_parent_invocation_context(
258+
request.function.__name__, sequential_agent
259+
)
260+
events = [e async for e in sequential_agent.run_async(parent_ctx)]
261+
262+
# Should only have 1 event from the escalating agent, normal agent should not run
263+
assert len(events) == 1
264+
assert events[0].author == escalating_agent.name
265+
assert (
266+
events[0].content.parts[0].text
267+
== f'Hello, async {escalating_agent.name}!'
268+
)
269+
assert events[0].actions.escalate is True
270+
271+
# Verify that the post-escalation event from the same agent is not yielded
272+
assert (
273+
'I should not be seen after escalation!'
274+
not in events[0].content.parts[0].text
275+
)
276+
277+
278+
@pytest.mark.asyncio
279+
async def test_run_async_escalate_action_in_middle(
280+
request: pytest.FixtureRequest,
281+
):
282+
"""Test that SequentialAgent exits when escalation happens in middle of sequence."""
283+
first_agent = _TestingAgent(name=f'{request.function.__name__}_first_agent')
284+
escalating_agent = _TestingAgentWithEscalateAction(
285+
name=f'{request.function.__name__}_escalating_agent'
286+
)
287+
third_agent = _TestingAgent(name=f'{request.function.__name__}_third_agent')
288+
sequential_agent = SequentialAgent(
289+
name=f'{request.function.__name__}_test_agent',
290+
sub_agents=[
291+
first_agent,
292+
escalating_agent,
293+
third_agent, # This should not execute
294+
],
295+
)
296+
parent_ctx = await _create_parent_invocation_context(
297+
request.function.__name__, sequential_agent
298+
)
299+
events = [e async for e in sequential_agent.run_async(parent_ctx)]
300+
301+
# Should have 2 events: one from first agent, one from escalating agent
302+
assert len(events) == 2
303+
assert events[0].author == first_agent.name
304+
assert events[1].author == escalating_agent.name
305+
assert events[1].actions.escalate is True
306+
307+
# Verify third agent did not run
308+
third_agent_events = [e for e in events if e.author == third_agent.name]
309+
assert len(third_agent_events) == 0
310+
311+
312+
@pytest.mark.asyncio
313+
async def test_run_async_no_escalate_action(request: pytest.FixtureRequest):
314+
"""Test that SequentialAgent continues normally when no escalate action."""
315+
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
316+
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
317+
agent_3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3')
318+
sequential_agent = SequentialAgent(
319+
name=f'{request.function.__name__}_test_agent',
320+
sub_agents=[
321+
agent_1,
322+
agent_2,
323+
agent_3,
324+
],
325+
)
326+
parent_ctx = await _create_parent_invocation_context(
327+
request.function.__name__, sequential_agent
328+
)
329+
events = [e async for e in sequential_agent.run_async(parent_ctx)]
330+
331+
# All agents should execute
332+
assert len(events) == 3
333+
assert events[0].author == agent_1.name
334+
assert events[1].author == agent_2.name
335+
assert events[2].author == agent_3.name

0 commit comments

Comments
 (0)