Skip to content

Commit b2b80e7

Browse files
XinranTangcopybara-github
authored andcommitted
feat: Pause invocations on long running function calls for resumable apps
PiperOrigin-RevId: 811518771
1 parent dd1ffad commit b2b80e7

File tree

8 files changed

+719
-3
lines changed

8 files changed

+719
-3
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,45 @@ def get_events(
260260
results = [event for event in results if event.branch == self.branch]
261261
return results
262262

263+
def should_pause_invocation(self, event: Event) -> bool:
264+
"""Returns whether to pause the invocation right after this event.
265+
266+
"Pausing" an invocation is different from "ending" an invocation. A paused
267+
invocation can be resumed later, while an ended invocation cannot.
268+
269+
Pausing the current agent's run will also pause all the agents that
270+
depend on its execution, i.e. the subsequent agents in a workflow, and the
271+
current agent's ancestors, etc.
272+
273+
Note that parallel sibling agents won't be affected, but their common
274+
ancestors will be paused after all the non-blocking sub-agents finished
275+
running.
276+
277+
Should meet all following conditions to pause an invocation:
278+
1. The app is resumable.
279+
2. The current event has a long running function call.
280+
281+
Args:
282+
event: The current event.
283+
284+
Returns:
285+
Whether to pause the invocation right after this event.
286+
"""
287+
if (
288+
not self.resumability_config
289+
or not self.resumability_config.is_resumable
290+
):
291+
return False
292+
293+
if not event.long_running_tool_ids or not event.get_function_calls():
294+
return False
295+
296+
for fc in event.get_function_calls():
297+
if fc.id in event.long_running_tool_ids:
298+
return True
299+
300+
return False
301+
263302

264303
def new_invocation_context_id() -> str:
265304
return "e-" + str(uuid.uuid4())

src/google/adk/agents/llm_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ async def _run_async_impl(
341341
async for event in agen:
342342
self.__maybe_save_output_to_state(event)
343343
yield event
344+
if ctx.should_pause_invocation(event):
345+
return
344346

345347
@override
346348
async def _run_live_impl(

src/google/adk/agents/loop_agent.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,26 @@ async def _run_async_impl(
7070
while not self.max_iterations or times_looped < self.max_iterations:
7171
for sub_agent in self.sub_agents:
7272
should_exit = False
73+
pause_invocation = False
74+
7375
async with Aclosing(sub_agent.run_async(ctx)) as agen:
7476
async for event in agen:
7577
yield event
7678
if event.actions.escalate:
7779
should_exit = True
80+
if ctx.should_pause_invocation(event):
81+
pause_invocation = True
7882

83+
# Indicates that the loop agent should exist after running this
84+
# sub-agent.
7985
if should_exit:
8086
return
8187

88+
# Indicates that the invocation should be paused after running this
89+
# sub-agent.
90+
if pause_invocation:
91+
return
92+
8293
times_looped += 1
8394
return
8495

src/google/adk/agents/parallel_agent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,26 @@ async def _run_async_impl(
181181
)
182182
for sub_agent in self.sub_agents
183183
]
184+
185+
pause_invocation = False
184186
try:
185187
# TODO remove if once Python <3.11 is no longer supported.
186188
if sys.version_info >= (3, 11):
187189
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
188190
async for event in agen:
189191
yield event
192+
if ctx.should_pause_invocation(event):
193+
pause_invocation = True
190194
else:
191195
async with Aclosing(_merge_agent_run_pre_3_11(agent_runs)) as agen:
192196
async for event in agen:
193197
yield event
198+
if ctx.should_pause_invocation(event):
199+
pause_invocation = True
200+
201+
if pause_invocation:
202+
return
203+
194204
finally:
195205
for sub_agent_run in agent_runs:
196206
await sub_agent_run.aclose()

src/google/adk/agents/sequential_agent.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,18 @@ async def _run_async_impl(
5252
self, ctx: InvocationContext
5353
) -> AsyncGenerator[Event, None]:
5454
for sub_agent in self.sub_agents:
55+
pause_invocation = False
56+
5557
async with Aclosing(sub_agent.run_async(ctx)) as agen:
5658
async for event in agen:
5759
yield event
60+
if ctx.should_pause_invocation(event):
61+
pause_invocation = True
62+
63+
# Indicates the invocation should pause when receiving signal from
64+
# the current sub_agent.
65+
if pause_invocation:
66+
return
5867

5968
@override
6069
async def _run_live_impl(

tests/unittests/agents/test_invocation_context.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616

1717
from google.adk.agents.base_agent import BaseAgent
1818
from google.adk.agents.invocation_context import InvocationContext
19+
from google.adk.apps import ResumabilityConfig
1920
from google.adk.events.event import Event
2021
from google.adk.sessions.base_session_service import BaseSessionService
2122
from google.adk.sessions.session import Session
23+
from google.genai.types import FunctionCall
24+
from google.genai.types import Part
2225
import pytest
2326

27+
from .. import testing_utils
28+
2429

2530
class TestInvocationContext:
2631
"""Test suite for InvocationContext."""
@@ -117,3 +122,87 @@ def test_get_events_with_no_matching_events(self, mock_invocation_context):
117122
current_branch=True,
118123
)
119124
assert not events
125+
126+
127+
class TestInvocationContextWithAppResumablity:
128+
"""Test suite for InvocationContext regarding app resumability."""
129+
130+
@pytest.fixture
131+
def long_running_function_call(self) -> FunctionCall:
132+
"""A long running function call."""
133+
return FunctionCall(
134+
id='tool_call_id_1',
135+
name='long_running_function_call',
136+
args={},
137+
)
138+
139+
@pytest.fixture
140+
def event_to_pause(self, long_running_function_call) -> Event:
141+
"""An event with a long running function call."""
142+
return Event(
143+
invocation_id='inv_1',
144+
author='agent',
145+
content=testing_utils.ModelContent(
146+
[Part(function_call=long_running_function_call)]
147+
),
148+
long_running_tool_ids=[long_running_function_call.id],
149+
)
150+
151+
def _create_test_invocation_context(
152+
self, resumability_config
153+
) -> InvocationContext:
154+
"""Create a mock invocation context for testing."""
155+
ctx = InvocationContext(
156+
session_service=Mock(spec=BaseSessionService),
157+
agent=Mock(spec=BaseAgent),
158+
invocation_id='inv_1',
159+
session=Mock(spec=Session),
160+
resumability_config=resumability_config,
161+
)
162+
return ctx
163+
164+
def test_should_pause_invocation_with_resumable_app(self, event_to_pause):
165+
"""Tests should_pause_invocation with a resumable app."""
166+
mock_invocation_context = self._create_test_invocation_context(
167+
ResumabilityConfig(is_resumable=True)
168+
)
169+
170+
assert mock_invocation_context.should_pause_invocation(event_to_pause)
171+
172+
def test_should_not_pause_invocation_with_non_resumable_app(
173+
self, event_to_pause
174+
):
175+
"""Tests should_pause_invocation with a non-resumable app."""
176+
invocation_context = self._create_test_invocation_context(
177+
ResumabilityConfig(is_resumable=False)
178+
)
179+
180+
assert not invocation_context.should_pause_invocation(event_to_pause)
181+
182+
def test_should_not_pause_invocation_with_no_long_running_tool_ids(
183+
self, event_to_pause
184+
):
185+
"""Tests should_pause_invocation with no long running tools."""
186+
invocation_context = self._create_test_invocation_context(
187+
ResumabilityConfig(is_resumable=True)
188+
)
189+
nonpausable_event = event_to_pause.model_copy(
190+
update={'long_running_tool_ids': []}
191+
)
192+
193+
assert not invocation_context.should_pause_invocation(nonpausable_event)
194+
195+
def test_should_not_pause_invocation_with_no_function_calls(
196+
self, event_to_pause
197+
):
198+
"""Tests should_pause_invocation with a non-model event."""
199+
mock_invocation_context = self._create_test_invocation_context(
200+
ResumabilityConfig(is_resumable=True)
201+
)
202+
nonpausable_event = event_to_pause.model_copy(
203+
update={'content': testing_utils.UserContent('test text part')}
204+
)
205+
206+
assert not mock_invocation_context.should_pause_invocation(
207+
nonpausable_event
208+
)

0 commit comments

Comments
 (0)