Skip to content

Commit 28d44a3

Browse files
XinranTangcopybara-github
authored andcommitted
test: Make testing_utils.InMemoryRunner support ADK App and add utils for extracting event contents for testing resumability
PiperOrigin-RevId: 811933527
1 parent e172811 commit 28d44a3

File tree

1 file changed

+61
-13
lines changed

1 file changed

+61
-13
lines changed

tests/unittests/testing_utils.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import contextlib
1717
from typing import AsyncGenerator
1818
from typing import Generator
19+
from typing import Optional
1920
from typing import Union
2021

2122
from google.adk.agents.invocation_context import InvocationContext
2223
from google.adk.agents.live_request_queue import LiveRequestQueue
2324
from google.adk.agents.llm_agent import Agent
2425
from google.adk.agents.llm_agent import LlmAgent
2526
from google.adk.agents.run_config import RunConfig
27+
from google.adk.apps.app import App
2628
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
2729
from google.adk.events.event import Event
2830
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
@@ -119,7 +121,32 @@ def append_user_content(
119121
# Extracts the contents from the events and transform them into a list of
120122
# (author, simplified_content) tuples.
121123
def simplify_events(events: list[Event]) -> list[(str, types.Part)]:
122-
return [(event.author, simplify_content(event.content)) for event in events]
124+
return [
125+
(event.author, simplify_content(event.content))
126+
for event in events
127+
if event.content
128+
]
129+
130+
131+
END_OF_AGENT = 'end_of_agent'
132+
133+
134+
# Extracts the contents from the events and transform them into a list of
135+
# (author, simplified_content OR AgentState OR "end_of_agent") tuples.
136+
#
137+
# Could be used to compare events for testing resumability.
138+
def simplify_resumable_app_events(
139+
events: list[Event],
140+
) -> list[(str, Union[types.Part, str])]:
141+
results = []
142+
for event in events:
143+
if event.content:
144+
results.append((event.author, simplify_content(event.content)))
145+
elif event.actions.end_of_agent:
146+
results.append((event.author, END_OF_AGENT))
147+
elif event.actions.agent_state is not None:
148+
results.append((event.author, event.actions.agent_state))
149+
return results
123150

124151

125152
# Simplifies the contents into a list of (author, simplified_content) tuples.
@@ -189,31 +216,52 @@ class InMemoryRunner:
189216

190217
def __init__(
191218
self,
192-
root_agent: Union[Agent, LlmAgent],
219+
root_agent: Optional[Union[Agent, LlmAgent]] = None,
193220
response_modalities: list[str] = None,
194221
plugins: list[BasePlugin] = [],
222+
app: Optional[App] = None,
195223
):
196-
self.root_agent = root_agent
197-
self.runner = Runner(
198-
app_name='test_app',
199-
agent=root_agent,
200-
artifact_service=InMemoryArtifactService(),
201-
session_service=InMemorySessionService(),
202-
memory_service=InMemoryMemoryService(),
203-
plugins=plugins,
204-
)
224+
"""Initializes the InMemoryRunner.
225+
226+
Args:
227+
root_agent: The root agent to run, won't be used if app is provided.
228+
response_modalities: The response modalities of the runner.
229+
plugins: The plugins to use in the runner, won't be used if app is
230+
provided.
231+
app: The app to use in the runner.
232+
"""
233+
if not app:
234+
self.app_name = 'test_app'
235+
self.root_agent = root_agent
236+
self.runner = Runner(
237+
app_name='test_app',
238+
agent=root_agent,
239+
artifact_service=InMemoryArtifactService(),
240+
session_service=InMemorySessionService(),
241+
memory_service=InMemoryMemoryService(),
242+
plugins=plugins,
243+
)
244+
else:
245+
self.app_name = app.name
246+
self.root_agent = app.root_agent
247+
self.runner = Runner(
248+
app=app,
249+
artifact_service=InMemoryArtifactService(),
250+
session_service=InMemorySessionService(),
251+
memory_service=InMemoryMemoryService(),
252+
)
205253
self.session_id = None
206254

207255
@property
208256
def session(self) -> Session:
209257
if not self.session_id:
210258
session = self.runner.session_service.create_session_sync(
211-
app_name='test_app', user_id='test_user'
259+
app_name=self.app_name, user_id='test_user'
212260
)
213261
self.session_id = session.id
214262
return session
215263
return self.runner.session_service.get_session_sync(
216-
app_name='test_app', user_id='test_user', session_id=self.session_id
264+
app_name=self.app_name, user_id='test_user', session_id=self.session_id
217265
)
218266

219267
def run(self, new_message: types.ContentUnion) -> list[Event]:

0 commit comments

Comments
 (0)