Skip to content

Commit 8b08175

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Add core checkpointing primitive for base agent
PiperOrigin-RevId: 811458903
1 parent b5a65fb commit 8b08175

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from typing_extensions import TypeAlias
3939

4040
from ..events.event import Event
41+
from ..events.event_actions import EventActions
4142
from ..telemetry import tracing
4243
from ..telemetry.tracing import tracer
4344
from ..utils.context_utils import Aclosing
@@ -75,6 +76,9 @@ class BaseAgentState(BaseModel):
7576
)
7677

7778

79+
AgentState = TypeVar('AgentState', bound=BaseAgentState)
80+
81+
7882
class BaseAgent(BaseModel):
7983
"""Base class for all agents in Agent Development Kit."""
8084

@@ -155,6 +159,57 @@ class MyAgent(BaseAgent):
155159
response and appended to event history as agent response.
156160
"""
157161

162+
def _load_agent_state(
163+
self,
164+
ctx: InvocationContext,
165+
state_type: Type[AgentState],
166+
default_state: AgentState,
167+
) -> tuple[AgentState, bool]:
168+
"""Loads the agent state from the invocation context, handling resumption.
169+
170+
Args:
171+
ctx: The invocation context.
172+
state_type: The type of the agent state.
173+
default_state: The default state to use if not resuming.
174+
175+
Returns:
176+
tuple[AgentState, bool]: The current state and a boolean indicating if
177+
resuming.
178+
"""
179+
if self.name not in ctx.agent_states:
180+
return default_state, False
181+
else:
182+
return state_type.model_validate(ctx.agent_states.get(self.name)), True
183+
184+
def _create_agent_state_event(
185+
self,
186+
ctx: InvocationContext,
187+
*,
188+
state: Optional[BaseAgentState] = None,
189+
end_of_agent: bool = False,
190+
) -> Event:
191+
"""Creates an event for agent state.
192+
193+
Args:
194+
ctx: The invocation context.
195+
state: The agent state to checkpoint.
196+
end_of_agent: Whether the agent is finished running.
197+
198+
Returns:
199+
An Event object representing the checkpoint.
200+
"""
201+
event_actions = EventActions()
202+
if state:
203+
event_actions.agent_state = state.model_dump(mode='json')
204+
if end_of_agent:
205+
event_actions.end_of_agent = True
206+
return Event(
207+
invocation_id=ctx.invocation_id,
208+
author=self.name,
209+
branch=ctx.branch,
210+
actions=event_actions,
211+
)
212+
158213
def clone(
159214
self: SelfAgent, update: Mapping[str, Any] | None = None
160215
) -> SelfAgent:

src/google/adk/agents/invocation_context.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Any
1718
from typing import Optional
1819
import uuid
1920

@@ -162,6 +163,12 @@ class InvocationContext(BaseModel):
162163
session: Session
163164
"""The current session of this invocation context. Readonly."""
164165

166+
agent_states: dict[str, dict[str, Any]] = Field(default_factory=dict)
167+
"""The state of the agent for this invocation."""
168+
169+
end_of_agents: dict[str, bool] = Field(default_factory=dict)
170+
"""The end of agent status for each agent in this invocation."""
171+
165172
end_invocation: bool = False
166173
"""Whether to end this invocation.
167174
@@ -201,6 +208,11 @@ class InvocationContext(BaseModel):
201208
of this invocation.
202209
"""
203210

211+
def reset_agent_state(self, agent_name: str) -> None:
212+
"""Resets the state of an agent, allowing it to be re-run."""
213+
self.agent_states.pop(agent_name, None)
214+
self.end_of_agents.pop(agent_name, None)
215+
204216
def increment_llm_call_count(
205217
self,
206218
):

tests/unittests/agents/test_base_agent.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from unittest import mock
2424

2525
from google.adk.agents.base_agent import BaseAgent
26+
from google.adk.agents.base_agent import BaseAgentState
2627
from google.adk.agents.callback_context import CallbackContext
2728
from google.adk.agents.invocation_context import InvocationContext
2829
from google.adk.events.event import Event
@@ -732,6 +733,39 @@ async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
732733
[e async for e in agent.run_live(parent_ctx)]
733734

734735

736+
@pytest.mark.asyncio
737+
async def test_create_agent_state_event(request: pytest.FixtureRequest):
738+
# Arrange
739+
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
740+
ctx = await _create_parent_invocation_context(
741+
request.function.__name__, agent, branch='test_branch'
742+
)
743+
state = BaseAgentState()
744+
745+
# Act
746+
event = agent._create_agent_state_event(ctx, state=state)
747+
748+
# Assert
749+
assert event.invocation_id == ctx.invocation_id
750+
assert event.author == agent.name
751+
assert event.branch == 'test_branch'
752+
assert event.actions is not None
753+
assert event.actions.agent_state is not None
754+
assert event.actions.agent_state == state.model_dump(mode='json')
755+
assert not event.actions.end_of_agent
756+
757+
# Act
758+
event = agent._create_agent_state_event(ctx, end_of_agent=True)
759+
760+
# Assert
761+
assert event.invocation_id == ctx.invocation_id
762+
assert event.author == agent.name
763+
assert event.branch == 'test_branch'
764+
assert event.actions is not None
765+
assert event.actions.end_of_agent
766+
assert event.actions.agent_state is None
767+
768+
735769
def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest):
736770
sub_agents: list[BaseAgent] = [
737771
_TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
@@ -854,3 +888,54 @@ def test_set_parent_agent_for_sub_agent_twice(
854888

855889
if __name__ == '__main__':
856890
pytest.main([__file__])
891+
892+
893+
class _TestAgentState(BaseAgentState):
894+
test_field: str = ''
895+
896+
897+
@pytest.mark.asyncio
898+
async def test_load_agent_state_no_resume():
899+
agent = BaseAgent(name='test_agent')
900+
session_service = InMemorySessionService()
901+
session = await session_service.create_session(
902+
app_name='test_app', user_id='test_user'
903+
)
904+
ctx = InvocationContext(
905+
invocation_id='test_invocation',
906+
agent=agent,
907+
session=session,
908+
session_service=session_service,
909+
)
910+
default_state = _TestAgentState(test_field='default')
911+
912+
state, is_resuming = agent._load_agent_state(
913+
ctx, _TestAgentState, default_state
914+
)
915+
916+
assert not is_resuming
917+
assert state == default_state
918+
919+
920+
@pytest.mark.asyncio
921+
async def test_load_agent_state_with_resume():
922+
agent = BaseAgent(name='test_agent')
923+
session_service = InMemorySessionService()
924+
session = await session_service.create_session(
925+
app_name='test_app', user_id='test_user'
926+
)
927+
ctx = InvocationContext(
928+
invocation_id='test_invocation',
929+
agent=agent,
930+
session=session,
931+
session_service=session_service,
932+
)
933+
persisted_state = _TestAgentState(test_field='resumed')
934+
ctx.agent_states[agent.name] = persisted_state.model_dump(mode='json')
935+
936+
state, is_resuming = agent._load_agent_state(
937+
ctx, _TestAgentState, _TestAgentState()
938+
)
939+
940+
assert is_resuming
941+
assert state == persisted_state

0 commit comments

Comments
 (0)