Skip to content

Commit ce9c39f

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Implement checkpoint and resume logic for LoopAgent
PiperOrigin-RevId: 813096880
1 parent d5c46e4 commit ce9c39f

File tree

4 files changed

+215
-41
lines changed

4 files changed

+215
-41
lines changed

src/google/adk/agents/loop_agent.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import logging
1920
from typing import Any
2021
from typing import AsyncGenerator
2122
from typing import ClassVar
@@ -24,15 +25,17 @@
2425

2526
from typing_extensions import override
2627

27-
from ..agents.invocation_context import InvocationContext
2828
from ..events.event import Event
2929
from ..utils.context_utils import Aclosing
3030
from ..utils.feature_decorator import experimental
3131
from .base_agent import BaseAgent
3232
from .base_agent import BaseAgentState
3333
from .base_agent_config import BaseAgentConfig
34+
from .invocation_context import InvocationContext
3435
from .loop_agent_config import LoopAgentConfig
3536

37+
logger = logging.getLogger('google_adk.' + __name__)
38+
3639

3740
@experimental
3841
class LoopAgentState(BaseAgentState):
@@ -69,11 +72,32 @@ async def _run_async_impl(
6972
if not self.sub_agents:
7073
return
7174

72-
times_looped = 0
73-
while not self.max_iterations or times_looped < self.max_iterations:
74-
for sub_agent in self.sub_agents:
75-
should_exit = False
76-
pause_invocation = False
75+
agent_state = self._load_agent_state(ctx, LoopAgentState)
76+
is_resuming_at_current_agent = agent_state is not None
77+
times_looped, start_index = self._get_start_state(agent_state)
78+
79+
should_exit = False
80+
pause_invocation = False
81+
while (
82+
not self.max_iterations or times_looped < self.max_iterations
83+
) and not (should_exit or pause_invocation):
84+
for i in range(start_index, len(self.sub_agents)):
85+
sub_agent = self.sub_agents[i]
86+
87+
if ctx.is_resumable and not is_resuming_at_current_agent:
88+
# If we are resuming from the current event, it means the same event
89+
# has already been logged, so we should avoid yielding it again.
90+
agent_state = LoopAgentState(
91+
current_sub_agent=sub_agent.name,
92+
times_looped=times_looped,
93+
)
94+
yield self._create_agent_state_event(ctx, agent_state=agent_state)
95+
96+
# Reset the sub-agent's state in the context to ensure that each
97+
# sub-agent starts fresh.
98+
if not is_resuming_at_current_agent:
99+
ctx.reset_agent_state(sub_agent.name)
100+
is_resuming_at_current_agent = False
77101

78102
async with Aclosing(sub_agent.run_async(ctx)) as agen:
79103
async for event in agen:
@@ -83,18 +107,42 @@ async def _run_async_impl(
83107
if ctx.should_pause_invocation(event):
84108
pause_invocation = True
85109

86-
# Indicates that the loop agent should exist after running this
87-
# sub-agent.
88-
if should_exit:
89-
return
90-
91-
# Indicates that the invocation should be paused after running this
92-
# sub-agent.
93-
if pause_invocation:
94-
return
110+
if should_exit or pause_invocation:
111+
break # break inner for loop
95112

113+
# Restart from the beginning of the loop.
114+
start_index = 0
96115
times_looped += 1
97-
return
116+
117+
# If the invocation is paused, we should not yield the end of agent event.
118+
if pause_invocation:
119+
return
120+
121+
if ctx.is_resumable:
122+
yield self._create_agent_state_event(ctx, end_of_agent=True)
123+
124+
def _get_start_state(
125+
self,
126+
agent_state: Optional[LoopAgentState],
127+
) -> tuple[int, int]:
128+
"""Computes the start state of the loop agent from the agent state."""
129+
if not agent_state:
130+
return 0, 0
131+
132+
times_looped = agent_state.times_looped
133+
start_index = 0
134+
if agent_state.current_sub_agent:
135+
try:
136+
sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents]
137+
start_index = sub_agent_names.index(agent_state.current_sub_agent)
138+
except ValueError:
139+
# A sub-agent was removed so the agent name is not found.
140+
# For now, we restart from the beginning.
141+
logger.warning(
142+
'Sub-agent %s was not found. Restarting from the beginning.',
143+
agent_state.current_sub_agent,
144+
)
145+
return times_looped, start_index
98146

99147
@override
100148
async def _run_live_impl(

tests/unittests/agents/test_loop_agent.py

Lines changed: 112 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@
1919
from google.adk.agents.base_agent import BaseAgent
2020
from google.adk.agents.invocation_context import InvocationContext
2121
from google.adk.agents.loop_agent import LoopAgent
22+
from google.adk.agents.loop_agent import LoopAgentState
23+
from google.adk.apps import ResumabilityConfig
2224
from google.adk.events.event import Event
2325
from google.adk.events.event_actions import EventActions
2426
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2527
from google.genai import types
2628
import pytest
2729
from typing_extensions import override
2830

31+
from .. import testing_utils
32+
33+
END_OF_AGENT = testing_utils.END_OF_AGENT
34+
2935

3036
class _TestingAgent(BaseAgent):
3137

@@ -72,13 +78,13 @@ async def _run_async_impl(
7278
author=self.name,
7379
invocation_id=ctx.invocation_id,
7480
content=types.Content(
75-
parts=[types.Part(text=f'I have done my job after escalation!!')]
81+
parts=[types.Part(text='I have done my job after escalation!!')]
7682
),
7783
)
7884

7985

8086
async def _create_parent_invocation_context(
81-
test_name: str, agent: BaseAgent
87+
test_name: str, agent: BaseAgent, resumable: bool = False
8288
) -> InvocationContext:
8389
session_service = InMemorySessionService()
8490
session = await session_service.create_session(
@@ -89,11 +95,13 @@ async def _create_parent_invocation_context(
8995
agent=agent,
9096
session=session,
9197
session_service=session_service,
98+
resumability_config=ResumabilityConfig(is_resumable=resumable),
9299
)
93100

94101

95102
@pytest.mark.asyncio
96-
async def test_run_async(request: pytest.FixtureRequest):
103+
@pytest.mark.parametrize('resumable', [True, False])
104+
async def test_run_async(request: pytest.FixtureRequest, resumable: bool):
97105
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
98106
loop_agent = LoopAgent(
99107
name=f'{request.function.__name__}_test_loop_agent',
@@ -103,15 +111,60 @@ async def test_run_async(request: pytest.FixtureRequest):
103111
],
104112
)
105113
parent_ctx = await _create_parent_invocation_context(
106-
request.function.__name__, loop_agent
114+
request.function.__name__, loop_agent, resumable=resumable
115+
)
116+
events = [e async for e in loop_agent.run_async(parent_ctx)]
117+
118+
simplified_events = testing_utils.simplify_resumable_app_events(events)
119+
if resumable:
120+
expected_events = [
121+
(
122+
loop_agent.name,
123+
{'current_sub_agent': agent.name, 'times_looped': 0},
124+
),
125+
(agent.name, f'Hello, async {agent.name}!'),
126+
(
127+
loop_agent.name,
128+
{'current_sub_agent': agent.name, 'times_looped': 1},
129+
),
130+
(agent.name, f'Hello, async {agent.name}!'),
131+
(loop_agent.name, END_OF_AGENT),
132+
]
133+
else:
134+
expected_events = [
135+
(agent.name, f'Hello, async {agent.name}!'),
136+
(agent.name, f'Hello, async {agent.name}!'),
137+
]
138+
assert simplified_events == expected_events
139+
140+
141+
@pytest.mark.asyncio
142+
async def test_resume_async(request: pytest.FixtureRequest):
143+
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
144+
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
145+
loop_agent = LoopAgent(
146+
name=f'{request.function.__name__}_test_loop_agent',
147+
max_iterations=2,
148+
sub_agents=[
149+
agent_1,
150+
agent_2,
151+
],
107152
)
153+
parent_ctx = await _create_parent_invocation_context(
154+
request.function.__name__, loop_agent, resumable=True
155+
)
156+
parent_ctx.agent_states[loop_agent.name] = LoopAgentState(
157+
current_sub_agent=agent_2.name, times_looped=1
158+
).model_dump(mode='json')
159+
108160
events = [e async for e in loop_agent.run_async(parent_ctx)]
109161

110-
assert len(events) == 2
111-
assert events[0].author == agent.name
112-
assert events[1].author == agent.name
113-
assert events[0].content.parts[0].text == f'Hello, async {agent.name}!'
114-
assert events[1].content.parts[0].text == f'Hello, async {agent.name}!'
162+
simplified_events = testing_utils.simplify_resumable_app_events(events)
163+
expected_events = [
164+
(agent_2.name, f'Hello, async {agent_2.name}!'),
165+
(loop_agent.name, END_OF_AGENT),
166+
]
167+
assert simplified_events == expected_events
115168

116169

117170
@pytest.mark.asyncio
@@ -129,7 +182,10 @@ async def test_run_async_skip_if_no_sub_agent(request: pytest.FixtureRequest):
129182

130183

131184
@pytest.mark.asyncio
132-
async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
185+
@pytest.mark.parametrize('resumable', [True, False])
186+
async def test_run_async_with_escalate_action(
187+
request: pytest.FixtureRequest, resumable: bool
188+
):
133189
non_escalating_agent = _TestingAgent(
134190
name=f'{request.function.__name__}_test_non_escalating_agent'
135191
)
@@ -144,20 +200,52 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
144200
sub_agents=[non_escalating_agent, escalating_agent, ignored_agent],
145201
)
146202
parent_ctx = await _create_parent_invocation_context(
147-
request.function.__name__, loop_agent
203+
request.function.__name__, loop_agent, resumable=resumable
148204
)
149205
events = [e async for e in loop_agent.run_async(parent_ctx)]
150206

151-
# Only two events are generated because the sub escalating_agent escalates.
152-
assert len(events) == 3
153-
assert events[0].author == non_escalating_agent.name
154-
assert events[1].author == escalating_agent.name
155-
assert events[0].content.parts[0].text == (
156-
f'Hello, async {non_escalating_agent.name}!'
157-
)
158-
assert events[1].content.parts[0].text == (
159-
f'Hello, async {escalating_agent.name}!'
160-
)
161-
assert (
162-
events[2].content.parts[0].text == 'I have done my job after escalation!!'
163-
)
207+
simplified_events = testing_utils.simplify_resumable_app_events(events)
208+
209+
if resumable:
210+
expected_events = [
211+
(
212+
loop_agent.name,
213+
{
214+
'current_sub_agent': non_escalating_agent.name,
215+
'times_looped': 0,
216+
},
217+
),
218+
(
219+
non_escalating_agent.name,
220+
f'Hello, async {non_escalating_agent.name}!',
221+
),
222+
(
223+
loop_agent.name,
224+
{'current_sub_agent': escalating_agent.name, 'times_looped': 0},
225+
),
226+
(
227+
escalating_agent.name,
228+
f'Hello, async {escalating_agent.name}!',
229+
),
230+
(
231+
escalating_agent.name,
232+
'I have done my job after escalation!!',
233+
),
234+
(loop_agent.name, END_OF_AGENT),
235+
]
236+
else:
237+
expected_events = [
238+
(
239+
non_escalating_agent.name,
240+
f'Hello, async {non_escalating_agent.name}!',
241+
),
242+
(
243+
escalating_agent.name,
244+
f'Hello, async {escalating_agent.name}!',
245+
),
246+
(
247+
escalating_agent.name,
248+
'I have done my job after escalation!!',
249+
),
250+
]
251+
assert simplified_events == expected_events

tests/unittests/flows/llm_flows/test_agent_transfer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from google.adk.agents.llm_agent import Agent
1616
from google.adk.agents.loop_agent import LoopAgent
17+
from google.adk.agents.loop_agent import LoopAgentState
1718
from google.adk.agents.sequential_agent import SequentialAgent
1819
from google.adk.agents.sequential_agent import SequentialAgentState
1920
from google.adk.apps.app import App
@@ -469,12 +470,36 @@ def test_auto_to_loop(is_resumable: bool):
469470
('root_agent', transfer_call_part('sub_agent_1')),
470471
('root_agent', TRANSFER_RESPONSE_PART),
471472
# Loops.
473+
(
474+
'sub_agent_1',
475+
LoopAgentState(current_sub_agent='sub_agent_1_1').model_dump(
476+
mode='json'
477+
),
478+
),
472479
('sub_agent_1_1', 'response1'),
473480
('sub_agent_1_1', END_OF_AGENT),
481+
(
482+
'sub_agent_1',
483+
LoopAgentState(current_sub_agent='sub_agent_1_2').model_dump(
484+
mode='json'
485+
),
486+
),
474487
('sub_agent_1_2', 'response2'),
475488
('sub_agent_1_2', END_OF_AGENT),
489+
(
490+
'sub_agent_1',
491+
LoopAgentState(
492+
current_sub_agent='sub_agent_1_1', times_looped=1
493+
).model_dump(mode='json'),
494+
),
476495
('sub_agent_1_1', 'response3'),
477496
('sub_agent_1_1', END_OF_AGENT),
497+
(
498+
'sub_agent_1',
499+
LoopAgentState(
500+
current_sub_agent='sub_agent_1_2', times_looped=1
501+
).model_dump(mode='json'),
502+
),
478503
# Exits.
479504
('sub_agent_1_2', Part.from_function_call(name='exit_loop', args={})),
480505
(
@@ -484,7 +509,7 @@ def test_auto_to_loop(is_resumable: bool):
484509
),
485510
),
486511
('sub_agent_1_2', END_OF_AGENT),
487-
# Later expect the loop agent to also yield agent state events.
512+
('sub_agent_1', END_OF_AGENT),
488513
('root_agent', END_OF_AGENT),
489514
]
490515
# Same session, different invocation.

