Skip to content

Commit 1ee01cc

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Implement checkpoint and resume logic for SequentialAgent
PiperOrigin-RevId: 811977004
1 parent 28d44a3 commit 1ee01cc

File tree

8 files changed

+245
-72
lines changed

8 files changed

+245
-72
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -163,44 +163,41 @@ def _load_agent_state(
163163
self,
164164
ctx: InvocationContext,
165165
state_type: Type[AgentState],
166-
default_state: AgentState,
167-
) -> tuple[AgentState, bool]:
166+
) -> Optional[AgentState]:
168167
"""Loads the agent state from the invocation context, handling resumption.
169168
170169
Args:
171170
ctx: The invocation context.
172171
state_type: The type of the agent state.
173-
default_state: The default state to use if not resuming.
174172
175173
Returns:
176-
tuple[AgentState, bool]: The current state and a boolean indicating if
177-
resuming.
174+
The current state if resuming, otherwise None.
178175
"""
176+
if not ctx.is_resumable:
177+
return None
178+
179179
if self.name not in ctx.agent_states:
180-
return default_state, False
180+
return None
181181
else:
182-
return state_type.model_validate(ctx.agent_states.get(self.name)), True
182+
return state_type.model_validate(ctx.agent_states.get(self.name))
183183

184184
def _create_agent_state_event(
185185
self,
186186
ctx: InvocationContext,
187187
*,
188-
state: Optional[BaseAgentState] = None,
188+
agent_state: Optional[BaseAgentState] = None,
189189
end_of_agent: bool = False,
190190
) -> Event:
191-
"""Creates an event for agent state.
191+
"""Returns an event with agent state.
192192
193193
Args:
194194
ctx: The invocation context.
195-
state: The agent state to checkpoint.
195+
agent_state: The agent state to checkpoint.
196196
end_of_agent: Whether the agent is finished running.
197-
198-
Returns:
199-
An Event object representing the checkpoint.
200197
"""
201198
event_actions = EventActions()
202-
if state:
203-
event_actions.agent_state = state.model_dump(mode='json')
199+
if agent_state:
200+
event_actions.agent_state = agent_state.model_dump(mode='json')
204201
if end_of_agent:
205202
event_actions.end_of_agent = True
206203
return Event(

src/google/adk/agents/invocation_context.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ class InvocationContext(BaseModel):
208208
of this invocation.
209209
"""
210210

