Skip to content

Commit 2f1040f

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Implement checkpoint and resume logic for ParallelAgent
PiperOrigin-RevId: 812658378
1 parent 943abec commit 2f1040f

File tree

2 files changed

+171
-35
lines changed

2 files changed

+171
-35
lines changed

src/google/adk/agents/parallel_agent.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..events.event import Event
2727
from ..utils.context_utils import Aclosing
2828
from .base_agent import BaseAgent
29+
from .base_agent import BaseAgentState
2930
from .base_agent_config import BaseAgentConfig
3031
from .invocation_context import InvocationContext
3132
from .parallel_agent_config import ParallelAgentConfig
@@ -178,12 +179,22 @@ async def _run_async_impl(
178179
if not self.sub_agents:
179180
return
180181

181-
agent_runs = [
182-
sub_agent.run_async(
183-
_create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
184-
)
185-
for sub_agent in self.sub_agents
186-
]
182+
agent_state = self._load_agent_state(ctx, BaseAgentState)
183+
if ctx.is_resumable and agent_state is None:
184+
yield self._create_agent_state_event(ctx, agent_state=BaseAgentState())
185+
186+
agent_runs = []
187+
# Prepare and collect async generators for each sub-agent.
188+
for sub_agent in self.sub_agents:
189+
if agent_state is None:
190+
# Reset sub-agent state to make sure each sub-agent starts fresh.
191+
ctx.reset_agent_state(sub_agent.name)
192+
193+
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
194+
195+
# Only include sub-agents that haven't finished in a previous run.
196+
if not sub_agent_ctx.end_of_agents.get(sub_agent.name):
197+
agent_runs.append(sub_agent.run_async(sub_agent_ctx))
187198

188199
pause_invocation = False
189200
try:
@@ -203,6 +214,10 @@ async def _run_async_impl(
203214
if pause_invocation:
204215
return
205216

217+
# Once all sub-agents are done, mark the ParallelAgent as final.
218+
if ctx.is_resumable:
219+
yield self._create_agent_state_event(ctx, end_of_agent=True)
220+
206221
finally:
207222
for sub_agent_run in agent_runs:
208223
await sub_agent_run.aclose()

tests/unittests/agents/test_parallel_agent.py

Lines changed: 150 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
from typing import AsyncGenerator
1919

2020
from google.adk.agents.base_agent import BaseAgent
21+
from google.adk.agents.base_agent import BaseAgentState
2122
from google.adk.agents.invocation_context import InvocationContext
2223
from google.adk.agents.parallel_agent import ParallelAgent
2324
from google.adk.agents.sequential_agent import SequentialAgent
25+
from google.adk.agents.sequential_agent import SequentialAgentState
26+
from google.adk.apps.app import ResumabilityConfig
2427
from google.adk.events.event import Event
2528
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2629
from google.genai import types
@@ -52,7 +55,7 @@ async def _run_async_impl(
5255

5356

5457
async def _create_parent_invocation_context(
55-
test_name: str, agent: BaseAgent
58+
test_name: str, agent: BaseAgent, is_resumable: bool = False
5659
) -> InvocationContext:
5760
session_service = InMemorySessionService()
5861
session = await session_service.create_session(
@@ -63,11 +66,13 @@ async def _create_parent_invocation_context(
6366
agent=agent,
6467
session=session,
6568
session_service=session_service,
69+
resumability_config=ResumabilityConfig(is_resumable=is_resumable),
6670
)
6771

6872

6973
@pytest.mark.asyncio
70-
async def test_run_async(request: pytest.FixtureRequest):
74+
@pytest.mark.parametrize('is_resumable', [True, False])
75+
async def test_run_async(request: pytest.FixtureRequest, is_resumable: bool):
7176
agent1 = _TestingAgent(
7277
name=f'{request.function.__name__}_test_agent_1',
7378
delay=0.5,
@@ -81,23 +86,43 @@ async def test_run_async(request: pytest.FixtureRequest):
8186
],
8287
)
8388
parent_ctx = await _create_parent_invocation_context(
84-
request.function.__name__, parallel_agent
89+
request.function.__name__, parallel_agent, is_resumable=is_resumable
8590
)
8691
events = [e async for e in parallel_agent.run_async(parent_ctx)]
8792

88-
assert len(events) == 2
89-
# agent2 generates an event first, then agent1. Because they run in parallel
90-
# and agent1 has a delay.
91-
assert events[0].author == agent2.name
92-
assert events[1].author == agent1.name
93-
assert events[0].branch.endswith(f'{parallel_agent.name}.{agent2.name}')
94-
assert events[1].branch.endswith(f'{parallel_agent.name}.{agent1.name}')
95-
assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!'
96-
assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!'
93+
if is_resumable:
94+
assert len(events) == 4
95+
96+
assert events[0].author == parallel_agent.name
97+
assert not events[0].actions.end_of_agent
98+
99+
# agent2 generates an event first, then agent1. Because they run in parallel
100+
# and agent1 has a delay.
101+
assert events[1].author == agent2.name
102+
assert events[2].author == agent1.name
103+
assert events[1].branch == f'{parallel_agent.name}.{agent2.name}'
104+
assert events[2].branch == f'{parallel_agent.name}.{agent1.name}'
105+
assert events[1].content.parts[0].text == f'Hello, async {agent2.name}!'
106+
assert events[2].content.parts[0].text == f'Hello, async {agent1.name}!'
107+
108+
assert events[3].author == parallel_agent.name
109+
assert events[3].actions.end_of_agent
110+
else:
111+
assert len(events) == 2
112+
113+
assert events[0].author == agent2.name
114+
assert events[1].author == agent1.name
115+
assert events[0].branch == f'{parallel_agent.name}.{agent2.name}'
116+
assert events[1].branch == f'{parallel_agent.name}.{agent1.name}'
117+
assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!'
118+
assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!'
97119

98120

99121
@pytest.mark.asyncio
100-
async def test_run_async_branches(request: pytest.FixtureRequest):
122+
@pytest.mark.parametrize('is_resumable', [True, False])
123+
async def test_run_async_branches(
124+
request: pytest.FixtureRequest, is_resumable: bool
125+
):
101126
agent1 = _TestingAgent(
102127
name=f'{request.function.__name__}_test_agent_1',
103128
delay=0.5,
@@ -116,28 +141,124 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
116141
],
117142
)
118143
parent_ctx = await _create_parent_invocation_context(
119-
request.function.__name__, parallel_agent
144+
request.function.__name__, parallel_agent, is_resumable=is_resumable
120145
)
121146
events = [e async for e in parallel_agent.run_async(parent_ctx)]
122147

123-
assert len(events) == 3
124-
assert (
125-
events[0].author == agent2.name
126-
and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
148+
if is_resumable:
149+
assert len(events) == 8
150+
151+
# 1. parallel agent checkpoint
152+
assert events[0].author == parallel_agent.name
153+
assert not events[0].actions.end_of_agent
154+
155+
# 2. sequential agent checkpoint
156+
assert events[1].author == sequential_agent.name
157+
assert not events[1].actions.end_of_agent
158+
assert events[1].actions.agent_state['current_sub_agent'] == agent2.name
159+
assert events[1].branch == f'{parallel_agent.name}.{sequential_agent.name}'
160+
161+
# 3. agent 2 event
162+
assert events[2].author == agent2.name
163+
assert events[2].branch == f'{parallel_agent.name}.{sequential_agent.name}'
164+
165+
# 4. sequential agent checkpoint
166+
assert events[3].author == sequential_agent.name
167+
assert not events[3].actions.end_of_agent
168+
assert events[3].actions.agent_state['current_sub_agent'] == agent3.name
169+
assert events[3].branch == f'{parallel_agent.name}.{sequential_agent.name}'
170+
171+
# 5. agent 3 event
172+
assert events[4].author == agent3.name
173+
assert events[4].branch == f'{parallel_agent.name}.{sequential_agent.name}'
174+
175+
# 6. sequential agent checkpoint (end)
176+
assert events[5].author == sequential_agent.name
177+
assert events[5].actions.end_of_agent
178+
assert events[5].branch == f'{parallel_agent.name}.{sequential_agent.name}'
179+
180+
# Descendants of the same sub-agent should have the same branch.
181+
assert events[1].branch == events[2].branch
182+
assert events[2].branch == events[3].branch
183+
assert events[3].branch == events[4].branch
184+
assert events[4].branch == events[5].branch
185+
186+
# 7. agent 1 event
187+
assert events[6].author == agent1.name
188+
assert events[6].branch == f'{parallel_agent.name}.{agent1.name}'
189+
190+
# Sub-agents should have different branches.
191+
assert events[6].branch != events[1].branch
192+
193+
# 8. parallel agent checkpoint (end)
194+
assert events[7].author == parallel_agent.name
195+
assert events[7].actions.end_of_agent
196+
else:
197+
assert len(events) == 3
198+
199+
# 1. agent 2 event
200+
assert events[0].author == agent2.name
201+
assert events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
202+
203+
# 2. agent 3 event
204+
assert events[1].author == agent3.name
205+
assert events[1].branch == f'{parallel_agent.name}.{sequential_agent.name}'
206+
207+
# 3. agent 1 event
208+
assert events[2].author == agent1.name
209+
assert events[2].branch == f'{parallel_agent.name}.{agent1.name}'
210+
211+
212+
@pytest.mark.asyncio
213+
async def test_resume_async_branches(request: pytest.FixtureRequest):
214+
agent1 = _TestingAgent(
215+
name=f'{request.function.__name__}_test_agent_1', delay=0.5
216+
)
217+
agent2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
218+
agent3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3')
219+
sequential_agent = SequentialAgent(
220+
name=f'{request.function.__name__}_test_sequential_agent',
221+
sub_agents=[agent2, agent3],
222+
)
223+
parallel_agent = ParallelAgent(
224+
name=f'{request.function.__name__}_test_parallel_agent',
225+
sub_agents=[
226+
sequential_agent,
227+
agent1,
228+
],
127229
)
128-
assert (
129-
events[1].author == agent3.name
130-
and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
230+
parent_ctx = await _create_parent_invocation_context(
231+
request.function.__name__, parallel_agent, is_resumable=True
131232
)
132-
# Descendants of the same sub-agent should have the same branch.
133-
assert events[0].branch == events[1].branch
134-
assert (
135-
events[2].author == agent1.name
136-
and events[2].branch == f'{parallel_agent.name}.{agent1.name}'
233+
parent_ctx.agent_states[parallel_agent.name] = BaseAgentState().model_dump(
234+
mode='json'
137235
)
138-
# Sub-agents should have different branches.
139-
assert events[2].branch != events[1].branch
140-
assert events[2].branch != events[0].branch
236+
parent_ctx.agent_states[sequential_agent.name] = SequentialAgentState(
237+
current_sub_agent=agent3.name
238+
).model_dump(mode='json')
239+
240+
events = [e async for e in parallel_agent.run_async(parent_ctx)]
241+
242+
assert len(events) == 4
243+
244+
# The sequential agent resumes from agent3.
245+
# 1. Agent 3 event
246+
assert events[0].author == agent3.name
247+
assert events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
248+
249+
# 2. Sequential agent checkpoint (end)
250+
assert events[1].author == sequential_agent.name
251+
assert events[1].actions.end_of_agent
252+
assert events[1].branch == f'{parallel_agent.name}.{sequential_agent.name}'
253+
254+
# Agent 1 runs in parallel but has a delay.
255+
# 3. Agent 1 event
256+
assert events[2].author == agent1.name
257+
assert events[2].branch == f'{parallel_agent.name}.{agent1.name}'
258+
259+
# 4. Parallel agent checkpoint (end)
260+
assert events[3].author == parallel_agent.name
261+
assert events[3].actions.end_of_agent
141262

142263

143264
class _TestingAgentWithMultipleEvents(_TestingAgent):

0 commit comments

Comments
 (0)