diff --git a/examples/memory/openai_session_example.py b/examples/memory/openai_session_example.py new file mode 100644 index 000000000..9254195b3 --- /dev/null +++ b/examples/memory/openai_session_example.py @@ -0,0 +1,78 @@ +""" +Example demonstrating session memory functionality. + +This example shows how to use session memory to maintain conversation history +across multiple agent runs without manually handling .to_input_list(). +""" + +import asyncio + +from agents import Agent, OpenAIConversationsSession, Runner + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session = OpenAIConversationsSession() + + print("=== Session Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + # Demonstrate the limit parameter - get only the latest 2 items + print("\n=== Latest Items Demo ===") + latest_items = await session.get_items(limit=2) + # print(latest_items) + print("Latest 2 items:") + for i, msg in enumerate(latest_items, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_items)} out of total conversation history.") + + # Get all items to show the difference + all_items = await session.get_items() + # print(all_items) + print(f"Total items in session: {len(all_items)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/basic/sqlalchemy_session_example.py b/examples/memory/sqlalchemy_session_example.py similarity index 50% rename from examples/basic/sqlalchemy_session_example.py rename to examples/memory/sqlalchemy_session_example.py index 2aec270f5..84a6c754f 100644 --- a/examples/basic/sqlalchemy_session_example.py +++ b/examples/memory/sqlalchemy_session_example.py @@ -20,28 +20,56 @@ async def main(): create_tables=True, ) - print("=== SQLAlchemySession Example ===") + print("=== Session Example ===") print("The agent will remember previous messages automatically.\n") # First turn + print("First turn:") print("User: What city is the Golden Gate Bridge in?") result = await Runner.run( agent, "What city is the Golden Gate Bridge in?", session=session, ) - print(f"Assistant: {result.final_output}\n") + print(f"Assistant: {result.final_output}") + print() # Second turn - the agent will remember the previous conversation + print("Second turn:") print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") result = await Runner.run( agent, - "What state is it in?", + "What's the population of that state?", session=session, ) - print(f"Assistant: {result.final_output}\n") + print(f"Assistant: {result.final_output}") + print() print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + # Demonstrate the limit parameter - get only the latest 2 items + print("\n=== Latest Items Demo ===") + latest_items = await session.get_items(limit=2) + print("Latest 2 items:") + for i, msg in enumerate(latest_items, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_items)} out of total conversation history.") + + # Get all items to show the difference + all_items = await session.get_items() + print(f"Total items in session: {len(all_items)}") if __name__ == "__main__": diff --git a/examples/basic/session_example.py b/examples/memory/sqlite_session_example.py similarity index 100% rename from examples/basic/session_example.py rename to examples/memory/sqlite_session_example.py diff --git a/examples/reasoning_content/main.py b/examples/reasoning_content/main.py index 9da2a5690..e83c0d4d4 100644 --- a/examples/reasoning_content/main.py +++ b/examples/reasoning_content/main.py @@ -47,6 +47,7 @@ async def stream_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): if event.type == "response.reasoning_summary_text.delta": @@ -83,6 +84,7 @@ async def get_response_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 02830bb29..3a8260f29 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -46,7 +46,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks -from .memory import Session, SQLiteSession +from .memory import OpenAIConversationsSession, Session, SessionABC, SQLiteSession from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.multi_provider import MultiProvider @@ -221,7 +221,9 @@ def enable_verbose_stdout_logging(): "RunHooks", "AgentHooks", "Session", + "SessionABC", "SQLiteSession", + "OpenAIConversationsSession", "RunContextWrapper", "TContext", "RunErrorDetails", diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index fca172fff..b20c673af 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -82,7 +82,8 @@ async def get_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: Any | None = None, ) -> ModelResponse: with generation_span( @@ -171,7 +172,8 @@ async def stream_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: Any | None = None, ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 059ca57ab..eeb2ace5d 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,3 +1,10 @@ -from .session import Session, SQLiteSession +from .openai_conversations_session import OpenAIConversationsSession +from .session import Session, SessionABC +from .sqlite_session import SQLiteSession -__all__ = ["Session", "SQLiteSession"] +__all__ = [ + "Session", + "SessionABC", + "SQLiteSession", + "OpenAIConversationsSession", +] diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py new file mode 100644 index 000000000..9bf5ccdac --- /dev/null +++ b/src/agents/memory/openai_conversations_session.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from openai import AsyncOpenAI + +from agents.models._openai_shared import get_default_openai_client + +from ..items import TResponseInputItem +from .session import SessionABC + + +async def start_openai_conversations_session(openai_client: AsyncOpenAI | None = None) -> str: + _maybe_openai_client = openai_client + if openai_client is None: + _maybe_openai_client = get_default_openai_client() or AsyncOpenAI() + # this never be None here + _openai_client: AsyncOpenAI = _maybe_openai_client # type: ignore [assignment] + + response = await _openai_client.conversations.create(items=[]) + return response.id + + +_EMPTY_SESSION_ID = "" + + +class OpenAIConversationsSession(SessionABC): + def __init__( + self, + *, + conversation_id: str | None = None, + openai_client: AsyncOpenAI | None = None, + ): + self._session_id: str | None = conversation_id + _openai_client = openai_client + if _openai_client is None: + _openai_client = get_default_openai_client() or AsyncOpenAI() + # this never be None here + self._openai_client: AsyncOpenAI = _openai_client + + async def _get_session_id(self) -> str: + if self._session_id is None: + self._session_id = await start_openai_conversations_session(self._openai_client) + return self._session_id + + async def _clear_session_id(self) -> None: + self._session_id = None + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + session_id = await self._get_session_id() + all_items = [] + if limit is None: + async for item in self._openai_client.conversations.items.list( + conversation_id=session_id, + order="asc", + ): + # calling model_dump() to make this serializable + all_items.append(item.model_dump()) + else: + async for item in self._openai_client.conversations.items.list( + conversation_id=session_id, + limit=limit, + order="desc", + ): + # calling model_dump() to make this serializable + all_items.append(item.model_dump()) + if limit is not None and len(all_items) >= limit: + break + all_items.reverse() + + return all_items # type: ignore + + async def add_items(self, items: list[TResponseInputItem]) -> None: + session_id = await self._get_session_id() + await self._openai_client.conversations.items.create( + conversation_id=session_id, + items=items, + ) + + async def pop_item(self) -> TResponseInputItem | None: + session_id = await self._get_session_id() + items = await self.get_items(limit=1) + if not items: + return None + item_id: str = str(items[0]["id"]) # type: ignore [typeddict-item] + await self._openai_client.conversations.items.delete( + conversation_id=session_id, item_id=item_id + ) + return items[0] + + async def clear_session(self) -> None: + session_id = await self._get_session_id() + await self._openai_client.conversations.delete( + conversation_id=session_id, + ) + await self._clear_session_id() diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 8db0971eb..9c85af6dd 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,11 +1,6 @@ from __future__ import annotations -import asyncio -import json -import sqlite3 -import threading from abc import ABC, abstractmethod -from pathlib import Path from typing import TYPE_CHECKING, Protocol, runtime_checkable if TYPE_CHECKING: @@ -102,268 +97,3 @@ async def pop_item(self) -> TResponseInputItem | None: async def clear_session(self) -> None: """Clear all items for this session.""" ... - - -class SQLiteSession(SessionABC): - """SQLite-based implementation of session storage. - - This implementation stores conversation history in a SQLite database. - By default, uses an in-memory database that is lost when the process ends. - For persistent storage, provide a file path. - """ - - def __init__( - self, - session_id: str, - db_path: str | Path = ":memory:", - sessions_table: str = "agent_sessions", - messages_table: str = "agent_messages", - ): - """Initialize the SQLite session. - - Args: - session_id: Unique identifier for the conversation session - db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) - sessions_table: Name of the table to store session metadata. Defaults to - 'agent_sessions' - messages_table: Name of the table to store message data. Defaults to 'agent_messages' - """ - self.session_id = session_id - self.db_path = db_path - self.sessions_table = sessions_table - self.messages_table = messages_table - self._local = threading.local() - self._lock = threading.Lock() - - # For in-memory databases, we need a shared connection to avoid thread isolation - # For file databases, we use thread-local connections for better concurrency - self._is_memory_db = str(db_path) == ":memory:" - if self._is_memory_db: - self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) - self._shared_connection.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(self._shared_connection) - else: - # For file databases, initialize the schema once since it persists - init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - init_conn.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(init_conn) - init_conn.close() - - def _get_connection(self) -> sqlite3.Connection: - """Get a database connection.""" - if self._is_memory_db: - # Use shared connection for in-memory database to avoid thread isolation - return self._shared_connection - else: - # Use thread-local connections for file databases - if not hasattr(self._local, "connection"): - self._local.connection = sqlite3.connect( - str(self.db_path), - check_same_thread=False, - ) - self._local.connection.execute("PRAGMA journal_mode=WAL") - assert isinstance(self._local.connection, sqlite3.Connection), ( - f"Expected sqlite3.Connection, got {type(self._local.connection)}" - ) - return self._local.connection - - def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: - """Initialize the database schema for a specific connection.""" - conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.sessions_table} ( - session_id TEXT PRIMARY KEY, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """ - ) - - conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.messages_table} ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT NOT NULL, - message_data TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) - ON DELETE CASCADE - ) - """ - ) - - conn.execute( - f""" - CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id - ON {self.messages_table} (session_id, created_at) - """ - ) - - conn.commit() - - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - """Retrieve the conversation history for this session. - - Args: - limit: Maximum number of items to retrieve. If None, retrieves all items. - When specified, returns the latest N items in chronological order. - - Returns: - List of input items representing the conversation history - """ - - def _get_items_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): - if limit is None: - # Fetch all items in chronological order - cursor = conn.execute( - f""" - SELECT message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at ASC - """, - (self.session_id,), - ) - else: - # Fetch the latest N items in chronological order - cursor = conn.execute( - f""" - SELECT message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at DESC - LIMIT ? - """, - (self.session_id, limit), - ) - - rows = cursor.fetchall() - - # Reverse to get chronological order when using DESC - if limit is not None: - rows = list(reversed(rows)) - - items = [] - for (message_data,) in rows: - try: - item = json.loads(message_data) - items.append(item) - except json.JSONDecodeError: - # Skip invalid JSON entries - continue - - return items - - return await asyncio.to_thread(_get_items_sync) - - async def add_items(self, items: list[TResponseInputItem]) -> None: - """Add new items to the conversation history. - - Args: - items: List of input items to add to the history - """ - if not items: - return - - def _add_items_sync(): - conn = self._get_connection() - - with self._lock if self._is_memory_db else threading.Lock(): - # Ensure session exists - conn.execute( - f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) - """, - (self.session_id,), - ) - - # Add items - message_data = [(self.session_id, json.dumps(item)) for item in items] - conn.executemany( - f""" - INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) - """, - message_data, - ) - - # Update session timestamp - conn.execute( - f""" - UPDATE {self.sessions_table} - SET updated_at = CURRENT_TIMESTAMP - WHERE session_id = ? - """, - (self.session_id,), - ) - - conn.commit() - - await asyncio.to_thread(_add_items_sync) - - async def pop_item(self) -> TResponseInputItem | None: - """Remove and return the most recent item from the session. - - Returns: - The most recent item if it exists, None if the session is empty - """ - - def _pop_item_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): - # Use DELETE with RETURNING to atomically delete and return the most recent item - cursor = conn.execute( - f""" - DELETE FROM {self.messages_table} - WHERE id = ( - SELECT id FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at DESC - LIMIT 1 - ) - RETURNING message_data - """, - (self.session_id,), - ) - - result = cursor.fetchone() - conn.commit() - - if result: - message_data = result[0] - try: - item = json.loads(message_data) - return item - except json.JSONDecodeError: - # Return None for corrupted JSON entries (already deleted) - return None - - return None - - return await asyncio.to_thread(_pop_item_sync) - - async def clear_session(self) -> None: - """Clear all items for this session.""" - - def _clear_session_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): - conn.execute( - f"DELETE FROM {self.messages_table} WHERE session_id = ?", - (self.session_id,), - ) - conn.execute( - f"DELETE FROM {self.sessions_table} WHERE session_id = ?", - (self.session_id,), - ) - conn.commit() - - await asyncio.to_thread(_clear_session_sync) - - def close(self) -> None: - """Close the database connection.""" - if self._is_memory_db: - if hasattr(self, "_shared_connection"): - self._shared_connection.close() - else: - if hasattr(self._local, "connection"): - self._local.connection.close() diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py new file mode 100644 index 000000000..2c2386ec7 --- /dev/null +++ b/src/agents/memory/sqlite_session.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import asyncio +import json +import sqlite3 +import threading +from pathlib import Path + +from ..items import TResponseInputItem +from .session import SessionABC + + +class SQLiteSession(SessionABC): + """SQLite-based implementation of session storage. + + This implementation stores conversation history in a SQLite database. + By default, uses an in-memory database that is lost when the process ends. + For persistent storage, provide a file path. + """ + + def __init__( + self, + session_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ): + """Initialize the SQLite session. + + Args: + session_id: Unique identifier for the conversation session + db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) + sessions_table: Name of the table to store session metadata. Defaults to + 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' + """ + self.session_id = session_id + self.db_path = db_path + self.sessions_table = sessions_table + self.messages_table = messages_table + self._local = threading.local() + self._lock = threading.Lock() + + # For in-memory databases, we need a shared connection to avoid thread isolation + # For file databases, we use thread-local connections for better concurrency + self._is_memory_db = str(db_path) == ":memory:" + if self._is_memory_db: + self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) + else: + # For file databases, initialize the schema once since it persists + init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + init_conn.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(init_conn) + init_conn.close() + + def _get_connection(self) -> sqlite3.Connection: + """Get a database connection.""" + if self._is_memory_db: + # Use shared connection for in-memory database to avoid thread isolation + return self._shared_connection + else: + # Use thread-local connections for file databases + if not hasattr(self._local, "connection"): + self._local.connection = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + ) + self._local.connection.execute("PRAGMA journal_mode=WAL") + assert isinstance(self._local.connection, sqlite3.Connection), ( + f"Expected sqlite3.Connection, got {type(self._local.connection)}" + ) + return self._local.connection + + def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: + """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.sessions_table} ( + session_id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.messages_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) + ON DELETE CASCADE + ) + """ + ) + + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id + ON {self.messages_table} (session_id, created_at) + """ + ) + + conn.commit() + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + + def _get_items_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + if limit is None: + # Fetch all items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (self.session_id,), + ) + else: + # Fetch the latest N items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT ? + """, + (self.session_id, limit), + ) + + rows = cursor.fetchall() + + # Reverse to get chronological order when using DESC + if limit is not None: + rows = list(reversed(rows)) + + items = [] + for (message_data,) in rows: + try: + item = json.loads(message_data) + items.append(item) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue + + return items + + return await asyncio.to_thread(_get_items_sync) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + if not items: + return + + def _add_items_sync(): + conn = self._get_connection() + + with self._lock if self._is_memory_db else threading.Lock(): + # Ensure session exists + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (self.session_id,), + ) + + # Add items + message_data = [(self.session_id, json.dumps(item)) for item in items] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + # Update session timestamp + conn.execute( + f""" + UPDATE {self.sessions_table} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? + """, + (self.session_id,), + ) + + conn.commit() + + await asyncio.to_thread(_add_items_sync) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + + def _pop_item_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + # Use DELETE with RETURNING to atomically delete and return the most recent item + cursor = conn.execute( + f""" + DELETE FROM {self.messages_table} + WHERE id = ( + SELECT id FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + ) + RETURNING message_data + """, + (self.session_id,), + ) + + result = cursor.fetchone() + conn.commit() + + if result: + message_data = result[0] + try: + item = json.loads(message_data) + return item + except json.JSONDecodeError: + # Return None for corrupted JSON entries (already deleted) + return None + + return None + + return await asyncio.to_thread(_pop_item_sync) + + async def clear_session(self) -> None: + """Clear all items for this session.""" + + def _clear_session_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.commit() + + await asyncio.to_thread(_clear_session_sync) + + def close(self) -> None: + """Close the database connection.""" + if self._is_memory_db: + if hasattr(self, "_shared_connection"): + self._shared_connection.close() + else: + if hasattr(self._local, "connection"): + self._local.connection.close() diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index 5a185806c..f25934780 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -48,6 +48,7 @@ async def get_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> ModelResponse: """Get a response from the model. @@ -62,6 +63,7 @@ async def get_response( tracing: Tracing configuration. previous_response_id: the ID of the previous response. Generally not used by the model, except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. prompt: The prompt config to use for the model. Returns: @@ -81,6 +83,7 @@ def stream_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: """Stream a response from the model. @@ -95,6 +98,7 @@ def stream_response( tracing: Tracing configuration. previous_response_id: the ID of the previous response. Generally not used by the model, except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. prompt: The prompt config to use for the model. Returns: diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index c6d1d7d22..f4d75d833 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -55,7 +55,8 @@ async def get_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, ) -> ModelResponse: with generation_span( @@ -142,7 +143,8 @@ async def stream_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, ) -> AsyncIterator[TResponseStreamEvent]: """ diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 6405bd586..85d8a0224 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -74,7 +74,8 @@ async def get_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, + conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, ) -> ModelResponse: with response_span(disabled=tracing.is_disabled()) as span_response: @@ -86,7 +87,8 @@ async def get_response( tools, output_schema, handoffs, - previous_response_id, + previous_response_id=previous_response_id, + conversation_id=conversation_id, stream=False, prompt=prompt, ) @@ -149,7 +151,8 @@ async def stream_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, + conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, ) -> AsyncIterator[ResponseStreamEvent]: """ @@ -164,7 +167,8 @@ async def stream_response( tools, output_schema, handoffs, - previous_response_id, + previous_response_id=previous_response_id, + conversation_id=conversation_id, stream=True, prompt=prompt, ) @@ -202,6 +206,7 @@ async def _fetch_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, + conversation_id: str | None, stream: Literal[True], prompt: ResponsePromptParam | None = None, ) -> AsyncStream[ResponseStreamEvent]: ... @@ -216,6 +221,7 @@ async def _fetch_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, + conversation_id: str | None, stream: Literal[False], prompt: ResponsePromptParam | None = None, ) -> Response: ... @@ -228,7 +234,8 @@ async def _fetch_response( tools: list[Tool], output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], - previous_response_id: str | None, + previous_response_id: str | None = None, + conversation_id: str | None = None, stream: Literal[True] | Literal[False] = False, prompt: ResponsePromptParam | None = None, ) -> Response | AsyncStream[ResponseStreamEvent]: @@ -264,6 +271,7 @@ async def _fetch_response( f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" f"Previous response id: {previous_response_id}\n" + f"Conversation id: {conversation_id}\n" ) extra_args = dict(model_settings.extra_args or {}) @@ -277,6 +285,7 @@ async def _fetch_response( return await self._client.responses.create( previous_response_id=self._non_null_or_not_given(previous_response_id), + conversation=self._non_null_or_not_given(conversation_id), instructions=self._non_null_or_not_given(system_instructions), model=self.model, input=list_input, diff --git a/src/agents/run.py b/src/agents/run.py index 727927b08..742917b87 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -208,6 +208,9 @@ class RunOptions(TypedDict, Generic[TContext]): previous_response_id: NotRequired[str | None] """The ID of the previous response, if any.""" + conversation_id: NotRequired[str | None] + """The ID of the stored conversation, if any.""" + session: NotRequired[Session | None] """The session for the run.""" @@ -224,6 +227,7 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + conversation_id: str | None = None, session: Session | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final @@ -248,6 +252,13 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + conversation_id: The conversation ID (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). + If provided, the conversation will be used to read and write items. + Every agent will have access to the conversation history so far, + and it's output items will be written to the conversation. + We recommend only using this if you are exclusively using OpenAI models; + other model providers don't write to the Conversation object, + so you'll end up having partial conversations stored. Returns: A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. @@ -261,6 +272,7 @@ async def run( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) @@ -275,6 +287,7 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + conversation_id: str | None = None, session: Session | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the @@ -302,6 +315,7 @@ def run_sync( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + conversation_id: The ID of the stored conversation, if any. Returns: A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. @@ -315,6 +329,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) @@ -328,6 +343,7 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + conversation_id: str | None = None, session: Session | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object @@ -353,6 +369,7 @@ def run_streamed( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + conversation_id: The ID of the stored conversation, if any. Returns: A result object that contains data about the run, as well as a method to stream events. """ @@ -365,6 +382,7 @@ def run_streamed( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) @@ -386,6 +404,7 @@ async def run( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") if hooks is None: hooks = RunHooks[Any]() @@ -478,6 +497,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, + conversation_id=conversation_id, ), ) else: @@ -492,6 +512,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, + conversation_id=conversation_id, ) should_run_agent_start_hooks = False @@ -558,6 +579,7 @@ def run_sync( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") return asyncio.get_event_loop().run_until_complete( @@ -570,6 +592,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, ) ) @@ -584,6 +607,7 @@ def run_streamed( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") if hooks is None: @@ -638,6 +662,7 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) ) @@ -738,6 +763,7 @@ async def _start_streaming( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, + conversation_id: str | None, session: Session | None, ): if streamed_result.trace: @@ -821,6 +847,7 @@ async def _start_streaming( tool_use_tracker, all_tools, previous_response_id, + conversation_id, ) should_run_agent_start_hooks = False @@ -923,6 +950,7 @@ async def _run_single_turn_streamed( tool_use_tracker: AgentToolUseTracker, all_tools: list[Tool], previous_response_id: str | None, + conversation_id: str | None, ) -> SingleStepResult: emitted_tool_call_ids: set[str] = set() @@ -983,6 +1011,7 @@ async def _run_single_turn_streamed( run_config.tracing_disabled, run_config.trace_include_sensitive_data ), previous_response_id=previous_response_id, + conversation_id=conversation_id, prompt=prompt_config, ): if isinstance(event, ResponseCompletedEvent): @@ -1091,6 +1120,7 @@ async def _run_single_turn( should_run_agent_start_hooks: bool, tool_use_tracker: AgentToolUseTracker, previous_response_id: str | None, + conversation_id: str | None, ) -> SingleStepResult: # Ensure we run the hooks before anything else if should_run_agent_start_hooks: @@ -1124,6 +1154,7 @@ async def _run_single_turn( run_config, tool_use_tracker, previous_response_id, + conversation_id, prompt_config, ) @@ -1318,6 +1349,7 @@ async def _get_new_response( run_config: RunConfig, tool_use_tracker: AgentToolUseTracker, previous_response_id: str | None, + conversation_id: str | None, prompt_config: ResponsePromptParam | None, ) -> ModelResponse: # Allow user to modify model input right before the call, if configured @@ -1352,6 +1384,7 @@ async def _get_new_response( run_config.tracing_disabled, run_config.trace_include_sensitive_data ), previous_response_id=previous_response_id, + conversation_id=conversation_id, prompt=prompt_config, ) # If the agent has hooks, we need to call them after the LLM call @@ -1473,4 +1506,3 @@ def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResp if isinstance(input, str): return input return input.copy() - diff --git a/tests/fake_model.py b/tests/fake_model.py index 6c1377e6d..7de629448 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -61,6 +61,7 @@ async def get_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: Any | None, ) -> ModelResponse: self.last_turn_args = { @@ -70,6 +71,7 @@ async def get_response( "tools": tools, "output_schema": output_schema, "previous_response_id": previous_response_id, + "conversation_id": conversation_id, } with generation_span(disabled=not self.tracing_enabled) as span: @@ -103,8 +105,9 @@ async def stream_response( handoffs: list[Handoff], tracing: ModelTracing, *, - previous_response_id: str | None, - prompt: Any | None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: Any | None = None, ) -> AsyncIterator[TResponseStreamEvent]: self.last_turn_args = { "system_instructions": system_instructions, @@ -113,6 +116,7 @@ async def stream_response( "tools": tools, "output_schema": output_schema, "previous_response_id": previous_response_id, + "conversation_id": conversation_id, } with generation_span(disabled=not self.tracing_enabled) as span: output = self.get_next_output() diff --git a/tests/models/test_kwargs_functionality.py b/tests/models/test_kwargs_functionality.py index 210610a02..941fdc68d 100644 --- a/tests/models/test_kwargs_functionality.py +++ b/tests/models/test_kwargs_functionality.py @@ -47,6 +47,7 @@ async def fake_acompletion(model, messages=None, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, ) # Verify that all kwargs were passed through diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py index bd38f8759..d8b79d542 100644 --- a/tests/models/test_litellm_chatcompletions_stream.py +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -90,6 +90,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -183,6 +184,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -273,6 +275,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -389,6 +392,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) diff --git a/tests/test_agent_prompt.py b/tests/test_agent_prompt.py index 010717d66..3d5ed5a3f 100644 --- a/tests/test_agent_prompt.py +++ b/tests/test_agent_prompt.py @@ -24,6 +24,7 @@ async def get_response( tracing, *, previous_response_id, + conversation_id, prompt, ): # Record the prompt that the agent resolved and passed in. @@ -37,6 +38,7 @@ async def get_response( handoffs, tracing, previous_response_id=previous_response_id, + conversation_id=conversation_id, prompt=prompt, ) diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index a6af30077..c6672374b 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -95,6 +95,7 @@ def __init__(self): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, ) assert "extra_headers" in called_kwargs assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index 6291418f6..d52d89b47 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -77,6 +77,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) # Should have produced exactly one output message with one text part @@ -129,6 +130,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) assert len(resp.output) == 1 @@ -182,6 +184,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) # Expect a message item followed by a function tool call item. @@ -224,6 +227,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) assert resp.output == [] diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index cbb3c5dae..947816f01 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -90,6 +90,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -183,6 +184,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -273,6 +275,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -390,6 +393,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) diff --git a/tests/test_reasoning_content.py b/tests/test_reasoning_content.py index 69e9a7d0c..a64fdaf15 100644 --- a/tests/test_reasoning_content.py +++ b/tests/test_reasoning_content.py @@ -129,6 +129,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -216,6 +217,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) @@ -270,6 +272,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index fe63e8ecb..a2d9b3c3d 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -69,7 +69,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -114,7 +115,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -157,7 +159,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -197,7 +200,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -251,7 +255,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -304,7 +309,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 611e6f255..94d87b994 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -55,6 +55,7 @@ async def get_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: Any | None, ) -> ModelResponse: raise NotImplementedError("Not implemented") @@ -70,6 +71,7 @@ async def stream_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: Any | None, ) -> AsyncIterator[TResponseStreamEvent]: output = self.get_next_output()