Skip to content

Commit e534fa0

Browse files
committed
Add SQLAlchemy session backend for conversation history management
1 parent 818344c commit e534fa0

File tree

6 files changed

+2419
-1828
lines changed

6 files changed

+2419
-1828
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
3838
viz = ["graphviz>=0.17"]
3939
litellm = ["litellm>=1.67.4.post1, <2"]
4040
realtime = ["websockets>=15.0, <16"]
41+
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]
4142

4243
[dependency-groups]
4344
dev = [
@@ -63,6 +64,7 @@ dev = [
6364
"mkdocs-static-i18n>=1.3.0",
6465
"eval-type-backport>=0.2.2",
6566
"fastapi >= 0.110.0, <1",
67+
"aiosqlite>=0.21.0",
6668
]
6769

6870
[tool.uv.workspace]

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
RunErrorDetails,
2424
UserError,
2525
)
26+
from .extensions.memory import SQLAlchemySession
2627
from .guardrail import (
2728
GuardrailFunctionOutput,
2829
InputGuardrail,
@@ -222,6 +223,7 @@ def enable_verbose_stdout_logging():
222223
"AgentHooks",
223224
"Session",
224225
"SQLiteSession",
226+
"SQLAlchemySession",
225227
"RunContextWrapper",
226228
"TContext",
227229
"RunErrorDetails",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
"""Session memory backends living in the extensions namespace.
3+
4+
This package contains optional, production-grade session implementations that
5+
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
6+
conform to the :class:`agents.memory.session.Session` protocol so they can be
7+
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
8+
"""
9+
from __future__ import annotations
10+
11+
from .sqlalchemy_session import SQLAlchemySession # noqa: F401
12+
13+
__all__: list[str] = [
14+
"SQLAlchemySession",
15+
]
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""SQLAlchemy-powered Session backend.
2+
3+
Usage::
4+
5+
from agents.extensions.memory import SQLAlchemySession
6+
7+
# Create from SQLAlchemy URL (uses asyncpg driver under the hood for Postgres)
8+
session = SQLAlchemySession.from_url(
9+
session_id="user-123",
10+
url="postgresql+asyncpg://app:[email protected]/agents",
11+
)
12+
13+
# Or pass an existing AsyncEngine that your application already manages
14+
session = SQLAlchemySession(
15+
session_id="user-123",
16+
engine=my_async_engine,
17+
)
18+
19+
await Runner.run(agent, "Hello", session=session)
20+
"""
21+
22+
from __future__ import annotations
23+
24+
import asyncio
25+
import json
26+
from typing import Any
27+
28+
from sqlalchemy import ( # type: ignore
29+
TIMESTAMP,
30+
Column,
31+
ForeignKey,
32+
Integer,
33+
MetaData,
34+
String,
35+
Table,
36+
Text,
37+
delete,
38+
insert,
39+
select,
40+
text as sql_text,
41+
update,
42+
)
43+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine # type: ignore
44+
from sqlalchemy.orm import sessionmaker # type: ignore
45+
46+
from ...items import TResponseInputItem # type: ignore
47+
from ...memory.session import SessionABC
48+
49+
50+
class SQLAlchemySession(SessionABC):
51+
"""SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""
52+
53+
_metadata: MetaData
54+
_sessions: Table
55+
_messages: Table
56+
57+
def __init__(
58+
self,
59+
session_id: str,
60+
*,
61+
engine: AsyncEngine,
62+
create_tables: bool = True,
63+
sessions_table: str = "agent_sessions",
64+
messages_table: str = "agent_messages",
65+
): # noqa: D401 – short description on the class-level docstring
66+
"""Create a new session.
67+
68+
Parameters
69+
----------
70+
session_id
71+
Unique identifier for the conversation.
72+
engine
73+
A pre-configured SQLAlchemy *async* engine. The engine **must** be
74+
created with an async driver (``postgresql+asyncpg://``,
75+
``mysql+aiomysql://`` or ``sqlite+aiosqlite://``).
76+
create_tables
77+
Whether to automatically create the required tables & indexes. Set
78+
this to *False* if your migrations take care of schema management.
79+
sessions_table, messages_table
80+
Override default table names if needed.
81+
"""
82+
self.session_id = session_id
83+
self._engine = engine
84+
self._lock = asyncio.Lock()
85+
86+
self._metadata = MetaData()
87+
self._sessions = Table(
88+
sessions_table,
89+
self._metadata,
90+
Column("session_id", String, primary_key=True),
91+
Column(
92+
"created_at",
93+
TIMESTAMP(timezone=False),
94+
server_default=sql_text("CURRENT_TIMESTAMP"),
95+
nullable=False,
96+
),
97+
Column(
98+
"updated_at",
99+
TIMESTAMP(timezone=False),
100+
server_default=sql_text("CURRENT_TIMESTAMP"),
101+
onupdate=sql_text("CURRENT_TIMESTAMP"),
102+
nullable=False,
103+
),
104+
)
105+
106+
self._messages = Table(
107+
messages_table,
108+
self._metadata,
109+
Column("id", Integer, primary_key=True, autoincrement=True),
110+
Column(
111+
"session_id",
112+
String,
113+
ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
114+
nullable=False,
115+
),
116+
Column("message_data", Text, nullable=False),
117+
Column(
118+
"created_at",
119+
TIMESTAMP(timezone=False),
120+
server_default=sql_text("CURRENT_TIMESTAMP"),
121+
nullable=False,
122+
),
123+
sqlite_autoincrement=True,
124+
)
125+
126+
# Index for efficient retrieval of messages per session ordered by time
127+
from sqlalchemy import Index # type: ignore
128+
129+
Index(
130+
f"idx_{messages_table}_session_time",
131+
self._messages.c.session_id,
132+
self._messages.c.created_at,
133+
)
134+
135+
# Async session factory
136+
self._session_factory: sessionmaker[AsyncSession] = sessionmaker(
137+
self._engine, expire_on_commit=False, class_=AsyncSession
138+
)
139+
140+
self._create_tables = create_tables
141+
142+
# ---------------------------------------------------------------------
143+
# Convenience constructors
144+
# ---------------------------------------------------------------------
145+
@classmethod
146+
def from_url(
147+
cls,
148+
session_id: str,
149+
*,
150+
url: str,
151+
engine_kwargs: dict[str, Any] | None = None,
152+
**kwargs: Any,
153+
) -> SQLAlchemySession:
154+
"""Create a session from a database URL string.
155+
156+
Parameters
157+
----------
158+
session_id
159+
Conversation ID.
160+
url
161+
Any SQLAlchemy async URL – e.g. ``"postgresql+asyncpg://user:pass@host/db"``.
162+
engine_kwargs
163+
Additional kwargs forwarded to :pyfunc:`sqlalchemy.ext.asyncio.create_async_engine`.
164+
kwargs
165+
Forwarded to the main constructor (``create_tables``, custom table names, …).
166+
"""
167+
engine_kwargs = engine_kwargs or {}
168+
engine = create_async_engine(url, **engine_kwargs)
169+
return cls(session_id, engine=engine, **kwargs)
170+
171+
# ------------------------------------------------------------------
172+
# Session protocol implementation
173+
# ------------------------------------------------------------------
174+
async def _ensure_tables(self) -> None:
175+
"""Ensure tables are created before any database operations."""
176+
if self._create_tables:
177+
async with self._engine.begin() as conn:
178+
await conn.run_sync(self._metadata.create_all)
179+
self._create_tables = False # Only create once
180+
181+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
182+
await self._ensure_tables()
183+
async with self._session_factory() as sess:
184+
if limit is None:
185+
stmt = (
186+
select(self._messages.c.message_data)
187+
.where(self._messages.c.session_id == self.session_id)
188+
.order_by(self._messages.c.created_at.asc())
189+
)
190+
else:
191+
stmt = (
192+
select(self._messages.c.message_data)
193+
.where(self._messages.c.session_id == self.session_id)
194+
.order_by(self._messages.c.created_at.desc())
195+
.limit(limit)
196+
)
197+
198+
result = await sess.execute(stmt)
199+
rows: list[str] = [row[0] for row in result.all()]
200+
201+
if limit is not None:
202+
rows.reverse() # chronological order
203+
204+
items: list[TResponseInputItem] = []
205+
for raw in rows:
206+
try:
207+
items.append(json.loads(raw))
208+
except json.JSONDecodeError:
209+
# Skip corrupted rows
210+
continue
211+
return items
212+
213+
async def add_items(self, items: list[TResponseInputItem]) -> None:
214+
if not items:
215+
return
216+
217+
await self._ensure_tables()
218+
payload = [
219+
{
220+
"session_id": self.session_id,
221+
"message_data": json.dumps(item, separators=(",", ":")),
222+
}
223+
for item in items
224+
]
225+
226+
async with self._session_factory() as sess:
227+
async with sess.begin():
228+
# Ensure the parent session row exists - use merge for cross-DB compatibility
229+
# Check if session exists
230+
existing = await sess.execute(
231+
select(self._sessions.c.session_id).where(
232+
self._sessions.c.session_id == self.session_id
233+
)
234+
)
235+
if not existing.scalar_one_or_none():
236+
# Session doesn't exist, create it
237+
await sess.execute(
238+
insert(self._sessions).values({"session_id": self.session_id})
239+
)
240+
241+
# Insert messages in bulk
242+
await sess.execute(insert(self._messages), payload)
243+
244+
# Touch updated_at column
245+
await sess.execute(
246+
update(self._sessions)
247+
.where(self._sessions.c.session_id == self.session_id)
248+
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
249+
)
250+
251+
async def pop_item(self) -> TResponseInputItem | None:
252+
await self._ensure_tables()
253+
async with self._session_factory() as sess:
254+
async with sess.begin():
255+
# First try dialects that support DELETE … RETURNING for atomicity
256+
try:
257+
stmt = (
258+
delete(self._messages)
259+
.where(self._messages.c.session_id == self.session_id)
260+
.order_by(self._messages.c.created_at.desc())
261+
.limit(1)
262+
.returning(self._messages.c.message_data)
263+
)
264+
result = await sess.execute(stmt)
265+
row = result.scalar_one_or_none()
266+
except Exception: # pragma: no cover – fallback path
267+
# Fallback for dialects that don't support ORDER BY in DELETE
268+
subq = (
269+
select(self._messages.c.id)
270+
.where(self._messages.c.session_id == self.session_id)
271+
.order_by(self._messages.c.created_at.desc())
272+
.limit(1)
273+
)
274+
res = await sess.execute(subq)
275+
row_id = res.scalar_one_or_none()
276+
if row_id is None:
277+
return None
278+
# Fetch data before deleting
279+
res_data = await sess.execute(
280+
select(self._messages.c.message_data).where(self._messages.c.id == row_id)
281+
)
282+
row = res_data.scalar_one_or_none()
283+
await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))
284+
285+
if row is None:
286+
return None
287+
try:
288+
return json.loads(row)
289+
except json.JSONDecodeError:
290+
return None
291+
292+
async def clear_session(self) -> None: # noqa: D401 – imperative mood is fine
293+
await self._ensure_tables()
294+
async with self._session_factory() as sess:
295+
async with sess.begin():
296+
await sess.execute(
297+
delete(self._messages).where(self._messages.c.session_id == self.session_id)
298+
)
299+
await sess.execute(
300+
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
301+
)

0 commit comments

Comments
 (0)