|
16 | 16 | import contextlib
|
17 | 17 | from typing import AsyncGenerator
|
18 | 18 | from typing import Generator
|
| 19 | +from typing import Optional |
19 | 20 | from typing import Union
|
20 | 21 |
|
21 | 22 | from google.adk.agents.invocation_context import InvocationContext
|
22 | 23 | from google.adk.agents.live_request_queue import LiveRequestQueue
|
23 | 24 | from google.adk.agents.llm_agent import Agent
|
24 | 25 | from google.adk.agents.llm_agent import LlmAgent
|
25 | 26 | from google.adk.agents.run_config import RunConfig
|
| 27 | +from google.adk.apps.app import App |
26 | 28 | from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
27 | 29 | from google.adk.events.event import Event
|
28 | 30 | from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
|
@@ -119,7 +121,32 @@ def append_user_content(
|
119 | 121 | # Extracts the contents from the events and transform them into a list of
|
120 | 122 | # (author, simplified_content) tuples.
|
121 | 123 | def simplify_events(events: list[Event]) -> list[(str, types.Part)]:
|
122 |
| - return [(event.author, simplify_content(event.content)) for event in events] |
| 124 | + return [ |
| 125 | + (event.author, simplify_content(event.content)) |
| 126 | + for event in events |
| 127 | + if event.content |
| 128 | + ] |
| 129 | + |
| 130 | + |
| 131 | +END_OF_AGENT = 'end_of_agent' |
| 132 | + |
| 133 | + |
| 134 | +# Extracts the contents from the events and transform them into a list of |
| 135 | +# (author, simplified_content OR AgentState OR "end_of_agent") tuples. |
| 136 | +# |
| 137 | +# Could be used to compare events for testing resumability. |
| 138 | +def simplify_resumable_app_events( |
| 139 | + events: list[Event], |
| 140 | +) -> list[(str, Union[types.Part, str])]: |
| 141 | + results = [] |
| 142 | + for event in events: |
| 143 | + if event.content: |
| 144 | + results.append((event.author, simplify_content(event.content))) |
| 145 | + elif event.actions.end_of_agent: |
| 146 | + results.append((event.author, END_OF_AGENT)) |
| 147 | + elif event.actions.agent_state is not None: |
| 148 | + results.append((event.author, event.actions.agent_state)) |
| 149 | + return results |
123 | 150 |
|
124 | 151 |
|
125 | 152 | # Simplifies the contents into a list of (author, simplified_content) tuples.
|
@@ -189,31 +216,52 @@ class InMemoryRunner:
|
189 | 216 |
|
190 | 217 | def __init__(
|
191 | 218 | self,
|
192 |
| - root_agent: Union[Agent, LlmAgent], |
| 219 | + root_agent: Optional[Union[Agent, LlmAgent]] = None, |
193 | 220 | response_modalities: list[str] = None,
|
194 | 221 | plugins: list[BasePlugin] = [],
|
| 222 | + app: Optional[App] = None, |
195 | 223 | ):
|
196 |
| - self.root_agent = root_agent |
197 |
| - self.runner = Runner( |
198 |
| - app_name='test_app', |
199 |
| - agent=root_agent, |
200 |
| - artifact_service=InMemoryArtifactService(), |
201 |
| - session_service=InMemorySessionService(), |
202 |
| - memory_service=InMemoryMemoryService(), |
203 |
| - plugins=plugins, |
204 |
| - ) |
| 224 | + """Initializes the InMemoryRunner. |
| 225 | +
|
| 226 | + Args: |
| 227 | + root_agent: The root agent to run, won't be used if app is provided. |
| 228 | + response_modalities: The response modalities of the runner. |
| 229 | + plugins: The plugins to use in the runner, won't be used if app is |
| 230 | + provided. |
| 231 | + app: The app to use in the runner. |
| 232 | + """ |
| 233 | + if not app: |
| 234 | + self.app_name = 'test_app' |
| 235 | + self.root_agent = root_agent |
| 236 | + self.runner = Runner( |
| 237 | + app_name='test_app', |
| 238 | + agent=root_agent, |
| 239 | + artifact_service=InMemoryArtifactService(), |
| 240 | + session_service=InMemorySessionService(), |
| 241 | + memory_service=InMemoryMemoryService(), |
| 242 | + plugins=plugins, |
| 243 | + ) |
| 244 | + else: |
| 245 | + self.app_name = app.name |
| 246 | + self.root_agent = app.root_agent |
| 247 | + self.runner = Runner( |
| 248 | + app=app, |
| 249 | + artifact_service=InMemoryArtifactService(), |
| 250 | + session_service=InMemorySessionService(), |
| 251 | + memory_service=InMemoryMemoryService(), |
| 252 | + ) |
205 | 253 | self.session_id = None
|
206 | 254 |
|
207 | 255 | @property
|
208 | 256 | def session(self) -> Session:
|
209 | 257 | if not self.session_id:
|
210 | 258 | session = self.runner.session_service.create_session_sync(
|
211 |
| - app_name='test_app', user_id='test_user' |
| 259 | + app_name=self.app_name, user_id='test_user' |
212 | 260 | )
|
213 | 261 | self.session_id = session.id
|
214 | 262 | return session
|
215 | 263 | return self.runner.session_service.get_session_sync(
|
216 |
| - app_name='test_app', user_id='test_user', session_id=self.session_id |
| 264 | + app_name=self.app_name, user_id='test_user', session_id=self.session_id |
217 | 265 | )
|
218 | 266 |
|
219 | 267 | def run(self, new_message: types.ContentUnion) -> list[Event]:
|
|
0 commit comments