19
19
from google .adk .agents .base_agent import BaseAgent
20
20
from google .adk .agents .invocation_context import InvocationContext
21
21
from google .adk .agents .loop_agent import LoopAgent
22
+ from google .adk .agents .loop_agent import LoopAgentState
23
+ from google .adk .apps import ResumabilityConfig
22
24
from google .adk .events .event import Event
23
25
from google .adk .events .event_actions import EventActions
24
26
from google .adk .sessions .in_memory_session_service import InMemorySessionService
25
27
from google .genai import types
26
28
import pytest
27
29
from typing_extensions import override
28
30
31
+ from .. import testing_utils
32
+
33
+ END_OF_AGENT = testing_utils .END_OF_AGENT
34
+
29
35
30
36
class _TestingAgent (BaseAgent ):
31
37
@@ -72,13 +78,13 @@ async def _run_async_impl(
72
78
author = self .name ,
73
79
invocation_id = ctx .invocation_id ,
74
80
content = types .Content (
75
- parts = [types .Part (text = f 'I have done my job after escalation!!' )]
81
+ parts = [types .Part (text = 'I have done my job after escalation!!' )]
76
82
),
77
83
)
78
84
79
85
80
86
async def _create_parent_invocation_context (
81
- test_name : str , agent : BaseAgent
87
+ test_name : str , agent : BaseAgent , resumable : bool = False
82
88
) -> InvocationContext :
83
89
session_service = InMemorySessionService ()
84
90
session = await session_service .create_session (
@@ -89,11 +95,13 @@ async def _create_parent_invocation_context(
89
95
agent = agent ,
90
96
session = session ,
91
97
session_service = session_service ,
98
+ resumability_config = ResumabilityConfig (is_resumable = resumable ),
92
99
)
93
100
94
101
95
102
@pytest .mark .asyncio
96
- async def test_run_async (request : pytest .FixtureRequest ):
103
+ @pytest .mark .parametrize ('resumable' , [True , False ])
104
+ async def test_run_async (request : pytest .FixtureRequest , resumable : bool ):
97
105
agent = _TestingAgent (name = f'{ request .function .__name__ } _test_agent' )
98
106
loop_agent = LoopAgent (
99
107
name = f'{ request .function .__name__ } _test_loop_agent' ,
@@ -103,15 +111,60 @@ async def test_run_async(request: pytest.FixtureRequest):
103
111
],
104
112
)
105
113
parent_ctx = await _create_parent_invocation_context (
106
- request .function .__name__ , loop_agent
114
+ request .function .__name__ , loop_agent , resumable = resumable
115
+ )
116
+ events = [e async for e in loop_agent .run_async (parent_ctx )]
117
+
118
+ simplified_events = testing_utils .simplify_resumable_app_events (events )
119
+ if resumable :
120
+ expected_events = [
121
+ (
122
+ loop_agent .name ,
123
+ {'current_sub_agent' : agent .name , 'times_looped' : 0 },
124
+ ),
125
+ (agent .name , f'Hello, async { agent .name } !' ),
126
+ (
127
+ loop_agent .name ,
128
+ {'current_sub_agent' : agent .name , 'times_looped' : 1 },
129
+ ),
130
+ (agent .name , f'Hello, async { agent .name } !' ),
131
+ (loop_agent .name , END_OF_AGENT ),
132
+ ]
133
+ else :
134
+ expected_events = [
135
+ (agent .name , f'Hello, async { agent .name } !' ),
136
+ (agent .name , f'Hello, async { agent .name } !' ),
137
+ ]
138
+ assert simplified_events == expected_events
139
+
140
+
141
+ @pytest .mark .asyncio
142
+ async def test_resume_async (request : pytest .FixtureRequest ):
143
+ agent_1 = _TestingAgent (name = f'{ request .function .__name__ } _test_agent_1' )
144
+ agent_2 = _TestingAgent (name = f'{ request .function .__name__ } _test_agent_2' )
145
+ loop_agent = LoopAgent (
146
+ name = f'{ request .function .__name__ } _test_loop_agent' ,
147
+ max_iterations = 2 ,
148
+ sub_agents = [
149
+ agent_1 ,
150
+ agent_2 ,
151
+ ],
107
152
)
153
+ parent_ctx = await _create_parent_invocation_context (
154
+ request .function .__name__ , loop_agent , resumable = True
155
+ )
156
+ parent_ctx .agent_states [loop_agent .name ] = LoopAgentState (
157
+ current_sub_agent = agent_2 .name , times_looped = 1
158
+ ).model_dump (mode = 'json' )
159
+
108
160
events = [e async for e in loop_agent .run_async (parent_ctx )]
109
161
110
- assert len (events ) == 2
111
- assert events [0 ].author == agent .name
112
- assert events [1 ].author == agent .name
113
- assert events [0 ].content .parts [0 ].text == f'Hello, async { agent .name } !'
114
- assert events [1 ].content .parts [0 ].text == f'Hello, async { agent .name } !'
162
+ simplified_events = testing_utils .simplify_resumable_app_events (events )
163
+ expected_events = [
164
+ (agent_2 .name , f'Hello, async { agent_2 .name } !' ),
165
+ (loop_agent .name , END_OF_AGENT ),
166
+ ]
167
+ assert simplified_events == expected_events
115
168
116
169
117
170
@pytest .mark .asyncio
@@ -129,7 +182,10 @@ async def test_run_async_skip_if_no_sub_agent(request: pytest.FixtureRequest):
129
182
130
183
131
184
@pytest .mark .asyncio
132
- async def test_run_async_with_escalate_action (request : pytest .FixtureRequest ):
185
+ @pytest .mark .parametrize ('resumable' , [True , False ])
186
+ async def test_run_async_with_escalate_action (
187
+ request : pytest .FixtureRequest , resumable : bool
188
+ ):
133
189
non_escalating_agent = _TestingAgent (
134
190
name = f'{ request .function .__name__ } _test_non_escalating_agent'
135
191
)
@@ -144,20 +200,52 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
144
200
sub_agents = [non_escalating_agent , escalating_agent , ignored_agent ],
145
201
)
146
202
parent_ctx = await _create_parent_invocation_context (
147
- request .function .__name__ , loop_agent
203
+ request .function .__name__ , loop_agent , resumable = resumable
148
204
)
149
205
events = [e async for e in loop_agent .run_async (parent_ctx )]
150
206
151
- # Only two events are generated because the sub escalating_agent escalates.
152
- assert len (events ) == 3
153
- assert events [0 ].author == non_escalating_agent .name
154
- assert events [1 ].author == escalating_agent .name
155
- assert events [0 ].content .parts [0 ].text == (
156
- f'Hello, async { non_escalating_agent .name } !'
157
- )
158
- assert events [1 ].content .parts [0 ].text == (
159
- f'Hello, async { escalating_agent .name } !'
160
- )
161
- assert (
162
- events [2 ].content .parts [0 ].text == 'I have done my job after escalation!!'
163
- )
207
+ simplified_events = testing_utils .simplify_resumable_app_events (events )
208
+
209
+ if resumable :
210
+ expected_events = [
211
+ (
212
+ loop_agent .name ,
213
+ {
214
+ 'current_sub_agent' : non_escalating_agent .name ,
215
+ 'times_looped' : 0 ,
216
+ },
217
+ ),
218
+ (
219
+ non_escalating_agent .name ,
220
+ f'Hello, async { non_escalating_agent .name } !' ,
221
+ ),
222
+ (
223
+ loop_agent .name ,
224
+ {'current_sub_agent' : escalating_agent .name , 'times_looped' : 0 },
225
+ ),
226
+ (
227
+ escalating_agent .name ,
228
+ f'Hello, async { escalating_agent .name } !' ,
229
+ ),
230
+ (
231
+ escalating_agent .name ,
232
+ 'I have done my job after escalation!!' ,
233
+ ),
234
+ (loop_agent .name , END_OF_AGENT ),
235
+ ]
236
+ else :
237
+ expected_events = [
238
+ (
239
+ non_escalating_agent .name ,
240
+ f'Hello, async { non_escalating_agent .name } !' ,
241
+ ),
242
+ (
243
+ escalating_agent .name ,
244
+ f'Hello, async { escalating_agent .name } !' ,
245
+ ),
246
+ (
247
+ escalating_agent .name ,
248
+ 'I have done my job after escalation!!' ,
249
+ ),
250
+ ]
251
+ assert simplified_events == expected_events
0 commit comments