Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
viz = ["graphviz>=0.17"]
litellm = ["litellm>=1.67.4.post1, <2"]
realtime = ["websockets>=15.0, <16"]
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]

[dependency-groups]
dev = [
Expand All @@ -63,6 +64,7 @@ dev = [
"mkdocs-static-i18n>=1.3.0",
"eval-type-backport>=0.2.2",
"fastapi >= 0.110.0, <1",
"aiosqlite>=0.21.0",
]

[tool.uv.workspace]
Expand Down
15 changes: 15 additions & 0 deletions src/agents/extensions/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

"""Session memory backends living in the extensions namespace.

This package contains optional, production-grade session implementations that
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
conform to the :class:`agents.memory.session.Session` protocol so they can be
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
"""
from __future__ import annotations

from .sqlalchemy_session import SQLAlchemySession # noqa: F401

__all__: list[str] = [
"SQLAlchemySession",
]
288 changes: 288 additions & 0 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""SQLAlchemy-powered Session backend.

Usage::

from agents.extensions.memory import SQLAlchemySession

# Create from SQLAlchemy URL (uses asyncpg driver under the hood for Postgres)
session = SQLAlchemySession.from_url(
session_id="user-123",
url="postgresql+asyncpg://app:[email protected]/agents",
)

# Or pass an existing AsyncEngine that your application already manages
session = SQLAlchemySession(
session_id="user-123",
engine=my_async_engine,
)

await Runner.run(agent, "Hello", session=session)
"""

from __future__ import annotations

import asyncio
import json
from typing import Any

from sqlalchemy import (
TIMESTAMP,
Column,
ForeignKey,
Integer,
MetaData,
String,
Table,
Text,
delete,
insert,
select,
text as sql_text,
update,
)
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine

from ...items import TResponseInputItem
from ...memory.session import SessionABC


class SQLAlchemySession(SessionABC):
"""SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""

_metadata: MetaData
_sessions: Table
_messages: Table

def __init__(
self,
session_id: str,
*,
engine: AsyncEngine,
create_tables: bool = True,
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
): # noqa: D401 – short description on the class-level docstring
"""Create a new session.

Parameters
----------
session_id
Unique identifier for the conversation.
engine
A pre-configured SQLAlchemy *async* engine. The engine **must** be
created with an async driver (``postgresql+asyncpg://``,
``mysql+aiomysql://`` or ``sqlite+aiosqlite://``).
create_tables
Whether to automatically create the required tables & indexes. Set
this to *False* if your migrations take care of schema management.
sessions_table, messages_table
Override default table names if needed.
"""
self.session_id = session_id
self._engine = engine
self._lock = asyncio.Lock()

self._metadata = MetaData()
self._sessions = Table(
sessions_table,
self._metadata,
Column("session_id", String, primary_key=True),
Column(
"created_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
Column(
"updated_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
onupdate=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
)

self._messages = Table(
messages_table,
self._metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column(
"session_id",
String,
ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
nullable=False,
),
Column("message_data", Text, nullable=False),
Column(
"created_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
sqlite_autoincrement=True,
)

# Index for efficient retrieval of messages per session ordered by time
from sqlalchemy import Index

Index(
f"idx_{messages_table}_session_time",
self._messages.c.session_id,
self._messages.c.created_at,
)

# Async session factory
self._session_factory = async_sessionmaker(
self._engine, expire_on_commit=False
)

self._create_tables = create_tables

# ---------------------------------------------------------------------
# Convenience constructors
# ---------------------------------------------------------------------
@classmethod
def from_url(
cls,
session_id: str,
*,
url: str,
engine_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> SQLAlchemySession:
"""Create a session from a database URL string.

Parameters
----------
session_id
Conversation ID.
url
Any SQLAlchemy async URL – e.g. ``"postgresql+asyncpg://user:pass@host/db"``.
engine_kwargs
Additional kwargs forwarded to :pyfunc:`sqlalchemy.ext.asyncio.create_async_engine`.
kwargs
Forwarded to the main constructor (``create_tables``, custom table names, …).
"""
engine_kwargs = engine_kwargs or {}
engine = create_async_engine(url, **engine_kwargs)
return cls(session_id, engine=engine, **kwargs)

# ------------------------------------------------------------------
# Session protocol implementation
# ------------------------------------------------------------------
async def _ensure_tables(self) -> None:
"""Ensure tables are created before any database operations."""
if self._create_tables:
async with self._engine.begin() as conn:
await conn.run_sync(self._metadata.create_all)
self._create_tables = False # Only create once

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
await self._ensure_tables()
async with self._session_factory() as sess:
if limit is None:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.asc())
)
else:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.desc())
.limit(limit)
)

result = await sess.execute(stmt)
rows: list[str] = [row[0] for row in result.all()]

if limit is not None:
rows.reverse() # chronological order

items: list[TResponseInputItem] = []
for raw in rows:
try:
items.append(json.loads(raw))
except json.JSONDecodeError:
# Skip corrupted rows
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
if not items:
return

await self._ensure_tables()
payload = [
{
"session_id": self.session_id,
"message_data": json.dumps(item, separators=(",", ":")),
}
for item in items
]

async with self._session_factory() as sess:
async with sess.begin():
# Ensure the parent session row exists - use merge for cross-DB compatibility
# Check if session exists
existing = await sess.execute(
select(self._sessions.c.session_id).where(
self._sessions.c.session_id == self.session_id
)
)
if not existing.scalar_one_or_none():
# Session doesn't exist, create it
await sess.execute(
insert(self._sessions).values({"session_id": self.session_id})
)

# Insert messages in bulk
await sess.execute(insert(self._messages), payload)

# Touch updated_at column
await sess.execute(
update(self._sessions)
.where(self._sessions.c.session_id == self.session_id)
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
)

async def pop_item(self) -> TResponseInputItem | None:
await self._ensure_tables()
async with self._session_factory() as sess:
async with sess.begin():
# Fallback for all dialects - get ID first, then delete
subq = (
select(self._messages.c.id)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.desc())
.limit(1)
)
res = await sess.execute(subq)
row_id = res.scalar_one_or_none()
if row_id is None:
return None
# Fetch data before deleting
res_data = await sess.execute(
select(self._messages.c.message_data).where(self._messages.c.id == row_id)
)
row = res_data.scalar_one_or_none()
await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))

if row is None:
return None
try:
return json.loads(row) # type: ignore[no-any-return]
except json.JSONDecodeError:
return None

async def clear_session(self) -> None: # noqa: D401 – imperative mood is fine
await self._ensure_tables()
async with self._session_factory() as sess:
async with sess.begin():
await sess.execute(
delete(self._messages).where(self._messages.c.session_id == self.session_id)
)
await sess.execute(
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
)
Loading