29
29
30
30
from .agents .active_streaming_tool import ActiveStreamingTool
31
31
from .agents .base_agent import BaseAgent
32
+ from .agents .base_agent import BaseAgentState
32
33
from .agents .context_cache_config import ContextCacheConfig
33
34
from .agents .invocation_context import InvocationContext
34
35
from .agents .invocation_context import new_invocation_context_id
@@ -272,7 +273,8 @@ async def run_async(
272
273
* ,
273
274
user_id : str ,
274
275
session_id : str ,
275
- new_message : types .Content ,
276
+ invocation_id : Optional [str ] = None ,
277
+ new_message : Optional [types .Content ] = None ,
276
278
state_delta : Optional [dict [str , Any ]] = None ,
277
279
run_config : Optional [RunConfig ] = None ,
278
280
) -> AsyncGenerator [Event , None ]:
@@ -281,6 +283,8 @@ async def run_async(
281
283
Args:
282
284
user_id: The user ID of the session.
283
285
session_id: The session ID of the session.
286
+ invocation_id: The invocation ID of the session, set this to resume an
287
+ interrupted invocation.
284
288
new_message: A new message to append to the session.
285
289
state_delta: Optional state changes to apply to the session.
286
290
run_config: The run config for the agent.
@@ -289,29 +293,57 @@ async def run_async(
289
293
The events generated by the agent.
290
294
291
295
Raises:
292
- ValueError: If the session is not found.
296
+ ValueError: If the session is not found; If both invocation_id and
297
+ new_message are None.
293
298
"""
294
299
run_config = run_config or RunConfig ()
295
300
296
- if not new_message .role :
301
+ if new_message and not new_message .role :
297
302
new_message .role = 'user'
298
303
299
304
async def _run_with_trace (
300
- new_message : types .Content ,
305
+ new_message : Optional [types .Content ] = None ,
306
+ invocation_id : Optional [str ] = None ,
301
307
) -> AsyncGenerator [Event , None ]:
302
308
with tracer .start_as_current_span ('invocation' ):
303
309
session = await self .session_service .get_session (
304
310
app_name = self .app_name , user_id = user_id , session_id = session_id
305
311
)
306
312
if not session :
307
313
raise ValueError (f'Session not found: { session_id } ' )
308
-
309
- invocation_context = await self ._setup_context_for_new_invocation (
310
- session = session ,
311
- new_message = new_message ,
312
- run_config = run_config ,
313
- state_delta = state_delta ,
314
- )
314
+ if not invocation_id and not new_message :
315
+ raise ValueError ('Both invocation_id and new_message are None.' )
316
+
317
+ if invocation_id :
318
+ if (
319
+ not self .resumability_config
320
+ or not self .resumability_config .is_resumable
321
+ ):
322
+ raise ValueError (
323
+ f'invocation_id: { invocation_id } is provided but the app is not'
324
+ ' resumable.'
325
+ )
326
+ invocation_context = await self ._setup_context_for_resumed_invocation (
327
+ session = session ,
328
+ new_message = new_message ,
329
+ invocation_id = invocation_id ,
330
+ run_config = run_config ,
331
+ state_delta = state_delta ,
332
+ )
333
+ if invocation_context .end_of_agents .get (self .agent .name ):
334
+ # Directly return if the root agent has already ended.
335
+ # TODO: Handle the case where the invocation-to-resume started from
336
+ # a sub_agent:
337
+ # invocation1: root_agent -> sub_agent1
338
+ # invocation2: sub_agent1 [paused][resume]
339
+ return
340
+ else :
341
+ invocation_context = await self ._setup_context_for_new_invocation (
342
+ session = session ,
343
+ new_message = new_message , # new_message is not None.
344
+ run_config = run_config ,
345
+ state_delta = state_delta ,
346
+ )
315
347
316
348
async def execute (ctx : InvocationContext ) -> AsyncGenerator [Event ]:
317
349
async with Aclosing (ctx .agent .run_async (ctx )) as agen :
@@ -329,7 +361,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
329
361
async for event in agen :
330
362
yield event
331
363
332
- async with Aclosing (_run_with_trace (new_message )) as agen :
364
+ async with Aclosing (_run_with_trace (new_message , invocation_id )) as agen :
333
365
async for event in agen :
334
366
yield event
335
367
@@ -462,6 +494,11 @@ async def _append_new_message_to_session(
462
494
author = 'user' ,
463
495
content = new_message ,
464
496
)
497
+ # If new_message is a function response, find the matching function call
498
+ # and use its branch as the new event's branch.
499
+ if function_call := invocation_context ._find_matching_function_call (event ):
500
+ event .branch = function_call .branch
501
+
465
502
await self .session_service .append_event (session = session , event = event )
466
503
467
504
async def run_live (
@@ -692,10 +729,82 @@ async def _setup_context_for_new_invocation(
692
729
invocation_context .agent = self ._find_agent_to_run (session , self .agent )
693
730
return invocation_context
694
731
732
+ async def _setup_context_for_resumed_invocation (
733
+ self ,
734
+ * ,
735
+ session : Session ,
736
+ new_message : Optional [types .Content ],
737
+ invocation_id : Optional [str ],
738
+ run_config : RunConfig ,
739
+ state_delta : Optional [dict [str , Any ]],
740
+ ) -> InvocationContext :
741
+ """Sets up the context for a resumed invocation.
742
+
743
+ Args:
744
+ session: The session to setup the invocation context for.
745
+ new_message: The new message to process and append to the session.
746
+ invocation_id: The invocation id to resume.
747
+ run_config: The run config of the agent.
748
+ state_delta: Optional state changes to apply to the session.
749
+
750
+ Returns:
751
+ The invocation context for the resumed invocation.
752
+
753
+ Raises:
754
+ ValueError: If the session has no events to resume; If no user message is
755
+ available for resuming the invocation; Or if the app is not resumable.
756
+ """
757
+ if not session .events :
758
+ raise ValueError (f'Session { session .id } has no events to resume.' )
759
+
760
+ # Step 1: Maybe retrive a previous user message for the invocation.
761
+ user_message = new_message or self ._find_user_message_for_invocation (
762
+ session .events , invocation_id
763
+ )
764
+ if not user_message :
765
+ raise ValueError (
766
+ f'No user message available for resuming invocation: { invocation_id } '
767
+ )
768
+ # Step 2: Create invocation context.
769
+ invocation_context = self ._new_invocation_context (
770
+ session ,
771
+ new_message = user_message ,
772
+ run_config = run_config ,
773
+ invocation_id = invocation_id ,
774
+ )
775
+ # Step 3: Maybe handle new message.
776
+ if new_message :
777
+ await self ._handle_new_message (
778
+ session = session ,
779
+ new_message = user_message ,
780
+ invocation_context = invocation_context ,
781
+ run_config = run_config ,
782
+ state_delta = state_delta ,
783
+ )
784
+ # Step 4: Populate agent states for the current invocation.
785
+ invocation_context .populate_invocation_agent_states ()
786
+ return invocation_context
787
+
788
+ def _find_user_message_for_invocation (
789
+ self , events : list [Event ], invocation_id : str
790
+ ) -> Optional [types .Content ]:
791
+ """Finds the user message that started a specific invocation."""
792
+ for event in events :
793
+ if (
794
+ event .invocation_id == invocation_id
795
+ and event .author == 'user'
796
+ and event .content
797
+ and event .content .parts
798
+ and event .content .parts [0 ].text
799
+ ):
800
+ return event .content
801
+ return None
802
+
695
803
def _new_invocation_context (
696
804
self ,
697
805
session : Session ,
698
806
* ,
807
+ invocation_id : Optional [str ] = None ,
699
808
new_message : Optional [types .Content ] = None ,
700
809
live_request_queue : Optional [LiveRequestQueue ] = None ,
701
810
run_config : Optional [RunConfig ] = None ,
@@ -704,6 +813,7 @@ def _new_invocation_context(
704
813
705
814
Args:
706
815
session: The session for the context.
816
+ invocation_id: The invocation id for the context.
707
817
new_message: The new message for the context.
708
818
live_request_queue: The live request queue for the context.
709
819
run_config: The run config for the context.
@@ -712,7 +822,7 @@ def _new_invocation_context(
712
822
The new invocation context.
713
823
"""
714
824
run_config = run_config or RunConfig ()
715
- invocation_id = new_invocation_context_id ()
825
+ invocation_id = invocation_id or new_invocation_context_id ()
716
826
717
827
if run_config .support_cfc and isinstance (self .agent , LlmAgent ):
718
828
model_name = self .agent .canonical_model .model
0 commit comments