18
18
from typing import AsyncGenerator
19
19
20
20
from google .adk .agents .base_agent import BaseAgent
21
+ from google .adk .agents .base_agent import BaseAgentState
21
22
from google .adk .agents .invocation_context import InvocationContext
22
23
from google .adk .agents .parallel_agent import ParallelAgent
23
24
from google .adk .agents .sequential_agent import SequentialAgent
25
+ from google .adk .agents .sequential_agent import SequentialAgentState
26
+ from google .adk .apps .app import ResumabilityConfig
24
27
from google .adk .events .event import Event
25
28
from google .adk .sessions .in_memory_session_service import InMemorySessionService
26
29
from google .genai import types
@@ -52,7 +55,7 @@ async def _run_async_impl(
52
55
53
56
54
57
async def _create_parent_invocation_context (
55
- test_name : str , agent : BaseAgent
58
+ test_name : str , agent : BaseAgent , is_resumable : bool = False
56
59
) -> InvocationContext :
57
60
session_service = InMemorySessionService ()
58
61
session = await session_service .create_session (
@@ -63,11 +66,13 @@ async def _create_parent_invocation_context(
63
66
agent = agent ,
64
67
session = session ,
65
68
session_service = session_service ,
69
+ resumability_config = ResumabilityConfig (is_resumable = is_resumable ),
66
70
)
67
71
68
72
69
73
@pytest .mark .asyncio
70
- async def test_run_async (request : pytest .FixtureRequest ):
74
+ @pytest .mark .parametrize ('is_resumable' , [True , False ])
75
+ async def test_run_async (request : pytest .FixtureRequest , is_resumable : bool ):
71
76
agent1 = _TestingAgent (
72
77
name = f'{ request .function .__name__ } _test_agent_1' ,
73
78
delay = 0.5 ,
@@ -81,23 +86,43 @@ async def test_run_async(request: pytest.FixtureRequest):
81
86
],
82
87
)
83
88
parent_ctx = await _create_parent_invocation_context (
84
- request .function .__name__ , parallel_agent
89
+ request .function .__name__ , parallel_agent , is_resumable = is_resumable
85
90
)
86
91
events = [e async for e in parallel_agent .run_async (parent_ctx )]
87
92
88
- assert len (events ) == 2
89
- # agent2 generates an event first, then agent1. Because they run in parallel
90
- # and agent1 has a delay.
91
- assert events [0 ].author == agent2 .name
92
- assert events [1 ].author == agent1 .name
93
- assert events [0 ].branch .endswith (f'{ parallel_agent .name } .{ agent2 .name } ' )
94
- assert events [1 ].branch .endswith (f'{ parallel_agent .name } .{ agent1 .name } ' )
95
- assert events [0 ].content .parts [0 ].text == f'Hello, async { agent2 .name } !'
96
- assert events [1 ].content .parts [0 ].text == f'Hello, async { agent1 .name } !'
93
+ if is_resumable :
94
+ assert len (events ) == 4
95
+
96
+ assert events [0 ].author == parallel_agent .name
97
+ assert not events [0 ].actions .end_of_agent
98
+
99
+ # agent2 generates an event first, then agent1. Because they run in parallel
100
+ # and agent1 has a delay.
101
+ assert events [1 ].author == agent2 .name
102
+ assert events [2 ].author == agent1 .name
103
+ assert events [1 ].branch == f'{ parallel_agent .name } .{ agent2 .name } '
104
+ assert events [2 ].branch == f'{ parallel_agent .name } .{ agent1 .name } '
105
+ assert events [1 ].content .parts [0 ].text == f'Hello, async { agent2 .name } !'
106
+ assert events [2 ].content .parts [0 ].text == f'Hello, async { agent1 .name } !'
107
+
108
+ assert events [3 ].author == parallel_agent .name
109
+ assert events [3 ].actions .end_of_agent
110
+ else :
111
+ assert len (events ) == 2
112
+
113
+ assert events [0 ].author == agent2 .name
114
+ assert events [1 ].author == agent1 .name
115
+ assert events [0 ].branch == f'{ parallel_agent .name } .{ agent2 .name } '
116
+ assert events [1 ].branch == f'{ parallel_agent .name } .{ agent1 .name } '
117
+ assert events [0 ].content .parts [0 ].text == f'Hello, async { agent2 .name } !'
118
+ assert events [1 ].content .parts [0 ].text == f'Hello, async { agent1 .name } !'
97
119
98
120
99
121
@pytest .mark .asyncio
100
- async def test_run_async_branches (request : pytest .FixtureRequest ):
122
+ @pytest .mark .parametrize ('is_resumable' , [True , False ])
123
+ async def test_run_async_branches (
124
+ request : pytest .FixtureRequest , is_resumable : bool
125
+ ):
101
126
agent1 = _TestingAgent (
102
127
name = f'{ request .function .__name__ } _test_agent_1' ,
103
128
delay = 0.5 ,
@@ -116,28 +141,124 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
116
141
],
117
142
)
118
143
parent_ctx = await _create_parent_invocation_context (
119
- request .function .__name__ , parallel_agent
144
+ request .function .__name__ , parallel_agent , is_resumable = is_resumable
120
145
)
121
146
events = [e async for e in parallel_agent .run_async (parent_ctx )]
122
147
123
- assert len (events ) == 3
124
- assert (
125
- events [0 ].author == agent2 .name
126
- and events [0 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
148
+ if is_resumable :
149
+ assert len (events ) == 8
150
+
151
+ # 1. parallel agent checkpoint
152
+ assert events [0 ].author == parallel_agent .name
153
+ assert not events [0 ].actions .end_of_agent
154
+
155
+ # 2. sequential agent checkpoint
156
+ assert events [1 ].author == sequential_agent .name
157
+ assert not events [1 ].actions .end_of_agent
158
+ assert events [1 ].actions .agent_state ['current_sub_agent' ] == agent2 .name
159
+ assert events [1 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
160
+
161
+ # 3. agent 2 event
162
+ assert events [2 ].author == agent2 .name
163
+ assert events [2 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
164
+
165
+ # 4. sequential agent checkpoint
166
+ assert events [3 ].author == sequential_agent .name
167
+ assert not events [3 ].actions .end_of_agent
168
+ assert events [3 ].actions .agent_state ['current_sub_agent' ] == agent3 .name
169
+ assert events [3 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
170
+
171
+ # 5. agent 3 event
172
+ assert events [4 ].author == agent3 .name
173
+ assert events [4 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
174
+
175
+ # 6. sequential agent checkpoint (end)
176
+ assert events [5 ].author == sequential_agent .name
177
+ assert events [5 ].actions .end_of_agent
178
+ assert events [5 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
179
+
180
+ # Descendants of the same sub-agent should have the same branch.
181
+ assert events [1 ].branch == events [2 ].branch
182
+ assert events [2 ].branch == events [3 ].branch
183
+ assert events [3 ].branch == events [4 ].branch
184
+ assert events [4 ].branch == events [5 ].branch
185
+
186
+ # 7. agent 1 event
187
+ assert events [6 ].author == agent1 .name
188
+ assert events [6 ].branch == f'{ parallel_agent .name } .{ agent1 .name } '
189
+
190
+ # Sub-agents should have different branches.
191
+ assert events [6 ].branch != events [1 ].branch
192
+
193
+ # 8. parallel agent checkpoint (end)
194
+ assert events [7 ].author == parallel_agent .name
195
+ assert events [7 ].actions .end_of_agent
196
+ else :
197
+ assert len (events ) == 3
198
+
199
+ # 1. agent 2 event
200
+ assert events [0 ].author == agent2 .name
201
+ assert events [0 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
202
+
203
+ # 2. agent 3 event
204
+ assert events [1 ].author == agent3 .name
205
+ assert events [1 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
206
+
207
+ # 3. agent 1 event
208
+ assert events [2 ].author == agent1 .name
209
+ assert events [2 ].branch == f'{ parallel_agent .name } .{ agent1 .name } '
210
+
211
+
212
+ @pytest .mark .asyncio
213
+ async def test_resume_async_branches (request : pytest .FixtureRequest ):
214
+ agent1 = _TestingAgent (
215
+ name = f'{ request .function .__name__ } _test_agent_1' , delay = 0.5
216
+ )
217
+ agent2 = _TestingAgent (name = f'{ request .function .__name__ } _test_agent_2' )
218
+ agent3 = _TestingAgent (name = f'{ request .function .__name__ } _test_agent_3' )
219
+ sequential_agent = SequentialAgent (
220
+ name = f'{ request .function .__name__ } _test_sequential_agent' ,
221
+ sub_agents = [agent2 , agent3 ],
222
+ )
223
+ parallel_agent = ParallelAgent (
224
+ name = f'{ request .function .__name__ } _test_parallel_agent' ,
225
+ sub_agents = [
226
+ sequential_agent ,
227
+ agent1 ,
228
+ ],
127
229
)
128
- assert (
129
- events [1 ].author == agent3 .name
130
- and events [0 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
230
+ parent_ctx = await _create_parent_invocation_context (
231
+ request .function .__name__ , parallel_agent , is_resumable = True
131
232
)
132
- # Descendants of the same sub-agent should have the same branch.
133
- assert events [0 ].branch == events [1 ].branch
134
- assert (
135
- events [2 ].author == agent1 .name
136
- and events [2 ].branch == f'{ parallel_agent .name } .{ agent1 .name } '
233
+ parent_ctx .agent_states [parallel_agent .name ] = BaseAgentState ().model_dump (
234
+ mode = 'json'
137
235
)
138
- # Sub-agents should have different branches.
139
- assert events [2 ].branch != events [1 ].branch
140
- assert events [2 ].branch != events [0 ].branch
236
+ parent_ctx .agent_states [sequential_agent .name ] = SequentialAgentState (
237
+ current_sub_agent = agent3 .name
238
+ ).model_dump (mode = 'json' )
239
+
240
+ events = [e async for e in parallel_agent .run_async (parent_ctx )]
241
+
242
+ assert len (events ) == 4
243
+
244
+ # The sequential agent resumes from agent3.
245
+ # 1. Agent 3 event
246
+ assert events [0 ].author == agent3 .name
247
+ assert events [0 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
248
+
249
+ # 2. Sequential agent checkpoint (end)
250
+ assert events [1 ].author == sequential_agent .name
251
+ assert events [1 ].actions .end_of_agent
252
+ assert events [1 ].branch == f'{ parallel_agent .name } .{ sequential_agent .name } '
253
+
254
+ # Agent 1 runs in parallel but has a delay.
255
+ # 3. Agent 1 event
256
+ assert events [2 ].author == agent1 .name
257
+ assert events [2 ].branch == f'{ parallel_agent .name } .{ agent1 .name } '
258
+
259
+ # 4. Parallel agent checkpoint (end)
260
+ assert events [3 ].author == parallel_agent .name
261
+ assert events [3 ].actions .end_of_agent
141
262
142
263
143
264
class _TestingAgentWithMultipleEvents (_TestingAgent ):
0 commit comments