Skip to content

Commit 88ea1ef

Browse files
committed
Simplify the run state data
1 parent 3553b23 commit 88ea1ef

File tree

18 files changed

+810
-798
lines changed

18 files changed

+810
-798
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,5 @@ cython_debug/
148148

149149
# Redis database files
150150
dump.rdb
151+
152+
tmp/
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
File-backed session example with human-in-the-loop tool approval.
3+
4+
This mirrors the JS `file-hitl.ts` sample: a session persisted on disk and tools that
5+
require approval before execution.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import asyncio
11+
import json
12+
from typing import Any
13+
14+
from agents import Agent, Runner, function_tool
15+
from agents.run_context import RunContextWrapper
16+
from agents.run_state import RunState
17+
18+
from .file_session import FileSession
19+
20+
21+
async def main() -> None:
22+
user_context = {"user_id": "101"}
23+
24+
customer_directory: dict[str, str] = {
25+
"101": (
26+
"Customer Kaz S. (tier gold) can be reached at +1-415-555-AAAA. "
27+
"Notes: Prefers SMS follow ups and values concise summaries."
28+
),
29+
"104": (
30+
"Customer Yu S. (tier platinum) can be reached at +1-415-555-BBBB. "
31+
"Notes: Recently reported sync issues. Flagged for a proactive onboarding call."
32+
),
33+
"205": (
34+
"Customer Ken S. (tier standard) can be reached at +1-415-555-CCCC. "
35+
"Notes: Interested in automation tutorials sent last week."
36+
),
37+
}
38+
39+
lookup_customer_profile = create_lookup_customer_profile_tool(directory=customer_directory)
40+
41+
instructions = (
42+
"You assist support agents. For every user turn you must call lookup_customer_profile. "
43+
"If a tool reports a transient failure, request approval and retry the same call once before "
44+
"responding. Keep responses under three sentences."
45+
)
46+
47+
agent = Agent(
48+
name="File HITL assistant",
49+
instructions=instructions,
50+
tools=[lookup_customer_profile],
51+
)
52+
53+
session = FileSession(dir="examples/memory/tmp")
54+
session_id = await session.get_session_id()
55+
print(f"Session id: {session_id}")
56+
print("Enter a message to chat with the agent. Submit an empty line to exit.")
57+
58+
saved_state = await session.load_state_json()
59+
if saved_state:
60+
print("Found saved run state. Resuming pending interruptions before new input.")
61+
try:
62+
state = await RunState.from_json(agent, saved_state, context_override=user_context)
63+
result = await Runner.run(agent, state, session=session)
64+
while result.interruptions:
65+
state = result.to_state()
66+
for interruption in result.interruptions:
67+
args = format_tool_arguments(interruption)
68+
approved = await prompt_yes_no(
69+
f"Agent {interruption.agent.name} wants to call {interruption.name} with {args or 'no arguments'}"
70+
)
71+
if approved:
72+
state.approve(interruption)
73+
print("Approved tool call.")
74+
else:
75+
state.reject(interruption)
76+
print("Rejected tool call.")
77+
result = await Runner.run(agent, state, session=session)
78+
await session.save_state_json(result.to_state().to_json())
79+
reply = result.final_output or "[No final output produced]"
80+
print(f"Assistant (resumed): {reply}\n")
81+
except Exception as exc: # noqa: BLE001
82+
print(f"Failed to resume saved state: {exc}. Starting a new session.")
83+
84+
while True:
85+
print("You: ", end="", flush=True)
86+
loop = asyncio.get_event_loop()
87+
user_message = await loop.run_in_executor(None, input)
88+
if not user_message.strip():
89+
break
90+
91+
result = await Runner.run(agent, user_message, session=session, context=user_context)
92+
while result.interruptions:
93+
state = result.to_state()
94+
for interruption in result.interruptions:
95+
args = format_tool_arguments(interruption)
96+
approved = await prompt_yes_no(
97+
f"Agent {interruption.agent.name} wants to call {interruption.name} with {args or 'no arguments'}"
98+
)
99+
if approved:
100+
state.approve(interruption)
101+
print("Approved tool call.")
102+
else:
103+
state.reject(interruption)
104+
print("Rejected tool call.")
105+
result = await Runner.run(agent, state, session=session)
106+
await session.save_state_json(result.to_state().to_json())
107+
108+
reply = result.final_output or "[No final output produced]"
109+
print(f"Assistant: {reply}\n")
110+
111+
112+
def create_lookup_customer_profile_tool(
113+
*,
114+
directory: dict[str, str],
115+
missing_customer_message: str = "No customer found for that id.",
116+
):
117+
@function_tool(
118+
name_override="lookup_customer_profile",
119+
description_override="Look up stored profile details for a customer by their internal id.",
120+
needs_approval=True,
121+
)
122+
def lookup_customer_profile(ctx: RunContextWrapper[Any]) -> str:
123+
return directory.get(ctx.context.get("user_id"), missing_customer_message)
124+
125+
return lookup_customer_profile
126+
127+
128+
def format_tool_arguments(interruption: Any) -> str:
129+
args = getattr(interruption, "arguments", None)
130+
if args is None:
131+
return ""
132+
if isinstance(args, str):
133+
return args
134+
try:
135+
return json.dumps(args)
136+
except Exception:
137+
return str(args)
138+
139+
140+
async def prompt_yes_no(question: str) -> bool:
141+
print(f"{question} (y/n): ", end="", flush=True)
142+
loop = asyncio.get_event_loop()
143+
answer = await loop.run_in_executor(None, input)
144+
normalized = answer.strip().lower()
145+
return normalized in {"y", "yes"}
146+
147+
148+
if __name__ == "__main__":
149+
asyncio.run(main())

