-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Add SQLAlchemy session backend for conversation history management #1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
e534fa0
Add SQLAlchemy session backend for conversation history management
habema f669093
temporary fixes to pass typecheck and old_versions CI tests
habema 78614fd
linter fixes
habema ea1177c
Merge remote-tracking branch 'origin/main' into feat/sqlalchemy-sessions
habema 48b675c
fix mypy errors
habema 94b6ce9
more mypy fixes
habema 18d38b2
even more mypy and lint fixes
habema 83dcdd8
Merge branch 'main' into feat/sqlalchemy-sessions
habema a1bf963
Update SQLAlchemySession to default create_tables to False
habema af6c9e1
add docs
habema c4edf74
add example
habema 38f270f
fix imports in example
habema 55167e0
code review: move index to table initialization
habema f6579e8
code review: explanatory comment for order/reverse behaviour
habema 3f9f31a
code review: move serialization to separate method
habema 43f7a40
code review: move deserialization to separate method
habema 43f1fd6
remove docs to move to another PR
habema 26d44ed
fix mypy
habema d8853a9
code review: make serialization/deserialization async
habema File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
|
|
||
| """Session memory backends living in the extensions namespace. | ||
|
|
||
| This package contains optional, production-grade session implementations that | ||
| introduce extra third-party dependencies (database drivers, ORMs, etc.). They | ||
| conform to the :class:`agents.memory.session.Session` protocol so they can be | ||
| used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`. | ||
| """ | ||
| from __future__ import annotations | ||
|
|
||
| from .sqlalchemy_session import SQLAlchemySession # noqa: F401 | ||
|
|
||
| __all__: list[str] = [ | ||
| "SQLAlchemySession", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,288 @@ | ||
| """SQLAlchemy-powered Session backend. | ||
|
|
||
| Usage:: | ||
|
|
||
| from agents.extensions.memory import SQLAlchemySession | ||
|
|
||
| # Create from SQLAlchemy URL (uses asyncpg driver under the hood for Postgres) | ||
| session = SQLAlchemySession.from_url( | ||
| session_id="user-123", | ||
| url="postgresql+asyncpg://app:[email protected]/agents", | ||
| ) | ||
|
|
||
| # Or pass an existing AsyncEngine that your application already manages | ||
| session = SQLAlchemySession( | ||
| session_id="user-123", | ||
| engine=my_async_engine, | ||
| ) | ||
|
|
||
| await Runner.run(agent, "Hello", session=session) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import json | ||
| from typing import Any | ||
|
|
||
| from sqlalchemy import ( | ||
| TIMESTAMP, | ||
| Column, | ||
| ForeignKey, | ||
| Integer, | ||
| MetaData, | ||
| String, | ||
| Table, | ||
| Text, | ||
| delete, | ||
| insert, | ||
| select, | ||
| text as sql_text, | ||
| update, | ||
| ) | ||
| from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine | ||
|
|
||
| from ...items import TResponseInputItem | ||
| from ...memory.session import SessionABC | ||
|
|
||
|
|
||
| class SQLAlchemySession(SessionABC): | ||
| """SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`.""" | ||
|
|
||
| _metadata: MetaData | ||
| _sessions: Table | ||
| _messages: Table | ||
|
|
||
| def __init__( | ||
| self, | ||
| session_id: str, | ||
| *, | ||
| engine: AsyncEngine, | ||
| create_tables: bool = True, | ||
| sessions_table: str = "agent_sessions", | ||
| messages_table: str = "agent_messages", | ||
| ): # noqa: D401 – short description on the class-level docstring | ||
| """Create a new session. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| session_id | ||
| Unique identifier for the conversation. | ||
| engine | ||
| A pre-configured SQLAlchemy *async* engine. The engine **must** be | ||
| created with an async driver (``postgresql+asyncpg://``, | ||
| ``mysql+aiomysql://`` or ``sqlite+aiosqlite://``). | ||
| create_tables | ||
| Whether to automatically create the required tables & indexes. Set | ||
| this to *False* if your migrations take care of schema management. | ||
| sessions_table, messages_table | ||
| Override default table names if needed. | ||
| """ | ||
| self.session_id = session_id | ||
| self._engine = engine | ||
| self._lock = asyncio.Lock() | ||
|
|
||
| self._metadata = MetaData() | ||
| self._sessions = Table( | ||
| sessions_table, | ||
| self._metadata, | ||
| Column("session_id", String, primary_key=True), | ||
| Column( | ||
| "created_at", | ||
| TIMESTAMP(timezone=False), | ||
| server_default=sql_text("CURRENT_TIMESTAMP"), | ||
| nullable=False, | ||
| ), | ||
| Column( | ||
| "updated_at", | ||
| TIMESTAMP(timezone=False), | ||
| server_default=sql_text("CURRENT_TIMESTAMP"), | ||
| onupdate=sql_text("CURRENT_TIMESTAMP"), | ||
| nullable=False, | ||
| ), | ||
| ) | ||
|
|
||
| self._messages = Table( | ||
| messages_table, | ||
| self._metadata, | ||
| Column("id", Integer, primary_key=True, autoincrement=True), | ||
| Column( | ||
| "session_id", | ||
| String, | ||
| ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"), | ||
| nullable=False, | ||
| ), | ||
| Column("message_data", Text, nullable=False), | ||
| Column( | ||
| "created_at", | ||
| TIMESTAMP(timezone=False), | ||
| server_default=sql_text("CURRENT_TIMESTAMP"), | ||
| nullable=False, | ||
| ), | ||
| sqlite_autoincrement=True, | ||
| ) | ||
|
|
||
| # Index for efficient retrieval of messages per session ordered by time | ||
| from sqlalchemy import Index | ||
|
|
||
| Index( | ||
| f"idx_{messages_table}_session_time", | ||
| self._messages.c.session_id, | ||
| self._messages.c.created_at, | ||
| ) | ||
|
|
||
| # Async session factory | ||
| self._session_factory = async_sessionmaker( | ||
| self._engine, expire_on_commit=False | ||
| ) | ||
|
|
||
| self._create_tables = create_tables | ||
|
|
||
| # --------------------------------------------------------------------- | ||
| # Convenience constructors | ||
| # --------------------------------------------------------------------- | ||
| @classmethod | ||
| def from_url( | ||
| cls, | ||
| session_id: str, | ||
| *, | ||
| url: str, | ||
| engine_kwargs: dict[str, Any] | None = None, | ||
| **kwargs: Any, | ||
| ) -> SQLAlchemySession: | ||
| """Create a session from a database URL string. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| session_id | ||
| Conversation ID. | ||
| url | ||
| Any SQLAlchemy async URL – e.g. ``"postgresql+asyncpg://user:pass@host/db"``. | ||
| engine_kwargs | ||
| Additional kwargs forwarded to :pyfunc:`sqlalchemy.ext.asyncio.create_async_engine`. | ||
| kwargs | ||
| Forwarded to the main constructor (``create_tables``, custom table names, …). | ||
| """ | ||
| engine_kwargs = engine_kwargs or {} | ||
| engine = create_async_engine(url, **engine_kwargs) | ||
| return cls(session_id, engine=engine, **kwargs) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Session protocol implementation | ||
| # ------------------------------------------------------------------ | ||
| async def _ensure_tables(self) -> None: | ||
habema marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Ensure tables are created before any database operations.""" | ||
| if self._create_tables: | ||
| async with self._engine.begin() as conn: | ||
| await conn.run_sync(self._metadata.create_all) | ||
| self._create_tables = False # Only create once | ||
|
|
||
| async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: | ||
| await self._ensure_tables() | ||
| async with self._session_factory() as sess: | ||
| if limit is None: | ||
| stmt = ( | ||
| select(self._messages.c.message_data) | ||
| .where(self._messages.c.session_id == self.session_id) | ||
| .order_by(self._messages.c.created_at.asc()) | ||
| ) | ||
| else: | ||
| stmt = ( | ||
| select(self._messages.c.message_data) | ||
| .where(self._messages.c.session_id == self.session_id) | ||
| .order_by(self._messages.c.created_at.desc()) | ||
habema marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| .limit(limit) | ||
| ) | ||
|
|
||
| result = await sess.execute(stmt) | ||
| rows: list[str] = [row[0] for row in result.all()] | ||
|
|
||
| if limit is not None: | ||
| rows.reverse() # chronological order | ||
|
|
||
| items: list[TResponseInputItem] = [] | ||
| for raw in rows: | ||
| try: | ||
| items.append(json.loads(raw)) | ||
| except json.JSONDecodeError: | ||
| # Skip corrupted rows | ||
| continue | ||
| return items | ||
|
|
||
| async def add_items(self, items: list[TResponseInputItem]) -> None: | ||
| if not items: | ||
| return | ||
|
|
||
| await self._ensure_tables() | ||
| payload = [ | ||
| { | ||
| "session_id": self.session_id, | ||
| "message_data": json.dumps(item, separators=(",", ":")), | ||
habema marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| for item in items | ||
| ] | ||
|
|
||
| async with self._session_factory() as sess: | ||
| async with sess.begin(): | ||
| # Ensure the parent session row exists - use merge for cross-DB compatibility | ||
| # Check if session exists | ||
| existing = await sess.execute( | ||
| select(self._sessions.c.session_id).where( | ||
| self._sessions.c.session_id == self.session_id | ||
| ) | ||
| ) | ||
| if not existing.scalar_one_or_none(): | ||
| # Session doesn't exist, create it | ||
| await sess.execute( | ||
| insert(self._sessions).values({"session_id": self.session_id}) | ||
| ) | ||
|
|
||
| # Insert messages in bulk | ||
| await sess.execute(insert(self._messages), payload) | ||
|
|
||
| # Touch updated_at column | ||
| await sess.execute( | ||
| update(self._sessions) | ||
| .where(self._sessions.c.session_id == self.session_id) | ||
| .values(updated_at=sql_text("CURRENT_TIMESTAMP")) | ||
| ) | ||
|
|
||
| async def pop_item(self) -> TResponseInputItem | None: | ||
| await self._ensure_tables() | ||
| async with self._session_factory() as sess: | ||
| async with sess.begin(): | ||
| # Fallback for all dialects - get ID first, then delete | ||
| subq = ( | ||
| select(self._messages.c.id) | ||
| .where(self._messages.c.session_id == self.session_id) | ||
| .order_by(self._messages.c.created_at.desc()) | ||
| .limit(1) | ||
| ) | ||
| res = await sess.execute(subq) | ||
| row_id = res.scalar_one_or_none() | ||
| if row_id is None: | ||
| return None | ||
| # Fetch data before deleting | ||
| res_data = await sess.execute( | ||
| select(self._messages.c.message_data).where(self._messages.c.id == row_id) | ||
| ) | ||
| row = res_data.scalar_one_or_none() | ||
| await sess.execute(delete(self._messages).where(self._messages.c.id == row_id)) | ||
|
|
||
| if row is None: | ||
| return None | ||
| try: | ||
| return json.loads(row) # type: ignore[no-any-return] | ||
habema marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| except json.JSONDecodeError: | ||
| return None | ||
|
|
||
| async def clear_session(self) -> None: # noqa: D401 – imperative mood is fine | ||
| await self._ensure_tables() | ||
| async with self._session_factory() as sess: | ||
| async with sess.begin(): | ||
| await sess.execute( | ||
| delete(self._messages).where(self._messages.c.session_id == self.session_id) | ||
| ) | ||
| await sess.execute( | ||
| delete(self._sessions).where(self._sessions.c.session_id == self.session_id) | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.