tests/unittests/runners/test_pause_invocation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.adk.agents.invocation_context import InvocationContext
2222
from google.adk.agents.llm_agent import LlmAgent
2323
from google.adk.agents.loop_agent import LoopAgent
24+
from google.adk.agents.loop_agent import LoopAgentState
2425
from google.adk.agents.parallel_agent import ParallelAgent
2526
from google.adk.agents.sequential_agent import SequentialAgent
2627
from google.adk.agents.sequential_agent import SequentialAgentState
@@ -368,8 +369,20 @@ def test_pause_on_long_running_function_call(
368369
):
369370
"""Tests that a LoopAgent pauses on long running function call."""
370371
assert testing_utils.simplify_resumable_app_events(runner.run("test")) == [
372+
(
373+
"root_agent",
374+
LoopAgentState(current_sub_agent="sub_agent_1").model_dump(
375+
mode="json"
376+
),
377+
),
371378
("sub_agent_1", "sub agent 1 response"),
372379
("sub_agent_1", END_OF_AGENT),
380+
(
381+
"root_agent",
382+
LoopAgentState(current_sub_agent="sub_agent_2").model_dump(
383+
mode="json"
384+
),
385+
),
373386
("sub_agent_2", Part.from_function_call(name="test_tool", args={})),
374387
]
375388

0 commit comments

Comments
 (0)