examples/memory/file_session.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Simple file-backed session implementation for examples.
3+
4+
Persists conversation history as JSON on disk so runs can resume across processes.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import asyncio
10+
import json
11+
from datetime import datetime
12+
from pathlib import Path
13+
from typing import Any
14+
from uuid import uuid4
15+
16+
from agents.memory.session import Session
17+
18+
19+
class FileSession(Session):
20+
"""Persist session items to a JSON file on disk."""
21+
22+
def __init__(self, *, dir: str | Path | None = None, session_id: str | None = None) -> None:
23+
self._dir = Path(dir) if dir is not None else Path.cwd() / ".agents-sessions"
24+
self.session_id = session_id or ""
25+
# Ensure the directory exists up front so subsequent file operations do not race.
26+
self._dir.mkdir(parents=True, exist_ok=True)
27+
28+
async def _ensure_session_id(self) -> str:
29+
if not self.session_id:
30+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
31+
# Prefix with wall-clock time so recent sessions are easy to spot on disk.
32+
self.session_id = f"{timestamp}-{uuid4().hex[:12]}"
33+
await asyncio.to_thread(self._dir.mkdir, parents=True, exist_ok=True)
34+
file_path = self._items_path(self.session_id)
35+
if not file_path.exists():
36+
await asyncio.to_thread(file_path.write_text, "[]", encoding="utf-8")
37+
return self.session_id
38+
39+
async def get_session_id(self) -> str:
40+
"""Return the session id, creating one if needed."""
41+
return await self._ensure_session_id()
42+
43+
async def get_items(self, limit: int | None = None) -> list[Any]:
44+
session_id = await self._ensure_session_id()
45+
items = await self._read_items(session_id)
46+
if limit is not None and limit >= 0:
47+
return items[-limit:]
48+
return items
49+
50+
async def add_items(self, items: list[Any]) -> None:
51+
if not items:
52+
return
53+
session_id = await self._ensure_session_id()
54+
current = await self._read_items(session_id)
55+
# Deep-copy via JSON to avoid persisting live references that might mutate later.
56+
cloned = json.loads(json.dumps(items))
57+
await self._write_items(session_id, current + cloned)
58+
59+
async def pop_item(self) -> Any | None:
60+
session_id = await self._ensure_session_id()
61+
items = await self._read_items(session_id)
62+
if not items:
63+
return None
64+
popped = items.pop()
65+
await self._write_items(session_id, items)
66+
return popped
67+
68+
async def clear_session(self) -> None:
69+
if not self.session_id:
70+
return
71+
file_path = self._items_path(self.session_id)
72+
state_path = self._state_path(self.session_id)
73+
try:
74+
await asyncio.to_thread(file_path.unlink)
75+
except FileNotFoundError:
76+
pass
77+
try:
78+
await asyncio.to_thread(state_path.unlink)
79+
except FileNotFoundError:
80+
pass
81+
self.session_id = ""
82+
83+
def _items_path(self, session_id: str) -> Path:
84+
return self._dir / f"{session_id}.json"
85+
86+
def _state_path(self, session_id: str) -> Path:
87+
return self._dir / f"{session_id}-state.json"
88+
89+
async def _read_items(self, session_id: str) -> list[Any]:
90+
file_path = self._items_path(session_id)
91+
try:
92+
data = await asyncio.to_thread(file_path.read_text, "utf-8")
93+
parsed = json.loads(data)
94+
return parsed if isinstance(parsed, list) else []
95+
except FileNotFoundError:
96+
return []
97+
98+
async def _write_items(self, session_id: str, items: list[Any]) -> None:
99+
file_path = self._items_path(session_id)
100+
payload = json.dumps(items, indent=2, ensure_ascii=False)
101+
await asyncio.to_thread(self._dir.mkdir, parents=True, exist_ok=True)
102+
await asyncio.to_thread(file_path.write_text, payload, encoding="utf-8")
103+
104+
async def load_state_json(self) -> dict[str, Any] | None:
105+
"""Load a previously saved RunState JSON payload, if present."""
106+
session_id = await self._ensure_session_id()
107+
state_path = self._state_path(session_id)
108+
try:
109+
data = await asyncio.to_thread(state_path.read_text, "utf-8")
110+
parsed = json.loads(data)
111+
return parsed if isinstance(parsed, dict) else None
112+
except FileNotFoundError:
113+
return None
114+
115+
async def save_state_json(self, state: dict[str, Any]) -> None:
116+
"""Persist the serialized RunState JSON payload alongside session items."""
117+
session_id = await self._ensure_session_id()
118+
state_path = self._state_path(session_id)
119+
payload = json.dumps(state, indent=2, ensure_ascii=False)
120+
await asyncio.to_thread(self._dir.mkdir, parents=True, exist_ok=True)
121+
await asyncio.to_thread(state_path.write_text, payload, encoding="utf-8")

0 commit comments

Comments
 (0)