@@ -306,30 +306,12 @@ async def _run_with_trace(
306
306
if not session :
307
307
raise ValueError (f'Session not found: { session_id } ' )
308
308
309
- invocation_context = self ._new_invocation_context (
310
- session ,
309
+ invocation_context = await self ._setup_context_for_new_invocation (
310
+ session = session ,
311
311
new_message = new_message ,
312
312
run_config = run_config ,
313
+ state_delta = state_delta ,
313
314
)
314
- root_agent = self .agent
315
-
316
- # Modify user message before execution.
317
- modified_user_message = await invocation_context .plugin_manager .run_on_user_message_callback (
318
- invocation_context = invocation_context , user_message = new_message
319
- )
320
- if modified_user_message is not None :
321
- new_message = modified_user_message
322
-
323
- if new_message :
324
- await self ._append_new_message_to_session (
325
- session ,
326
- new_message ,
327
- invocation_context ,
328
- run_config .save_input_blobs_as_artifacts ,
329
- state_delta ,
330
- )
331
-
332
- invocation_context .agent = self ._find_agent_to_run (session , root_agent )
333
315
334
316
async def execute (ctx : InvocationContext ) -> AsyncGenerator [Event ]:
335
317
async with Aclosing (ctx .agent .run_async (ctx )) as agen :
@@ -420,6 +402,7 @@ async def _exec_with_plugin(
420
402
421
403
async def _append_new_message_to_session (
422
404
self ,
405
+ * ,
423
406
session : Session ,
424
407
new_message : types .Content ,
425
408
invocation_context : InvocationContext ,
@@ -433,6 +416,7 @@ async def _append_new_message_to_session(
433
416
new_message: The new message to append.
434
417
invocation_context: The invocation context for the message.
435
418
save_input_blobs_as_artifacts: Whether to save input blobs as artifacts.
419
+ state_delta: Optional state changes to apply to the session.
436
420
"""
437
421
if not new_message .parts :
438
422
raise ValueError ('No parts in the new_message.' )
@@ -661,6 +645,44 @@ def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
661
645
agent = agent .parent_agent
662
646
return True
663
647
648
+ async def _setup_context_for_new_invocation (
649
+ self ,
650
+ * ,
651
+ session : Session ,
652
+ new_message : types .Content ,
653
+ run_config : RunConfig ,
654
+ state_delta : Optional [dict [str , Any ]],
655
+ ) -> InvocationContext :
656
+ """Sets up the context for a new invocation.
657
+
658
+ Args:
659
+ session: The session to setup the invocation context for.
660
+ new_message: The new message to process and append to the session.
661
+ run_config: The run config of the agent.
662
+ state_delta: Optional state changes to apply to the session.
663
+
664
+ Returns:
665
+ The invocation context for the new invocation.
666
+ """
667
+ # Step 1: Create invocation context in memory.
668
+ invocation_context = self ._new_invocation_context (
669
+ session ,
670
+ new_message = new_message ,
671
+ run_config = run_config ,
672
+ )
673
+ # Step 2: Handle new message, by running callbacks and appending to
674
+ # session.
675
+ await self ._handle_new_message (
676
+ session = session ,
677
+ new_message = new_message ,
678
+ invocation_context = invocation_context ,
679
+ run_config = run_config ,
680
+ state_delta = state_delta ,
681
+ )
682
+ # Step 3: Set agent to run for the invocation.
683
+ invocation_context .agent = self ._find_agent_to_run (session , self .agent )
684
+ return invocation_context
685
+
664
686
def _new_invocation_context (
665
687
self ,
666
688
session : Session ,
@@ -743,6 +765,42 @@ def _new_invocation_context_for_live(
743
765
run_config = run_config ,
744
766
)
745
767
768
+ async def _handle_new_message (
769
+ self ,
770
+ * ,
771
+ session : Session ,
772
+ new_message : types .Content ,
773
+ invocation_context : InvocationContext ,
774
+ run_config : RunConfig ,
775
+ state_delta : Optional [dict [str , Any ]],
776
+ ) -> None :
777
+ """Handles a new message by running callbacks and appending to session.
778
+
779
+ Args:
780
+ session: The session of the new message.
781
+ new_message: The new message to process and append to the session.
782
+ invocation_context: The invocation context to use for the message
783
+ handling.
784
+ run_config: The run config of the agent.
785
+ state_delta: Optional state changes to apply to the session.
786
+ """
787
+ modified_user_message = (
788
+ await invocation_context .plugin_manager .run_on_user_message_callback (
789
+ invocation_context = invocation_context , user_message = new_message
790
+ )
791
+ )
792
+ if modified_user_message is not None :
793
+ new_message = modified_user_message
794
+
795
+ if new_message :
796
+ await self ._append_new_message_to_session (
797
+ session = session ,
798
+ new_message = new_message ,
799
+ invocation_context = invocation_context ,
800
+ save_input_blobs_as_artifacts = run_config .save_input_blobs_as_artifacts ,
801
+ state_delta = state_delta ,
802
+ )
803
+
746
804
def _collect_toolset (self , agent : BaseAgent ) -> set [BaseToolset ]:
747
805
toolsets = set ()
748
806
if isinstance (agent , LlmAgent ):
0 commit comments