211+
@property
212+
def is_resumable(self) -> bool:
213+
"""Returns whether the current invocation is resumable."""
214+
return (
215+
self.resumability_config is not None
216+
and self.resumability_config.is_resumable
217+
)
218+
211219
def reset_agent_state(self, agent_name: str) -> None:
212220
"""Resets the state of an agent, allowing it to be re-run."""
213221
self.agent_states.pop(agent_name, None)
@@ -284,10 +292,7 @@ def should_pause_invocation(self, event: Event) -> bool:
284292
Returns:
285293
Whether to pause the invocation right after this event.
286294
"""
287-
if (
288-
not self.resumability_config
289-
or not self.resumability_config.is_resumable
290-
):
295+
if not self.is_resumable:
291296
return False
292297

293298
if not event.long_running_tool_ids or not event.get_function_calls():

src/google/adk/agents/sequential_agent.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import logging
1920
from typing import AsyncGenerator
2021
from typing import ClassVar
2122
from typing import Type
@@ -32,6 +33,8 @@
3233
from .llm_agent import LlmAgent
3334
from .sequential_agent_config import SequentialAgentConfig
3435

36+
logger = logging.getLogger('google_adk.' + __name__)
37+
3538

3639
@experimental
3740
class SequentialAgentState(BaseAgentState):
@@ -51,24 +54,69 @@ class SequentialAgent(BaseAgent):
5154
async def _run_async_impl(
5255
self, ctx: InvocationContext
5356
) -> AsyncGenerator[Event, None]:
54-
# Skip if there is no sub-agent.
5557
if not self.sub_agents:
5658
return
5759

58-
for sub_agent in self.sub_agents:
59-
pause_invocation = False
60+
# Initialize or resume the execution state from the agent state.
61+
agent_state = self._load_agent_state(ctx, SequentialAgentState)
62+
start_index = self._get_start_index(agent_state)
63+
64+
pause_invocation = False
65+
resuming_sub_agent = agent_state is not None
66+
for i in range(start_index, len(self.sub_agents)):
67+
sub_agent = self.sub_agents[i]
68+
if not resuming_sub_agent:
69+
# If we are resuming from the current event, it means the same event has
70+
# already been logged, so we should avoid yielding it again.
71+
if ctx.is_resumable:
72+
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
73+
yield self._create_agent_state_event(ctx, agent_state=agent_state)
74+
75+
# Reset the sub-agent's state in the context to ensure that each
76+
# sub-agent starts fresh.
77+
ctx.reset_agent_state(sub_agent.name)
6078

6179
async with Aclosing(sub_agent.run_async(ctx)) as agen:
6280
async for event in agen:
6381
yield event
6482
if ctx.should_pause_invocation(event):
6583
pause_invocation = True
6684

67-
# Indicates the invocation should pause when receiving signal from
68-
# the current sub_agent.
85+
# Skip the rest of the sub-agents if the invocation is paused.
6986
if pause_invocation:
7087
return
7188

89+
# Reset the flag for the next sub-agent.
90+
resuming_sub_agent = False
91+
92+
if ctx.is_resumable:
93+
yield self._create_agent_state_event(ctx, end_of_agent=True)
94+
95+
def _get_start_index(
96+
self,
97+
agent_state: SequentialAgentState,
98+
) -> int:
99+
"""Calculates the start index for the sub-agent loop."""
100+
if not agent_state:
101+
return 0
102+
103+
if not agent_state.current_sub_agent:
104+
# This means the process was finished.
105+
return len(self.sub_agents)
106+
107+
try:
108+
sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents]
109+
return sub_agent_names.index(agent_state.current_sub_agent)
110+
except ValueError:
111+
# A sub-agent was removed so the agent name is not found.
112+
# For now, we restart from the beginning.
113+
logger.warning(
114+
'Sub-agent %s was removed so the agent name is not found. Restarting'
115+
' from the beginning.',
116+
agent_state.current_sub_agent,
117+
)
118+
return 0
119+
72120
@override
73121
async def _run_live_impl(
74122
self, ctx: InvocationContext

src/google/adk/flows/llm_flows/contents.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def _get_current_turn_contents(
350350
# Find the latest event that starts the current turn and process from there
351351
for i in range(len(events) - 1, -1, -1):
352352
event = events[i]
353+
if not event.content:
354+
continue
353355
if event.author == 'user' or _is_other_agent_reply(agent_name, event):
354356
return _get_contents(current_branch, events[i:], agent_name)
355357

tests/unittests/agents/test_base_agent.py

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.adk.agents.base_agent import BaseAgentState
2727
from google.adk.agents.callback_context import CallbackContext
2828
from google.adk.agents.invocation_context import InvocationContext
29+
from google.adk.apps.app import ResumabilityConfig
2930
from google.adk.events.event import Event
3031
from google.adk.plugins.base_plugin import BasePlugin
3132
from google.adk.plugins.plugin_manager import PluginManager
@@ -733,39 +734,6 @@ async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
733734
[e async for e in agent.run_live(parent_ctx)]
734735

735736

736-
@pytest.mark.asyncio
737-
async def test_create_agent_state_event(request: pytest.FixtureRequest):
738-
# Arrange
739-
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
740-
ctx = await _create_parent_invocation_context(
741-
request.function.__name__, agent, branch='test_branch'
742-
)
743-
state = BaseAgentState()
744-
745-
# Act
746-
event = agent._create_agent_state_event(ctx, state=state)
747-
748-
# Assert
749-
assert event.invocation_id == ctx.invocation_id
750-
assert event.author == agent.name
751-
assert event.branch == 'test_branch'
752-
assert event.actions is not None
753-
assert event.actions.agent_state is not None
754-
assert event.actions.agent_state == state.model_dump(mode='json')
755-
assert not event.actions.end_of_agent
756-
757-
# Act
758-
event = agent._create_agent_state_event(ctx, end_of_agent=True)
759-
760-
# Assert
761-
assert event.invocation_id == ctx.invocation_id
762-
assert event.author == agent.name
763-
assert event.branch == 'test_branch'
764-
assert event.actions is not None
765-
assert event.actions.end_of_agent
766-
assert event.actions.agent_state is None
767-
768-
769737
def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest):
770738
sub_agents: list[BaseAgent] = [
771739
_TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
@@ -895,7 +863,7 @@ class _TestAgentState(BaseAgentState):
895863

896864

897865
@pytest.mark.asyncio
898-
async def test_load_agent_state_no_resume():
866+
async def test_load_agent_state_not_resumable():
899867
agent = BaseAgent(name='test_agent')
900868
session_service = InMemorySessionService()
901869
session = await session_service.create_session(
@@ -907,14 +875,15 @@ async def test_load_agent_state_no_resume():
907875
session=session,
908876
session_service=session_service,
909877
)
910-
default_state = _TestAgentState(test_field='default')
911878

912-
state, is_resuming = agent._load_agent_state(
913-
ctx, _TestAgentState, default_state
914-
)
879+
# Test case 1: resumability_config is None
880+
state = agent._load_agent_state(ctx, _TestAgentState)
881+
assert state is None
915882

916-
assert not is_resuming
917-
assert state == default_state
883+
# Test case 2: is_resumable is False
884+
ctx.resumability_config = ResumabilityConfig(is_resumable=False)
885+
state = agent._load_agent_state(ctx, _TestAgentState)
886+
assert state is None
918887

919888

920889
@pytest.mark.asyncio
@@ -929,13 +898,70 @@ async def test_load_agent_state_with_resume():
929898
agent=agent,
930899
session=session,
931900
session_service=session_service,
901+
resumability_config=ResumabilityConfig(is_resumable=True),
932902
)
903+
904+
# Test case 1: agent state not in context
905+
state = agent._load_agent_state(ctx, _TestAgentState)
906+
assert state is None
907+
908+
# Test case 2: agent state in context
933909
persisted_state = _TestAgentState(test_field='resumed')
934910
ctx.agent_states[agent.name] = persisted_state.model_dump(mode='json')
935911

936-
state, is_resuming = agent._load_agent_state(
937-
ctx, _TestAgentState, _TestAgentState()
912+
state = agent._load_agent_state(ctx, _TestAgentState)
913+
assert state == persisted_state
914+
915+
916+
@pytest.mark.asyncio
917+
async def test_create_agent_state_event():
918+
agent = BaseAgent(name='test_agent')
919+
session_service = InMemorySessionService()
920+
session = await session_service.create_session(
921+
app_name='test_app', user_id='test_user'
922+
)
923+
ctx = InvocationContext(
924+
invocation_id='test_invocation',
925+
agent=agent,
926+
session=session,
927+
session_service=session_service,
938928
)
939929

940-
assert is_resuming
941-
assert state == persisted_state
930+
ctx.branch = 'test_branch'
931+
932+
# Test case 1: with state
933+
state = _TestAgentState(test_field='checkpoint')
934+
event = agent._create_agent_state_event(ctx, agent_state=state)
935+
assert event is not None
936+
assert event.invocation_id == ctx.invocation_id
937+
assert event.author == agent.name
938+
assert event.branch == 'test_branch'
939+
assert event.actions is not None
940+
assert event.actions.agent_state is not None
941+
assert event.actions.agent_state == state.model_dump(mode='json')
942+
assert not event.actions.end_of_agent
943+
944+
# Test case 2: with end_of_agent
945+
event = agent._create_agent_state_event(ctx, end_of_agent=True)
946+
assert event is not None
947+
assert event.invocation_id == ctx.invocation_id
948+
assert event.author == agent.name
949+
assert event.branch == 'test_branch'
950+
assert event.actions is not None
951+
assert event.actions.end_of_agent
952+
assert event.actions.agent_state is None
953+
954+
# Test case 3: with both state and end_of_agent
955+
state = _TestAgentState(test_field='checkpoint')
956+
event = agent._create_agent_state_event(
957+
ctx, agent_state=state, end_of_agent=True
958+
)
959+
assert event is not None
960+
assert event.actions.agent_state == state.model_dump(mode='json')
961+
assert event.actions.end_of_agent
962+
963+
# Test case 4: with neither
964+
event = agent._create_agent_state_event(ctx)
965+
assert event is not None
966+
assert event.actions.agent_state is None
967+
assert not event.actions.end_of_agent

tests/unittests/agents/test_invocation_context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,22 @@ def test_should_not_pause_invocation_with_no_function_calls(
206206
assert not mock_invocation_context.should_pause_invocation(
207207
nonpausable_event
208208
)
209+
210+
def test_is_resumable_true(self):
211+
"""Tests that is_resumable is True when resumability is enabled."""
212+
invocation_context = self._create_test_invocation_context(
213+
ResumabilityConfig(is_resumable=True)
214+
)
215+
assert invocation_context.is_resumable
216+
217+
def test_is_resumable_false(self):
218+
"""Tests that is_resumable is False when resumability is disabled."""
219+
invocation_context = self._create_test_invocation_context(
220+
ResumabilityConfig(is_resumable=False)
221+
)
222+
assert not invocation_context.is_resumable
223+
224+
def test_is_resumable_no_config(self):
225+
"""Tests that is_resumable is False when no resumability config is set."""
226+
invocation_context = self._create_test_invocation_context(None)
227+
assert not invocation_context.is_resumable

tests/unittests/agents/test_llm_agent_include_contents.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,10 @@ async def test_include_contents_none_sequential_agents():
219219
runner = testing_utils.InMemoryRunner(sequential_agent)
220220
events = runner.run("Original user request")
221221

222-
assert len(events) == 2
223-
assert events[0].author == "agent1"
224-
assert events[1].author == "agent2"
222+
simplified_events = [event for event in events if event.content]
223+
assert len(simplified_events) == 2
224+
assert simplified_events[0].author == "agent1"
225+
assert simplified_events[1].author == "agent2"
225226

226227
# Agent1 sees original user request
227228
agent1_contents = testing_utils.simplify_contents(

0 commit comments

Comments
 (0)