Skip to content

Commit f005414

Browse files
XinranTangcopybara-github
authored andcommitted
feat: Make resumable llm agents yield checkpoint events
PiperOrigin-RevId: 813001108
1 parent 609a235 commit f005414

File tree

10 files changed

+1156
-229
lines changed

10 files changed

+1156
-229
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,10 @@ def app_name(self) -> str:
242242
def user_id(self) -> str:
243243
return self.session.user_id
244244

245-
def get_events(
245+
# TODO: Move this method from invocation_context to a dedicated module.
246+
def _get_events(
246247
self,
248+
*,
247249
current_invocation: bool = False,
248250
current_branch: bool = False,
249251
) -> list[Event]:
@@ -304,6 +306,25 @@ def should_pause_invocation(self, event: Event) -> bool:
304306

305307
return False
306308

309+
# TODO: Move this method from invocation_context to a dedicated module.
310+
# TODO: Converge this method with find_matching_function_call in llm_flows.
311+
def _find_matching_function_call(
312+
self, function_response_event: Event
313+
) -> Optional[Event]:
314+
"""Finds the function call event in the current invocation that matches the function response id."""
315+
function_responses = function_response_event.get_function_responses()
316+
if not function_responses:
317+
return None
318+
function_call_id = function_responses[0].id
319+
320+
events = self._get_events(current_invocation=True)
321+
# The last event is function_response_event, so we search backwards from the
322+
# one before it.
323+
for event in reversed(events[:-1]):
324+
if any(fc.id == function_call_id for fc in event.get_function_calls()):
325+
return event
326+
return None
327+
307328

308329
def new_invocation_context_id() -> str:
309330
return "e-" + str(uuid.uuid4())

src/google/adk/agents/llm_agent.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from ..utils.context_utils import Aclosing
5555
from ..utils.feature_decorator import experimental
5656
from .base_agent import BaseAgent
57+
from .base_agent import BaseAgentState
5758
from .base_agent_config import BaseAgentConfig
5859
from .callback_context import CallbackContext
5960
from .invocation_context import InvocationContext
@@ -337,13 +338,30 @@ class LlmAgent(BaseAgent):
337338
async def _run_async_impl(
338339
self, ctx: InvocationContext
339340
) -> AsyncGenerator[Event, None]:
341+
agent_state = self._load_agent_state(ctx, BaseAgentState)
342+
343+
# If there is an sub-agent to resume, run it and then end the current
344+
# agent.
345+
if agent_state is not None and (
346+
agent_to_transfer := self._get_subagent_to_resume(ctx)
347+
):
348+
async with Aclosing(agent_to_transfer.run_async(ctx)) as agen:
349+
async for event in agen:
350+
yield event
351+
352+
yield self._create_agent_state_event(ctx, end_of_agent=True)
353+
return
354+
340355
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
341356
async for event in agen:
342357
self.__maybe_save_output_to_state(event)
343358
yield event
344359
if ctx.should_pause_invocation(event):
345360
return
346361

362+
if ctx.is_resumable:
363+
yield self._create_agent_state_event(ctx, end_of_agent=True)
364+
347365
@override
348366
async def _run_live_impl(
349367
self, ctx: InvocationContext
@@ -498,6 +516,74 @@ def _llm_flow(self) -> BaseLlmFlow:
498516
else:
499517
return AutoFlow()
500518

519+
def _get_subagent_to_resume(
520+
self, ctx: InvocationContext
521+
) -> Optional[BaseAgent]:
522+
"""Returns the sub-agent in the llm tree to resume if it exists.
523+
524+
There are 2 cases where we need to transfer to and resume a sub-agent:
525+
1. The last event is a transfer to agent response from the current agent.
526+
In this case, we need to return the agent specified in the response.
527+
528+
2. The last event's author isn't the current agent, or the user is
529+
responding to another agent's tool call.
530+
In this case, we need to return the LAST agent being transferred to
531+
from the current agent.
532+
"""
533+
events = ctx._get_events(current_invocation=True, current_branch=True)
534+
if not events:
535+
return None
536+
537+
last_event = events[-1]
538+
if last_event.author == self.name:
539+
# Last event is from current agent. Return transfer_to_agent in the event
540+
# if it exists, or None.
541+
return self.__get_transfer_to_agent_or_none(last_event, self.name)
542+
543+
# Last event is from user or another agent.
544+
if last_event.author == 'user':
545+
function_call_event = ctx._find_matching_function_call(last_event)
546+
if not function_call_event:
547+
raise ValueError(
548+
'No agent to transfer to for resuming agent from function response'
549+
f' {self.name}'
550+
)
551+
if function_call_event.author == self.name:
552+
# User is responding to a tool call from the current agent.
553+
# Current agent should continue, so no sub-agent to resume.
554+
return None
555+
556+
# Last event is from another agent, or from user for another agent's tool
557+
# call. We need to find the last agent we transferred to.
558+
for event in reversed(events):
559+
if agent := self.__get_transfer_to_agent_or_none(event, self.name):
560+
return agent
561+
562+
return None
563+
564+
def __get_agent_to_run(self, agent_name: str) -> BaseAgent:
565+
"""Find the agent to run under the root agent by name."""
566+
agent_to_run = self.root_agent.find_agent(agent_name)
567+
if not agent_to_run:
568+
raise ValueError(f'Agent {agent_name} not found in the agent tree.')
569+
return agent_to_run
570+
571+
def __get_transfer_to_agent_or_none(
572+
self, event: Event, from_agent: str
573+
) -> Optional[BaseAgent]:
574+
"""Returns the agent to run if the event is a transfer to agent response."""
575+
function_responses = event.get_function_responses()
576+
if not function_responses:
577+
return None
578+
for function_response in function_responses:
579+
if (
580+
function_response.name == 'transfer_to_agent'
581+
and event.author == from_agent
582+
and event.actions.transfer_to_agent != from_agent
583+
):
584+
return self.__get_agent_to_run(event.actions.transfer_to_agent)
585+
return None
586+
501587
def __maybe_save_output_to_state(self, event: Event):
502588
"""Saves the model output to state if needed."""
503589
# skip if the event was authored by some other agent (e.g. current agent

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,28 @@ async def _run_one_step_async(
376376
if invocation_context.end_invocation:
377377
return
378378

379+
# Resume the LLM agent based on the last event from the current branch.
380+
# 1. User content: continue the normal flow
381+
# 2. Function call: call the tool and get the response event.
382+
events = invocation_context._get_events(
383+
current_invocation=True, current_branch=True
384+
)
385+
if (
386+
invocation_context.is_resumable
387+
and events
388+
and events[-1].get_function_calls()
389+
):
390+
model_response_event = events[-1]
391+
async with Aclosing(
392+
self._postprocess_handle_function_calls_async(
393+
invocation_context, model_response_event, llm_request
394+
)
395+
) as agen:
396+
async for event in agen:
397+
event.id = Event.new_id()
398+
yield event
399+
return
400+
379401
# Calls the LLM.
380402
model_response_event = Event(
381403
id=Event.new_id(),

src/google/adk/flows/llm_flows/contents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def _rearrange_events_for_latest_function_response(
135135
Returns:
136136
A list of events with the latest function_response rearranged.
137137
"""
138-
if not events:
138+
if len(events) < 2:
139+
# No need to process, since there is no function_call.
139140
return events
140141

141142
function_responses = events[-1].get_function_responses()

src/google/adk/runners.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,16 @@ def _find_agent_to_run(
606606
event = find_matching_function_call(session.events)
607607
if event and event.author:
608608
return root_agent.find_agent(event.author)
609-
for event in filter(lambda e: e.author != 'user', reversed(session.events)):
609+
610+
def _event_filter(event: Event) -> bool:
611+
"""Filters out user-authored events and agent state change events."""
612+
if event.author == 'user':
613+
return False
614+
if event.actions.agent_state is not None or event.actions.end_of_agent:
615+
return False
616+
return True
617+
618+
for event in filter(_event_filter, reversed(session.events)):
610619
if event.author == root_agent.name:
611620
# Found root agent.
612621
return root_agent

tests/unittests/agents/test_invocation_context.py

Lines changed: 120 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.adk.events.event import Event
2121
from google.adk.sessions.base_session_service import BaseSessionService
2222
from google.adk.sessions.session import Session
23+
from google.genai.types import Content
2324
from google.genai.types import FunctionCall
2425
from google.genai.types import Part
2526
import pytest
@@ -67,31 +68,31 @@ def test_get_events_returns_all_events_by_default(
6768
self, mock_invocation_context, mock_events
6869
):
6970
"""Tests that get_events returns all events when no filters are applied."""
70-
events = mock_invocation_context.get_events()
71+
events = mock_invocation_context._get_events()
7172
assert events == mock_events
7273

7374
def test_get_events_filters_by_current_invocation(
7475
self, mock_invocation_context, mock_events
7576
):
7677
"""Tests that get_events correctly filters by the current invocation."""
7778
event1, event2, _, _ = mock_events
78-
events = mock_invocation_context.get_events(current_invocation=True)
79+
events = mock_invocation_context._get_events(current_invocation=True)
7980
assert events == [event1, event2]
8081

8182
def test_get_events_filters_by_current_branch(
8283
self, mock_invocation_context, mock_events
8384
):
8485
"""Tests that get_events correctly filters by the current branch."""
8586
event1, _, event3, _ = mock_events
86-
events = mock_invocation_context.get_events(current_branch=True)
87+
events = mock_invocation_context._get_events(current_branch=True)
8788
assert events == [event1, event3]
8889

8990
def test_get_events_filters_by_invocation_and_branch(
9091
self, mock_invocation_context, mock_events
9192
):
9293
"""Tests that get_events filters by invocation and branch."""
9394
event1, _, _, _ = mock_events
94-
events = mock_invocation_context.get_events(
95+
events = mock_invocation_context._get_events(
9596
current_invocation=True,
9697
current_branch=True,
9798
)
@@ -100,7 +101,7 @@ def test_get_events_filters_by_invocation_and_branch(
100101
def test_get_events_with_no_events_in_session(self, mock_invocation_context):
101102
"""Tests get_events when the session has no events."""
102103
mock_invocation_context.session.events = []
103-
events = mock_invocation_context.get_events()
104+
events = mock_invocation_context._get_events()
104105
assert not events
105106

106107
def test_get_events_with_no_matching_events(self, mock_invocation_context):
@@ -109,15 +110,15 @@ def test_get_events_with_no_matching_events(self, mock_invocation_context):
109110
mock_invocation_context.branch = 'branch_C'
110111

111112
# Filter by invocation
112-
events = mock_invocation_context.get_events(current_invocation=True)
113+
events = mock_invocation_context._get_events(current_invocation=True)
113114
assert not events
114115

115116
# Filter by branch
116-
events = mock_invocation_context.get_events(current_branch=True)
117+
events = mock_invocation_context._get_events(current_branch=True)
117118
assert not events
118119

119120
# Filter by both
120-
events = mock_invocation_context.get_events(
121+
events = mock_invocation_context._get_events(
121122
current_invocation=True,
122123
current_branch=True,
123124
)
@@ -225,3 +226,114 @@ def test_is_resumable_no_config(self):
225226
"""Tests that is_resumable is False when no resumability config is set."""
226227
invocation_context = self._create_test_invocation_context(None)
227228
assert not invocation_context.is_resumable
229+
230+
231+
class TestFindMatchingFunctionCall:
232+
"""Test suite for find_matching_function_call."""
233+
234+
@pytest.fixture
235+
def test_invocation_context(self):
236+
"""Create a mock invocation context for testing."""
237+
238+
def _create_invocation_context(events):
239+
return InvocationContext(
240+
session_service=Mock(spec=BaseSessionService),
241+
agent=Mock(spec=BaseAgent, name='agent'),
242+
invocation_id='inv_1',
243+
session=Mock(spec=Session, events=events),
244+
)
245+
246+
return _create_invocation_context
247+
248+
def test_find_matching_function_call_found(self, test_invocation_context):
249+
"""Tests that a matching function call is found."""
250+
fc = Part.from_function_call(name='some_tool', args={})
251+
fc.function_call.id = 'test_function_call_id'
252+
fc_event = Event(
253+
invocation_id='inv_1',
254+
author='agent',
255+
content=testing_utils.ModelContent([fc]),
256+
)
257+
fr = Part.from_function_response(
258+
name='some_tool', response={'result': 'ok'}
259+
)
260+
fr.function_response.id = 'test_function_call_id'
261+
fr_event = Event(
262+
invocation_id='inv_1',
263+
author='agent',
264+
content=Content(role='user', parts=[fr]),
265+
)
266+
invocation_context = test_invocation_context([fc_event, fr_event])
267+
matching_fc_event = invocation_context._find_matching_function_call(
268+
fr_event
269+
)
270+
assert testing_utils.simplify_content(
271+
matching_fc_event.content
272+
) == testing_utils.simplify_content(fc_event.content)
273+
274+
def test_find_matching_function_call_not_found(self, test_invocation_context):
275+
"""Tests that no matching function call is returned if id doesn't match."""
276+
fc = Part.from_function_call(name='some_tool', args={})
277+
fc.function_call.id = 'another_function_call_id'
278+
fc_event = Event(
279+
invocation_id='inv_1',
280+
author='agent',
281+
content=testing_utils.ModelContent([fc]),
282+
)
283+
fr = Part.from_function_response(
284+
name='some_tool', response={'result': 'ok'}
285+
)
286+
fr.function_response.id = 'test_function_call_id'
287+
fr_event = Event(
288+
invocation_id='inv_1',
289+
author='agent',
290+
content=Content(role='user', parts=[fr]),
291+
)
292+
invocation_context = test_invocation_context([fc_event, fr_event])
293+
match = invocation_context._find_matching_function_call(fr_event)
294+
assert match is None
295+
296+
def test_find_matching_function_call_no_call_events(
297+
self, test_invocation_context
298+
):
299+
"""Tests that no matching function call is returned if there are no call events."""
300+
fr = Part.from_function_response(
301+
name='some_tool', response={'result': 'ok'}
302+
)
303+
fr.function_response.id = 'test_function_call_id'
304+
fr_event = Event(
305+
invocation_id='inv_1',
306+
author='agent',
307+
content=Content(role='user', parts=[fr]),
308+
)
309+
invocation_context = test_invocation_context([fr_event])
310+
match = invocation_context._find_matching_function_call(fr_event)
311+
assert match is None
312+
313+
def test_find_matching_function_call_no_response_in_event(
314+
self, test_invocation_context
315+
):
316+
"""Tests result is None if function_response_event has no function response."""
317+
fr_event_no_fr = Event(
318+
author='agent',
319+
content=Content(role='user', parts=[Part(text='user message')]),
320+
)
321+
fc = Part.from_function_call(name='some_tool', args={})
322+
fc.function_call.id = 'test_function_call_id'
323+
fc_event = Event(
324+
invocation_id='inv_1',
325+
author='agent',
326+
content=testing_utils.ModelContent([fc]),
327+
)
328+
fr = Part.from_function_response(
329+
name='some_tool', response={'result': 'ok'}
330+
)
331+
fr.function_response.id = 'test_function_call_id'
332+
fr_event = Event(
333+
invocation_id='inv_1',
334+
author='agent',
335+
content=Content(role='user', parts=[Part(text='user message')]),
336+
)
337+
invocation_context = test_invocation_context([fc_event, fr_event])
338+
match = invocation_context._find_matching_function_call(fr_event_no_fr)
339+
assert match is None

0 commit comments

Comments
 (0)