Skip to content

Commit 772658f

Browse files
XinranTangcopybara-github
authored andcommitted
chore: Refactor runner run_async flow to extract out execution context setup logic
PiperOrigin-RevId: 812894540
1 parent 8e5f361 commit 772658f

File tree

1 file changed

+79
-21
lines changed

1 file changed

+79
-21
lines changed

src/google/adk/runners.py

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -306,30 +306,12 @@ async def _run_with_trace(
306306
if not session:
307307
raise ValueError(f'Session not found: {session_id}')
308308

309-
invocation_context = self._new_invocation_context(
310-
session,
309+
invocation_context = await self._setup_context_for_new_invocation(
310+
session=session,
311311
new_message=new_message,
312312
run_config=run_config,
313+
state_delta=state_delta,
313314
)
314-
root_agent = self.agent
315-
316-
# Modify user message before execution.
317-
modified_user_message = await invocation_context.plugin_manager.run_on_user_message_callback(
318-
invocation_context=invocation_context, user_message=new_message
319-
)
320-
if modified_user_message is not None:
321-
new_message = modified_user_message
322-
323-
if new_message:
324-
await self._append_new_message_to_session(
325-
session,
326-
new_message,
327-
invocation_context,
328-
run_config.save_input_blobs_as_artifacts,
329-
state_delta,
330-
)
331-
332-
invocation_context.agent = self._find_agent_to_run(session, root_agent)
333315

334316
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
335317
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
@@ -420,6 +402,7 @@ async def _exec_with_plugin(
420402

421403
async def _append_new_message_to_session(
422404
self,
405+
*,
423406
session: Session,
424407
new_message: types.Content,
425408
invocation_context: InvocationContext,
@@ -433,6 +416,7 @@ async def _append_new_message_to_session(
433416
new_message: The new message to append.
434417
invocation_context: The invocation context for the message.
435418
save_input_blobs_as_artifacts: Whether to save input blobs as artifacts.
419+
state_delta: Optional state changes to apply to the session.
436420
"""
437421
if not new_message.parts:
438422
raise ValueError('No parts in the new_message.')
@@ -661,6 +645,44 @@ def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
661645
agent = agent.parent_agent
662646
return True
663647

648+
async def _setup_context_for_new_invocation(
649+
self,
650+
*,
651+
session: Session,
652+
new_message: types.Content,
653+
run_config: RunConfig,
654+
state_delta: Optional[dict[str, Any]],
655+
) -> InvocationContext:
656+
"""Sets up the context for a new invocation.
657+
658+
Args:
659+
session: The session to setup the invocation context for.
660+
new_message: The new message to process and append to the session.
661+
run_config: The run config of the agent.
662+
state_delta: Optional state changes to apply to the session.
663+
664+
Returns:
665+
The invocation context for the new invocation.
666+
"""
667+
# Step 1: Create invocation context in memory.
668+
invocation_context = self._new_invocation_context(
669+
session,
670+
new_message=new_message,
671+
run_config=run_config,
672+
)
673+
# Step 2: Handle new message, by running callbacks and appending to
674+
# session.
675+
await self._handle_new_message(
676+
session=session,
677+
new_message=new_message,
678+
invocation_context=invocation_context,
679+
run_config=run_config,
680+
state_delta=state_delta,
681+
)
682+
# Step 3: Set agent to run for the invocation.
683+
invocation_context.agent = self._find_agent_to_run(session, self.agent)
684+
return invocation_context
685+
664686
def _new_invocation_context(
665687
self,
666688
session: Session,
@@ -743,6 +765,42 @@ def _new_invocation_context_for_live(
743765
run_config=run_config,
744766
)
745767

768+
async def _handle_new_message(
769+
self,
770+
*,
771+
session: Session,
772+
new_message: types.Content,
773+
invocation_context: InvocationContext,
774+
run_config: RunConfig,
775+
state_delta: Optional[dict[str, Any]],
776+
) -> None:
777+
"""Handles a new message by running callbacks and appending to session.
778+
779+
Args:
780+
session: The session of the new message.
781+
new_message: The new message to process and append to the session.
782+
invocation_context: The invocation context to use for the message
783+
handling.
784+
run_config: The run config of the agent.
785+
state_delta: Optional state changes to apply to the session.
786+
"""
787+
modified_user_message = (
788+
await invocation_context.plugin_manager.run_on_user_message_callback(
789+
invocation_context=invocation_context, user_message=new_message
790+
)
791+
)
792+
if modified_user_message is not None:
793+
new_message = modified_user_message
794+
795+
if new_message:
796+
await self._append_new_message_to_session(
797+
session=session,
798+
new_message=new_message,
799+
invocation_context=invocation_context,
800+
save_input_blobs_as_artifacts=run_config.save_input_blobs_as_artifacts,
801+
state_delta=state_delta,
802+
)
803+
746804
def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
747805
toolsets = set()
748806
if isinstance(agent, LlmAgent):

0 commit comments

Comments
 (0)