Skip to content

Commit fbf7576

Browse files
XinranTangcopybara-github
authored andcommitted
feat: Modify runner to support resuming an invocation (optionally with a function response)
PiperOrigin-RevId: 813008406
1 parent f005414 commit fbf7576

File tree

5 files changed

+474
-18
lines changed

5 files changed

+474
-18
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..sessions.session import Session
3535
from .active_streaming_tool import ActiveStreamingTool
3636
from .base_agent import BaseAgent
37+
from .base_agent import BaseAgentState
3738
from .context_cache_config import ContextCacheConfig
3839
from .live_request_queue import LiveRequestQueue
3940
from .run_config import RunConfig
@@ -221,6 +222,37 @@ def reset_agent_state(self, agent_name: str) -> None:
221222
self.agent_states.pop(agent_name, None)
222223
self.end_of_agents.pop(agent_name, None)
223224

225+
def populate_invocation_agent_states(self) -> None:
226+
"""Populates agent states for the current invocation if it is resumable.
227+
228+
For history events that contain agent state information, set the
229+
agent_state and end_of_agent of the agent that generated the event.
230+
231+
For non-workflow agents, also set an initial agent_state if it has
232+
already generated some contents.
233+
"""
234+
if not self.is_resumable:
235+
return
236+
for event in self._get_events(current_invocation=True):
237+
if event.actions.end_of_agent:
238+
self.end_of_agents[event.author] = True
239+
# Delete agent_state when it is end
240+
self.agent_states.pop(event.author, None)
241+
elif event.actions.agent_state is not None:
242+
self.agent_states[event.author] = event.actions.agent_state
243+
# Invalidate the end_of_agent flag
244+
self.end_of_agents[event.author] = False
245+
elif (
246+
event.author != "user"
247+
and event.content
248+
and not self.agent_states.get(event.author)
249+
):
250+
# If the agent has generated some contents but its agent_state is not
251+
# set, set its agent_state to an empty agent_state.
252+
self.agent_states[event.author] = BaseAgentState()
253+
# Invalidate the end_of_agent flag
254+
self.end_of_agents[event.author] = False
255+
224256
def increment_llm_call_count(
225257
self,
226258
):

src/google/adk/runners.py

