Skip to content

Commit 717bb1e

Browse files
Feat: add agent run cache (#10)
* add initial cache functionality * add caching * remove resume=True in example * clean up agent cache attributes
1 parent 70a5388 commit 717bb1e

File tree

5 files changed

+704
-14
lines changed

5 files changed

+704
-14
lines changed

docs/guides/caching.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Caching and Resumption
2+
3+
Stirrup automatically caches agent state on interruptions, allowing you to resume long-running tasks.
4+
5+
## Enabling Resume
6+
7+
Pass `resume=True` to `session()`:
8+
9+
```python
10+
from stirrup import Agent
11+
from stirrup.clients.chat_completions_client import ChatCompletionsClient
12+
from stirrup.tools import DEFAULT_TOOLS
13+
14+
client = ChatCompletionsClient(model="gpt-5")
15+
agent = Agent(client=client, name="researcher", tools=DEFAULT_TOOLS, max_turns=50)
16+
17+
async with agent.session(output_dir="./output", resume=True) as session:
18+
await session.run("Analyze all datasets in the data folder")
19+
```
20+
21+
## How It Works
22+
23+
1. **On interruption** (Ctrl+C, error, or max turns): Stirrup saves conversation state and execution environment files to `~/.cache/stirrup/<task_hash>/`
24+
25+
2. **On next run with `resume=True`**: If a cache exists for the same prompt, the agent restores state and continues from the last turn
26+
27+
3. **On successful completion**: The cache is automatically cleared (configurable via `clear_on_success`)
28+
29+
```
30+
# First run (interrupted at turn 15)
31+
$ python my_agent.py
32+
^C
33+
Cached state for task abc123...
34+
35+
# Second run (resumes from turn 15)
36+
$ python my_agent.py
37+
Resuming from cached state at turn 15
38+
```
39+
40+
## What Gets Cached
41+
42+
- Conversation messages and history
43+
- Current turn number
44+
- Tool metadata
45+
- All files in the execution environment
46+
47+
## Preserving Caches on Success
48+
49+
By default, caches are cleared on successful completion. To preserve them for inspection or debugging:
50+
51+
```python
52+
async with agent.session(
53+
resume=True,
54+
clear_cache_on_success=False, # Keep cache after success
55+
) as session:
56+
await session.run("Analyze the data")
57+
```
58+
59+
## Managing Caches
60+
61+
```python
62+
from stirrup.core.cache import CacheManager
63+
64+
cache_manager = CacheManager()
65+
66+
# List all caches
67+
for task_hash in cache_manager.list_caches():
68+
info = cache_manager.get_cache_info(task_hash)
69+
print(f"{task_hash}: turn {info['turn']}")
70+
71+
# Clear a specific cache
72+
cache_manager.clear_cache("abc123def456")
73+
```
74+
75+
## Notes
76+
77+
- Cache key is computed from the initial prompt—same prompt = same cache
78+
- Caches are stored locally in `~/.cache/stirrup/`
79+
- Caches are automatically cleared on successful completion (by default)

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ nav:
8080
- Creating Tools: guides/tools.md
8181
- Tool Providers: guides/tool-providers.md
8282
- Code Execution: guides/code-execution.md
83+
- Caching: guides/caching.md
8384
- Skills: guides/skills.md
8485
- Sub-Agents: guides/sub-agents.md
8586
- MCP Integration: guides/mcp.md

src/stirrup/core/agent.py

Lines changed: 140 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77
import re
8+
import signal
89
from contextlib import AsyncExitStack
910
from dataclasses import dataclass, field
1011
from itertools import chain, takewhile
@@ -20,6 +21,7 @@
2021
CONTEXT_SUMMARIZATION_CUTOFF,
2122
FINISH_TOOL_NAME,
2223
)
24+
from stirrup.core.cache import CacheManager, CacheState, compute_task_hash
2325
from stirrup.core.models import (
2426
AssistantMessage,
2527
ChatMessage,
@@ -226,13 +228,19 @@ def __init__(
226228
self._pending_output_dir: Path | None = None
227229
self._pending_input_files: str | Path | list[str | Path] | None = None
228230
self._pending_skills_dir: Path | None = None
231+
self._resume: bool = False
232+
self._clear_cache_on_success: bool = True
229233

230234
# Instance-scoped state (populated during __aenter__, isolated per agent instance)
231235
self._active_tools: dict[str, Tool] = {}
232236
self._last_finish_params: Any = None # FinishParams type parameter
233237
self._last_run_metadata: dict[str, list[Any]] = {}
234238
self._transferred_paths: list[str] = [] # Paths transferred to parent (for subagents)
235239

240+
# Cache state for resumption (set during run(), used in __aexit__ for caching on interrupt)
241+
self._current_task_hash: str | None = None
242+
self._current_run_state: CacheState | None = None
243+
236244
@property
237245
def name(self) -> str:
238246
"""The name of this agent."""
@@ -263,6 +271,8 @@ def session(
263271
output_dir: Path | str | None = None,
264272
input_files: str | Path | list[str | Path] | None = None,
265273
skills_dir: Path | str | None = None,
274+
resume: bool = False,
275+
clear_cache_on_success: bool = True,
266276
) -> Self:
267277
"""Configure a session and return self for use as async context manager.
268278
@@ -278,6 +288,13 @@ def session(
278288
skills_dir: Directory containing skill definitions to load and make available
279289
to the agent. Skills are uploaded to the execution environment
280290
and their metadata is included in the system prompt.
291+
resume: If True, attempt to resume from cached state if available.
292+
The cache is identified by hashing the init_msgs passed to run().
293+
Cached state includes message history, current turn, and execution
294+
environment files from a previous interrupted run.
295+
clear_cache_on_success: If True (default), automatically clear the cache
296+
when the agent completes successfully. Set to False
297+
to preserve caches for inspection or debugging.
281298
282299
Returns:
283300
Self, for use with `async with agent.session(...) as session:`
@@ -294,8 +311,18 @@ def session(
294311
self._pending_output_dir = Path(output_dir) if output_dir else None
295312
self._pending_input_files = input_files
296313
self._pending_skills_dir = Path(skills_dir) if skills_dir else None
314+
self._resume = resume
315+
self._clear_cache_on_success = clear_cache_on_success
297316
return self
298317

318+
def _handle_interrupt(self, _signum: int, _frame: object) -> None:
319+
"""Handle SIGINT to ensure caching before exit.
320+
321+
Converts the signal to a KeyboardInterrupt exception so that __aexit__
322+
is properly called and can cache the state before cleanup.
323+
"""
324+
raise KeyboardInterrupt("Agent interrupted - state will be cached")
325+
299326
def _resolve_input_files(self, input_files: str | Path | list[str | Path]) -> list[Path]:
300327
"""Resolve input file paths, expanding globs and normalizing to Path objects.
301328
@@ -632,6 +659,11 @@ async def __aenter__(self) -> Self:
632659
# depth is already set (0 for main agent, passed in for sub-agents)
633660
self._logger.__enter__()
634661

662+
# Set up signal handler for graceful caching on interrupt (root agent only)
663+
if current_depth == 0:
664+
self._original_sigint = signal.getsignal(signal.SIGINT)
665+
signal.signal(signal.SIGINT, self._handle_interrupt)
666+
635667
return self
636668

637669
except Exception:
@@ -653,6 +685,47 @@ async def __aexit__(
653685
state = _SESSION_STATE.get()
654686

655687
try:
688+
# Cache state on non-success exit (only at root level)
689+
should_cache = (
690+
state.depth == 0
691+
and (exc_type is not None or self._last_finish_params is None)
692+
and self._current_task_hash is not None
693+
and self._current_run_state is not None
694+
)
695+
696+
logger.debug(
697+
"[%s __aexit__] Cache decision: should_cache=%s, depth=%d, exc_type=%s, "
698+
"finish_params=%s, task_hash=%s, run_state=%s",
699+
self._name,
700+
should_cache,
701+
state.depth,
702+
exc_type,
703+
self._last_finish_params is not None,
704+
self._current_task_hash,
705+
self._current_run_state is not None,
706+
)
707+
708+
if should_cache:
709+
cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
710+
711+
exec_env_dir = state.exec_env.temp_dir if state.exec_env else None
712+
713+
# Explicit checks to keep type checker happy - should_cache condition guarantees these
714+
if self._current_task_hash is None or self._current_run_state is None:
715+
raise ValueError("Cache state is unexpectedly None after should_cache check")
716+
717+
# Temporarily block SIGINT during cache save to prevent interruption
718+
original_handler = signal.getsignal(signal.SIGINT)
719+
signal.signal(signal.SIGINT, signal.SIG_IGN)
720+
try:
721+
cache_manager.save_state(
722+
self._current_task_hash,
723+
self._current_run_state,
724+
exec_env_dir,
725+
)
726+
finally:
727+
signal.signal(signal.SIGINT, original_handler)
728+
self._logger.info(f"Cached state for task {self._current_task_hash}")
656729
# Save files from finish_params.paths based on depth
657730
if state.output_dir and self._last_finish_params and state.exec_env:
658731
paths = getattr(self._last_finish_params, "paths", None)
@@ -707,6 +780,11 @@ async def __aexit__(
707780
state.depth,
708781
)
709782
finally:
783+
# Restore original signal handler (root agent only)
784+
if hasattr(self, "_original_sigint"):
785+
signal.signal(signal.SIGINT, self._original_sigint)
786+
del self._original_sigint
787+
710788
# Exit logger context
711789
self._logger.finish_params = self._last_finish_params
712790
self._logger.run_metadata = self._last_run_metadata
@@ -870,23 +948,59 @@ async def run(
870948
])
871949
872950
"""
873-
msgs: list[ChatMessage] = []
874951

875-
# Build the complete system prompt (base + input files + user instructions)
876-
full_system_prompt = self._build_system_prompt()
877-
msgs.append(SystemMessage(content=full_system_prompt))
952+
# Compute task hash for caching/resume
953+
task_hash = compute_task_hash(init_msgs)
954+
self._current_task_hash = task_hash
955+
956+
# Initialize cache manager
957+
cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
958+
start_turn = 0
959+
resumed = False
960+
961+
# Try to resume from cache if requested
962+
if self._resume:
963+
state = _SESSION_STATE.get()
964+
cached = cache_manager.load_state(task_hash)
965+
if cached:
966+
# Restore files to exec env
967+
if state.exec_env and state.exec_env.temp_dir:
968+
cache_manager.restore_files(task_hash, state.exec_env.temp_dir)
969+
970+
# Restore state
971+
msgs = cached.msgs
972+
full_msg_history = cached.full_msg_history
973+
run_metadata = cached.run_metadata
974+
start_turn = cached.turn
975+
resumed = True
976+
self._logger.info(f"Resuming from cached state at turn {start_turn}")
977+
else:
978+
self._logger.info(f"No cache found for task {task_hash}, starting fresh")
878979

879-
if isinstance(init_msgs, str):
880-
msgs.append(UserMessage(content=init_msgs))
881-
else:
882-
msgs.extend(init_msgs)
980+
if not resumed:
981+
msgs: list[ChatMessage] = []
982+
983+
# Build the complete system prompt (base + input files + user instructions)
984+
full_system_prompt = self._build_system_prompt()
985+
msgs.append(SystemMessage(content=full_system_prompt))
986+
987+
if isinstance(init_msgs, str):
988+
msgs.append(UserMessage(content=init_msgs))
989+
else:
990+
msgs.extend(init_msgs)
991+
992+
# Local metadata storage - isolated per run() invocation for thread safety
993+
run_metadata: dict[str, list[Any]] = {}
994+
995+
full_msg_history: list[list[ChatMessage]] = []
883996

884997
# Set logger depth if provided (for sub-agent runs)
885998
if depth is not None:
886999
self._logger.depth = depth
8871000

888-
# Log the task at run start
889-
self._logger.task_message(msgs[-1].content)
1001+
# Log the task at run start (only if not resuming)
1002+
if not resumed:
1003+
self._logger.task_message(msgs[-1].content)
8901004

8911005
# Show warnings (top-level only, if logger supports it)
8921006
if self._logger.depth == 0 and isinstance(self._logger, AgentLogger):
@@ -897,9 +1011,6 @@ async def run(
8971011
# Use logger callback if available and not overridden
8981012
step_callback = self._logger.on_step
8991013

900-
# Local metadata storage - isolated per run() invocation for thread safety
901-
run_metadata: dict[str, list[Any]] = {}
902-
9031014
full_msg_history: list[list[ChatMessage]] = []
9041015
finish_params: FinishParams | None = None
9051016

@@ -908,7 +1019,16 @@ async def run(
9081019
total_input_tokens = 0
9091020
total_output_tokens = 0
9101021

911-
for i in range(self._max_turns):
1022+
for i in range(start_turn, self._max_turns):
1023+
# Capture current state for potential caching (before any async work)
1024+
self._current_run_state = CacheState(
1025+
msgs=list(msgs),
1026+
full_msg_history=[list(group) for group in full_msg_history],
1027+
turn=i,
1028+
run_metadata=dict(run_metadata),
1029+
task_hash=task_hash,
1030+
agent_name=self._name,
1031+
)
9121032
if self._max_turns - i <= 30 and i != 0:
9131033
num_turns_remaining_msg = _num_turns_remaining_msg(self._max_turns - i)
9141034
msgs.append(num_turns_remaining_msg)
@@ -976,6 +1096,12 @@ async def run(
9761096
self._last_finish_params = finish_params
9771097
self._last_run_metadata = run_metadata
9781098

1099+
# Clear cache on successful completion (finish_params is set)
1100+
if finish_params is not None and cache_manager.clear_on_success:
1101+
cache_manager.clear_cache(task_hash)
1102+
self._current_task_hash = None
1103+
self._current_run_state = None
1104+
9791105
return finish_params, full_msg_history, run_metadata
9801106

9811107
def to_tool(

0 commit comments

Comments
 (0)