Lines changed: 123 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from .agents.active_streaming_tool import ActiveStreamingTool
3131
from .agents.base_agent import BaseAgent
32+
from .agents.base_agent import BaseAgentState
3233
from .agents.context_cache_config import ContextCacheConfig
3334
from .agents.invocation_context import InvocationContext
3435
from .agents.invocation_context import new_invocation_context_id
@@ -272,7 +273,8 @@ async def run_async(
272273
*,
273274
user_id: str,
274275
session_id: str,
275-
new_message: types.Content,
276+
invocation_id: Optional[str] = None,
277+
new_message: Optional[types.Content] = None,
276278
state_delta: Optional[dict[str, Any]] = None,
277279
run_config: Optional[RunConfig] = None,
278280
) -> AsyncGenerator[Event, None]:
@@ -281,6 +283,8 @@ async def run_async(
281283
Args:
282284
user_id: The user ID of the session.
283285
session_id: The session ID of the session.
286+
invocation_id: The invocation ID of the session, set this to resume an
287+
interrupted invocation.
284288
new_message: A new message to append to the session.
285289
state_delta: Optional state changes to apply to the session.
286290
run_config: The run config for the agent.
@@ -289,29 +293,57 @@ async def run_async(
289293
The events generated by the agent.
290294
291295
Raises:
292-
ValueError: If the session is not found.
296+
ValueError: If the session is not found; If both invocation_id and
297+
new_message are None.
293298
"""
294299
run_config = run_config or RunConfig()
295300

296-
if not new_message.role:
301+
if new_message and not new_message.role:
297302
new_message.role = 'user'
298303

299304
async def _run_with_trace(
300-
new_message: types.Content,
305+
new_message: Optional[types.Content] = None,
306+
invocation_id: Optional[str] = None,
301307
) -> AsyncGenerator[Event, None]:
302308
with tracer.start_as_current_span('invocation'):
303309
session = await self.session_service.get_session(
304310
app_name=self.app_name, user_id=user_id, session_id=session_id
305311
)
306312
if not session:
307313
raise ValueError(f'Session not found: {session_id}')
308-
309-
invocation_context = await self._setup_context_for_new_invocation(
310-
session=session,
311-
new_message=new_message,
312-
run_config=run_config,
313-
state_delta=state_delta,
314-
)
314+
if not invocation_id and not new_message:
315+
raise ValueError('Both invocation_id and new_message are None.')
316+
317+
if invocation_id:
318+
if (
319+
not self.resumability_config
320+
or not self.resumability_config.is_resumable
321+
):
322+
raise ValueError(
323+
f'invocation_id: {invocation_id} is provided but the app is not'
324+
' resumable.'
325+
)
326+
invocation_context = await self._setup_context_for_resumed_invocation(
327+
session=session,
328+
new_message=new_message,
329+
invocation_id=invocation_id,
330+
run_config=run_config,
331+
state_delta=state_delta,
332+
)
333+
if invocation_context.end_of_agents.get(self.agent.name):
334+
# Directly return if the root agent has already ended.
335+
# TODO: Handle the case where the invocation-to-resume started from
336+
# a sub_agent:
337+
# invocation1: root_agent -> sub_agent1
338+
# invocation2: sub_agent1 [paused][resume]
339+
return
340+
else:
341+
invocation_context = await self._setup_context_for_new_invocation(
342+
session=session,
343+
new_message=new_message, # new_message is not None.
344+
run_config=run_config,
345+
state_delta=state_delta,
346+
)
315347

316348
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
317349
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
@@ -329,7 +361,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
329361
async for event in agen:
330362
yield event
331363

332-
async with Aclosing(_run_with_trace(new_message)) as agen:
364+
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
333365
async for event in agen:
334366
yield event
335367

@@ -462,6 +494,11 @@ async def _append_new_message_to_session(
462494
author='user',
463495
content=new_message,
464496
)
497+
# If new_message is a function response, find the matching function call
498+
# and use its branch as the new event's branch.
499+
if function_call := invocation_context._find_matching_function_call(event):
500+
event.branch = function_call.branch
501+
465502
await self.session_service.append_event(session=session, event=event)
466503

467504
async def run_live(
@@ -692,10 +729,82 @@ async def _setup_context_for_new_invocation(
692729
invocation_context.agent = self._find_agent_to_run(session, self.agent)
693730
return invocation_context
694731

732+
async def _setup_context_for_resumed_invocation(
733+
self,
734+
*,
735+
session: Session,
736+
new_message: Optional[types.Content],
737+
invocation_id: Optional[str],
738+
run_config: RunConfig,
739+
state_delta: Optional[dict[str, Any]],
740+
) -> InvocationContext:
741+
"""Sets up the context for a resumed invocation.
742+
743+
Args:
744+
session: The session to setup the invocation context for.
745+
new_message: The new message to process and append to the session.
746+
invocation_id: The invocation id to resume.
747+
run_config: The run config of the agent.
748+
state_delta: Optional state changes to apply to the session.
749+
750+
Returns:
751+
The invocation context for the resumed invocation.
752+
753+
Raises:
754+
ValueError: If the session has no events to resume; If no user message is
755+
available for resuming the invocation; Or if the app is not resumable.
756+
"""
757+
if not session.events:
758+
raise ValueError(f'Session {session.id} has no events to resume.')
759+
760+
# Step 1: Maybe retrive a previous user message for the invocation.
761+
user_message = new_message or self._find_user_message_for_invocation(
762+
session.events, invocation_id
763+
)
764+
if not user_message:
765+
raise ValueError(
766+
f'No user message available for resuming invocation: {invocation_id}'
767+
)
768+
# Step 2: Create invocation context.
769+
invocation_context = self._new_invocation_context(
770+
session,
771+
new_message=user_message,
772+
run_config=run_config,
773+
invocation_id=invocation_id,
774+
)
775+
# Step 3: Maybe handle new message.
776+
if new_message:
777+
await self._handle_new_message(
778+
session=session,
779+
new_message=user_message,
780+
invocation_context=invocation_context,
781+
run_config=run_config,
782+
state_delta=state_delta,
783+
)
784+
# Step 4: Populate agent states for the current invocation.
785+
invocation_context.populate_invocation_agent_states()
786+
return invocation_context
787+
788+
def _find_user_message_for_invocation(
789+
self, events: list[Event], invocation_id: str
790+
) -> Optional[types.Content]:
791+
"""Finds the user message that started a specific invocation."""
792+
for event in events:
793+
if (
794+
event.invocation_id == invocation_id
795+
and event.author == 'user'
796+
and event.content
797+
and event.content.parts
798+
and event.content.parts[0].text
799+
):
800+
return event.content
801+
return None
802+
695803
def _new_invocation_context(
696804
self,
697805
session: Session,
698806
*,
807+
invocation_id: Optional[str] = None,
699808
new_message: Optional[types.Content] = None,
700809
live_request_queue: Optional[LiveRequestQueue] = None,
701810
run_config: Optional[RunConfig] = None,
@@ -704,6 +813,7 @@ def _new_invocation_context(
704813
705814
Args:
706815
session: The session for the context.
816+
invocation_id: The invocation id for the context.
707817
new_message: The new message for the context.
708818
live_request_queue: The live request queue for the context.
709819
run_config: The run config for the context.
@@ -712,7 +822,7 @@ def _new_invocation_context(
712822
The new invocation context.
713823
"""
714824
run_config = run_config or RunConfig()
715-
invocation_id = new_invocation_context_id()
825+
invocation_id = invocation_id or new_invocation_context_id()
716826

717827
if run_config.support_cfc and isinstance(self.agent, LlmAgent):
718828
model_name = self.agent.canonical_model.model

tests/unittests/agents/test_invocation_context.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from unittest.mock import Mock
1616

1717
from google.adk.agents.base_agent import BaseAgent
18+
from google.adk.agents.base_agent import BaseAgentState
1819
from google.adk.agents.invocation_context import InvocationContext
1920
from google.adk.apps import ResumabilityConfig
2021
from google.adk.events.event import Event
22+
from google.adk.events.event_actions import EventActions
2123
from google.adk.sessions.base_session_service import BaseSessionService
2224
from google.adk.sessions.session import Session
2325
from google.genai.types import Content
@@ -227,6 +229,124 @@ def test_is_resumable_no_config(self):
227229
invocation_context = self._create_test_invocation_context(None)
228230
assert not invocation_context.is_resumable
229231

232+
def test_populate_invocation_agent_states_not_resumable(self):
233+
"""Tests that populate_invocation_agent_states does nothing if not resumable."""
234+
invocation_context = self._create_test_invocation_context(
235+
ResumabilityConfig(is_resumable=False)
236+
)
237+
event = Event(
238+
invocation_id='inv_1',
239+
author='agent1',
240+
actions=EventActions(end_of_agent=True, agent_state=None),
241+
)
242+
invocation_context.session.events = [event]
243+
invocation_context.populate_invocation_agent_states()
244+
assert not invocation_context.agent_states
245+
assert not invocation_context.end_of_agents
246+
247+
def test_populate_invocation_agent_states_end_of_agent(self):
248+
"""Tests that populate_invocation_agent_states handles end_of_agent."""
249+
invocation_context = self._create_test_invocation_context(
250+
ResumabilityConfig(is_resumable=True)
251+
)
252+
event = Event(
253+
invocation_id='inv_1',
254+
author='agent1',
255+
actions=EventActions(end_of_agent=True, agent_state=None),
256+
)
257+
invocation_context.session.events = [event]
258+
invocation_context.populate_invocation_agent_states()
259+
assert not invocation_context.agent_states
260+
assert invocation_context.end_of_agents == {'agent1': True}
261+
262+
def test_populate_invocation_agent_states_with_agent_state(self):
263+
"""Tests that populate_invocation_agent_states handles agent_state."""
264+
invocation_context = self._create_test_invocation_context(
265+
ResumabilityConfig(is_resumable=True)
266+
)
267+
event = Event(
268+
invocation_id='inv_1',
269+
author='agent1',
270+
actions=EventActions(
271+
end_of_agent=False,
272+
agent_state=BaseAgentState().model_dump(mode='json'),
273+
),
274+
)
275+
invocation_context.session.events = [event]
276+
invocation_context.populate_invocation_agent_states()
277+
assert invocation_context.agent_states == {'agent1': {}}
278+
assert invocation_context.end_of_agents == {'agent1': False}
279+
280+
def test_populate_invocation_agent_states_with_agent_state_and_end_of_agent(
281+
self,
282+
):
283+
"""Tests that populate_invocation_agent_states handles agent_state and end_of_agent."""
284+
invocation_context = self._create_test_invocation_context(
285+
ResumabilityConfig(is_resumable=True)
286+
)
287+
event = Event(
288+
invocation_id='inv_1',
289+
author='agent1',
290+
actions=EventActions(
291+
end_of_agent=True,
292+
agent_state=BaseAgentState().model_dump(mode='json'),
293+
),
294+
)
295+
invocation_context.session.events = [event]
296+
invocation_context.populate_invocation_agent_states()
297+
# When both agent_state and end_of_agent are set, agent_state should be
298+
# cleared, as end_of_agent is of a higher priority.
299+
assert not invocation_context.agent_states
300+
assert invocation_context.end_of_agents == {'agent1': True}
301+
302+
def test_populate_invocation_agent_states_with_content_no_state(self):
303+
"""Tests that populate_invocation_agent_states creates default state."""
304+
invocation_context = self._create_test_invocation_context(
305+
ResumabilityConfig(is_resumable=True)
306+
)
307+
event = Event(
308+
invocation_id='inv_1',
309+
author='agent1',
310+
actions=EventActions(end_of_agent=False, agent_state=None),
311+
content=Content(role='model', parts=[Part(text='hi')]),
312+
)
313+
invocation_context.session.events = [event]
314+
invocation_context.populate_invocation_agent_states()
315+
assert invocation_context.agent_states == {'agent1': BaseAgentState()}
316+
assert invocation_context.end_of_agents == {'agent1': False}
317+
318+
def test_populate_invocation_agent_states_user_message_event(self):
319+
"""Tests that populate_invocation_agent_states ignores user message events for default state."""
320+
invocation_context = self._create_test_invocation_context(
321+
ResumabilityConfig(is_resumable=True)
322+
)
323+
event = Event(
324+
invocation_id='inv_1',
325+
author='user',
326+
actions=EventActions(end_of_agent=False, agent_state=None),
327+
content=Content(role='user', parts=[Part(text='hi')]),
328+
)
329+
invocation_context.session.events = [event]
330+
invocation_context.populate_invocation_agent_states()
331+
assert not invocation_context.agent_states
332+
assert not invocation_context.end_of_agents
333+
334+
def test_populate_invocation_agent_states_no_content(self):
335+
"""Tests that populate_invocation_agent_states ignores events with no content if no state."""
336+
invocation_context = self._create_test_invocation_context(
337+
ResumabilityConfig(is_resumable=True)
338+
)
339+
event = Event(
340+
invocation_id='inv_1',
341+
author='agent1',
342+
actions=EventActions(end_of_agent=None, agent_state=None),
343+
content=None,
344+
)
345+
invocation_context.session.events = [event]
346+
invocation_context.populate_invocation_agent_states()
347+
assert not invocation_context.agent_states
348+
assert not invocation_context.end_of_agents
349+
230350

231351
class TestFindMatchingFunctionCall:
232352
"""Test suite for find_matching_function_call."""

0 commit comments

Comments
 (0)