diff --git a/examples/memory/advanced_sqlite_session_example.py b/examples/memory/advanced_sqlite_session_example.py new file mode 100644 index 000000000..fe9d3aab4 --- /dev/null +++ b/examples/memory/advanced_sqlite_session_example.py @@ -0,0 +1,278 @@ +""" +Comprehensive example demonstrating AdvancedSQLiteSession functionality. + +This example shows both basic session memory features and advanced conversation +branching capabilities, including usage statistics, turn-based organization, +and multi-timeline conversation management. +""" + +import asyncio + +from agents import Agent, Runner, function_tool +from agents.extensions.memory import AdvancedSQLiteSession + + +@function_tool +async def get_weather(city: str) -> str: + if city.strip().lower() == "new york": + return f"The weather in {city} is cloudy." + return f"The weather in {city} is sunny." + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + tools=[get_weather], + ) + + # Create an advanced session instance + session = AdvancedSQLiteSession( + session_id="conversation_comprehensive", + create_tables=True, + ) + + print("=== AdvancedSQLiteSession Comprehensive Example ===") + print("This example demonstrates both basic and advanced session features.\n") + + # === PART 1: Basic Session Functionality === + print("=== PART 1: Basic Session Memory ===") + print("The agent will remember previous messages with structured tracking.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print(f"Usage: {result.context_wrapper.usage.total_tokens} tokens") + + # Store usage data automatically + await session.store_run_usage(result) + print() + + # Second turn - continuing the conversation + print("Second turn:") + print("User: What's the weather in that city?") + result = await Runner.run( + agent, + "What's the weather in that city?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print(f"Usage: {result.context_wrapper.usage.total_tokens} tokens") + + # Store usage data automatically + await session.store_run_usage(result) + print() + + # Third turn + print("Third turn:") + print("User: What's the population of that city?") + result = await Runner.run( + agent, + "What's the population of that city?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print(f"Usage: {result.context_wrapper.usage.total_tokens} tokens") + + # Store usage data automatically + await session.store_run_usage(result) + print() + + # === PART 2: Usage Tracking and Analytics === + print("=== PART 2: Usage Tracking and Analytics ===") + session_usage = await session.get_session_usage() + if session_usage: + print("Session Usage (aggregated from turns):") + print(f" Total requests: {session_usage['requests']}") + print(f" Total tokens: {session_usage['total_tokens']}") + print(f" Input tokens: {session_usage['input_tokens']}") + print(f" Output tokens: {session_usage['output_tokens']}") + print(f" Total turns: {session_usage['total_turns']}") + + # Show usage by turn + turn_usage_list = await session.get_turn_usage() + if turn_usage_list and isinstance(turn_usage_list, list): + print("\nUsage by turn:") + for turn_data in turn_usage_list: + turn_num = turn_data["user_turn_number"] + tokens = turn_data["total_tokens"] + print(f" Turn {turn_num}: {tokens} tokens") + else: + print("No usage data found.") + + print("\n=== Structured Query Demo ===") + conversation_turns = await session.get_conversation_by_turns() + print("Conversation by turns:") + for turn_num, items in conversation_turns.items(): + print(f" Turn {turn_num}: {len(items)} items") + for item in items: + if item["tool_name"]: + print(f" - {item['type']} (tool: {item['tool_name']})") + else: + print(f" - {item['type']}") + + # Show tool usage + tool_usage = await session.get_tool_usage() + if tool_usage: + print("\nTool usage:") + for tool_name, count, turn in tool_usage: + print(f" {tool_name}: used {count} times in turn {turn}") + else: + print("\nNo tool usage found.") + + print("\n=== Original Conversation Complete ===") + + # Show current conversation + print("Current conversation:") + current_items = await session.get_items() + for i, item in enumerate(current_items, 1): + role = str(item.get("role", item.get("type", "unknown"))) + if item.get("type") == "function_call": + content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})" + elif item.get("type") == "function_call_output": + content = str(item.get("output", "")) + else: + content = str(item.get("content", item.get("output", ""))) + print(f" {i}. {role}: {content}") + + print(f"\nTotal items: {len(current_items)}") + + # === PART 3: Conversation Branching === + print("\n=== PART 3: Conversation Branching ===") + print("Let's explore a different path from turn 2...") + + # Show available turns for branching + print("\nAvailable turns for branching:") + turns = await session.get_conversation_turns() + for turn in turns: + print(f" Turn {turn['turn']}: {turn['content']}") + + # Create a branch from turn 2 + print("\nCreating new branch from turn 2...") + branch_id = await session.create_branch_from_turn(2) + print(f"Created branch: {branch_id}") + + # Show what's in the new branch (should have conversation up to turn 2) + branch_items = await session.get_items() + print(f"Items copied to new branch: {len(branch_items)}") + print("New branch contains:") + for i, item in enumerate(branch_items, 1): + role = str(item.get("role", item.get("type", "unknown"))) + if item.get("type") == "function_call": + content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})" + elif item.get("type") == "function_call_output": + content = str(item.get("output", "")) + else: + content = str(item.get("content", item.get("output", ""))) + print(f" {i}. {role}: {content}") + + # Continue conversation in new branch + print("\nContinuing conversation in new branch...") + print("Turn 2 (new branch): User asks about New York instead") + result = await Runner.run( + agent, + "Actually, what's the weather in New York instead?", + session=session, + ) + print(f"Assistant: {result.final_output}") + await session.store_run_usage(result) + + # Continue the new branch + print("Turn 3 (new branch): User asks about NYC attractions") + result = await Runner.run( + agent, + "What are some famous attractions in New York?", + session=session, + ) + print(f"Assistant: {result.final_output}") + await session.store_run_usage(result) + + # Show the new conversation + print("\n=== New Conversation Branch ===") + new_conversation = await session.get_items() + print("New conversation with branch:") + for i, item in enumerate(new_conversation, 1): + role = str(item.get("role", item.get("type", "unknown"))) + if item.get("type") == "function_call": + content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})" + elif item.get("type") == "function_call_output": + content = str(item.get("output", "")) + else: + content = str(item.get("content", item.get("output", ""))) + print(f" {i}. {role}: {content}") + + print(f"\nTotal items in new branch: {len(new_conversation)}") + + # === PART 4: Branch Management === + print("\n=== PART 4: Branch Management ===") + # Show all branches + branches = await session.list_branches() + print("All branches in this session:") + for branch in branches: + current = " (current)" if branch["is_current"] else "" + print( + f" {branch['branch_id']}: {branch['user_turns']} user turns, {branch['message_count']} total messages{current}" + ) + + # Show conversation turns in current branch + print("\nConversation turns in current branch:") + current_turns = await session.get_conversation_turns() + for turn in current_turns: + print(f" Turn {turn['turn']}: {turn['content']}") + + print("\n=== Branch Switching Demo ===") + print("We can switch back to the main branch...") + + # Switch back to main branch + await session.switch_to_branch("main") + print("Switched to main branch") + + # Show what's in main branch + main_items = await session.get_items() + print(f"Items in main branch: {len(main_items)}") + + # Switch back to new branch + await session.switch_to_branch(branch_id) + branch_items = await session.get_items() + print(f"Items in new branch: {len(branch_items)}") + + print("\n=== Final Summary ===") + await session.switch_to_branch("main") + main_final = len(await session.get_items()) + await session.switch_to_branch(branch_id) + branch_final = len(await session.get_items()) + + print(f"Main branch items: {main_final}") + print(f"New branch items: {branch_final}") + + # Show that branches are completely independent + print("\nBranches are completely independent:") + print("- Main branch has full original conversation") + print("- New branch has turn 1 + new conversation path") + print("- No interference between branches!") + + print("\n=== Comprehensive Example Complete ===") + print("This demonstrates the full AdvancedSQLiteSession capabilities!") + print("Key features:") + print("- Structured conversation tracking with usage analytics") + print("- Turn-based organization and querying") + print("- Create branches from any user message") + print("- Branches inherit conversation history up to the branch point") + print("- Complete branch isolation - no interference between branches") + print("- Easy branch switching and management") + print("- No complex soft deletion - clean branch-based architecture") + print("- Perfect for building AI systems with conversation editing capabilities!") + + # Cleanup + session.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/extensions/memory/__init__.py b/src/agents/extensions/memory/__init__.py index 4e7bad61f..20d8e8624 100644 --- a/src/agents/extensions/memory/__init__.py +++ b/src/agents/extensions/memory/__init__.py @@ -13,6 +13,7 @@ __all__: list[str] = [ "EncryptedSession", "SQLAlchemySession", + "AdvancedSQLiteSession", ] @@ -39,4 +40,14 @@ def __getattr__(name: str) -> Any: "Install it with: pip install openai-agents[sqlalchemy]" ) from e + if name == "AdvancedSQLiteSession": + try: + from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401 + + return AdvancedSQLiteSession + except ModuleNotFoundError as e: + raise ImportError( + f"Failed to import AdvancedSQLiteSession: {e}" + ) from e + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py new file mode 100644 index 000000000..fefb73026 --- /dev/null +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -0,0 +1,1285 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import threading +from contextlib import closing +from pathlib import Path +from typing import Any, Union, cast + +from agents.result import RunResult +from agents.usage import Usage + +from ...items import TResponseInputItem +from ...memory import SQLiteSession + + +class AdvancedSQLiteSession(SQLiteSession): + """Enhanced SQLite session with conversation branching and usage analytics.""" + + def __init__( + self, + *, + session_id: str, + db_path: str | Path = ":memory:", + create_tables: bool = False, + logger: logging.Logger | None = None, + **kwargs, + ): + """Initialize the AdvancedSQLiteSession. + + Args: + session_id: The ID of the session + db_path: The path to the SQLite database file. Defaults to `:memory:` for in-memory storage + create_tables: Whether to create the structure tables + logger: The logger to use. Defaults to the module logger + **kwargs: Additional keyword arguments to pass to the superclass + """ # noqa: E501 + super().__init__(session_id, db_path, **kwargs) + if create_tables: + self._init_structure_tables() + self._current_branch_id = "main" + self._logger = logger or logging.getLogger(__name__) + + def _init_structure_tables(self): + """Add structure and usage tracking tables. + + Creates the message_structure and turn_usage tables with appropriate + indexes for conversation branching and usage analytics. + """ + conn = self._get_connection() + + # Message structure with branch support + conn.execute(""" + CREATE TABLE IF NOT EXISTS message_structure ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_id INTEGER NOT NULL, + branch_id TEXT NOT NULL DEFAULT 'main', + message_type TEXT NOT NULL, + sequence_number INTEGER NOT NULL, + user_turn_number INTEGER, + branch_turn_number INTEGER, + tool_name TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE, + FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE + ) + """) + + # Turn-level usage tracking with branch support and full JSON details + conn.execute(""" + CREATE TABLE IF NOT EXISTS turn_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + branch_id TEXT NOT NULL DEFAULT 'main', + user_turn_number INTEGER NOT NULL, + requests INTEGER DEFAULT 0, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + input_tokens_details JSON, + output_tokens_details JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE, + UNIQUE(session_id, branch_id, user_turn_number) + ) + """) + + # Indexes + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_session_seq + ON message_structure(session_id, sequence_number) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_branch + ON message_structure(session_id, branch_id) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_turn + ON message_structure(session_id, branch_id, user_turn_number) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_branch_seq + ON message_structure(session_id, branch_id, sequence_number) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn + ON turn_usage(session_id, branch_id, user_turn_number) + """) + + conn.commit() + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add items to the session. + + Args: + items: The items to add to the session + """ + # Add to base table first + await super().add_items(items) + + # Extract structure metadata with precise sequencing + if items: + await self._add_structure_metadata(items) + + async def get_items( + self, + limit: int | None = None, + branch_id: str | None = None, + ) -> list[TResponseInputItem]: + """Get items from current or specified branch. + + Args: + limit: Maximum number of items to return. If None, returns all items. + branch_id: Branch to get items from. If None, uses current branch. + + Returns: + List of conversation items from the specified branch. + """ + if branch_id is None: + branch_id = self._current_branch_id + + # Get all items for this branch + def _get_all_items_sync(): + """Synchronous helper to get all items for a branch.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + with closing(conn.cursor()) as cursor: + if limit is None: + cursor.execute( + """ + SELECT m.message_data + FROM agent_messages m + JOIN message_structure s ON m.id = s.message_id + WHERE m.session_id = ? AND s.branch_id = ? + ORDER BY s.sequence_number ASC + """, + (self.session_id, branch_id), + ) + else: + cursor.execute( + """ + SELECT m.message_data + FROM agent_messages m + JOIN message_structure s ON m.id = s.message_id + WHERE m.session_id = ? AND s.branch_id = ? + ORDER BY s.sequence_number DESC + LIMIT ? + """, + (self.session_id, branch_id, limit), + ) + + rows = cursor.fetchall() + if limit is not None: + rows = list(reversed(rows)) + + items = [] + for (message_data,) in rows: + try: + item = json.loads(message_data) + items.append(item) + except json.JSONDecodeError: + continue + return items + + return await asyncio.to_thread(_get_all_items_sync) + + def _get_items_sync(): + """Synchronous helper to get items for a specific branch.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + with closing(conn.cursor()) as cursor: + # Get message IDs in correct order for this branch + if limit is None: + cursor.execute( + """ + SELECT m.message_data + FROM agent_messages m + JOIN message_structure s ON m.id = s.message_id + WHERE m.session_id = ? AND s.branch_id = ? + ORDER BY s.sequence_number ASC + """, + (self.session_id, branch_id), + ) + else: + cursor.execute( + """ + SELECT m.message_data + FROM agent_messages m + JOIN message_structure s ON m.id = s.message_id + WHERE m.session_id = ? AND s.branch_id = ? + ORDER BY s.sequence_number DESC + LIMIT ? + """, + (self.session_id, branch_id, limit), + ) + + rows = cursor.fetchall() + if limit is not None: + rows = list(reversed(rows)) + + items = [] + for (message_data,) in rows: + try: + item = json.loads(message_data) + items.append(item) + except json.JSONDecodeError: + continue + return items + + return await asyncio.to_thread(_get_items_sync) + + async def store_run_usage(self, result: RunResult) -> None: + """Store usage data for the current conversation turn. + + This is designed to be called after `Runner.run()` completes. + Session-level usage can be aggregated from turn data when needed. + + Args: + result: The result from the run + """ + try: + if result.context_wrapper.usage is not None: + # Get the current turn number for this branch + current_turn = self._get_current_turn_number() + # Only update turn-level usage - session usage is aggregated on demand + await self._update_turn_usage_internal(current_turn, result.context_wrapper.usage) + except Exception as e: + self._logger.error(f"Failed to store usage for session {self.session_id}: {e}") + + def _get_next_turn_number(self, branch_id: str) -> int: + """Get the next turn number for a specific branch. + + Args: + branch_id: The branch ID to get the next turn number for. + + Returns: + The next available turn number for the specified branch. + """ + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(user_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + result = cursor.fetchone() + max_turn = result[0] if result else 0 + return max_turn + 1 + + def _get_next_branch_turn_number(self, branch_id: str) -> int: + """Get the next branch turn number for a specific branch. + + Args: + branch_id: The branch ID to get the next branch turn number for. + + Returns: + The next available branch turn number for the specified branch. + """ + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(branch_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + result = cursor.fetchone() + max_turn = result[0] if result else 0 + return max_turn + 1 + + def _get_current_turn_number(self) -> int: + """Get the current turn number for the current branch. + + Returns: + The current turn number for the active branch. + """ + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(user_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, self._current_branch_id), + ) + result = cursor.fetchone() + return result[0] if result else 0 + + async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None: + """Extract structure metadata with branch-aware turn tracking. + + This method: + - Assigns turn numbers per branch (not globally) + - Assigns explicit sequence numbers for precise ordering + - Links messages to their database IDs for structure tracking + - Handles multiple user messages in a single batch correctly + + Args: + items: The items to add to the session + """ + + def _add_structure_sync(): + """Synchronous helper to add structure metadata to database.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + # Get the IDs of messages we just inserted, in order + with closing(conn.cursor()) as cursor: + cursor.execute( + f"SELECT id FROM {self.messages_table} " + f"WHERE session_id = ? ORDER BY id DESC LIMIT ?", + (self.session_id, len(items)), + ) + message_ids = [row[0] for row in cursor.fetchall()] + message_ids.reverse() # Match order of items + + # Get current max sequence number (global) + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(sequence_number), 0) + FROM message_structure + WHERE session_id = ? + """, + (self.session_id,), + ) + seq_start = cursor.fetchone()[0] + + # Get current turn numbers atomically with a single query + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT + COALESCE(MAX(user_turn_number), 0) as max_global_turn, + COALESCE(MAX(branch_turn_number), 0) as max_branch_turn + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, self._current_branch_id), + ) + result = cursor.fetchone() + current_turn = result[0] if result else 0 + current_branch_turn = result[1] if result else 0 + + # Process items and assign turn numbers correctly + structure_data = [] + user_message_count = 0 + + for i, (item, msg_id) in enumerate(zip(items, message_ids)): + msg_type = self._classify_message_type(item) + tool_name = self._extract_tool_name(item) + + # If this is a user message, increment turn counters + if self._is_user_message(item): + user_message_count += 1 + item_turn = current_turn + user_message_count + item_branch_turn = current_branch_turn + user_message_count + else: + # Non-user messages inherit the turn number of the most recent user message + item_turn = current_turn + user_message_count + item_branch_turn = current_branch_turn + user_message_count + + structure_data.append( + ( + self.session_id, + msg_id, + self._current_branch_id, + msg_type, + seq_start + i + 1, # Global sequence + item_turn, # Global turn number + item_branch_turn, # Branch-specific turn number + tool_name, + ) + ) + + with closing(conn.cursor()) as cursor: + cursor.executemany( + """ + INSERT INTO message_structure + (session_id, message_id, branch_id, message_type, sequence_number, + user_turn_number, branch_turn_number, tool_name) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + structure_data, + ) + conn.commit() + + try: + await asyncio.to_thread(_add_structure_sync) + except Exception as e: + self._logger.error( + f"Failed to add structure metadata for session {self.session_id}: {e}" + ) + # Try to clean up any orphaned messages to maintain consistency + try: + await self._cleanup_orphaned_messages() + except Exception as cleanup_error: + self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}") + # Don't re-raise - structure metadata is supplementary + + async def _cleanup_orphaned_messages(self) -> None: + """Remove messages that exist in agent_messages but not in message_structure. + + This can happen if _add_structure_metadata fails after super().add_items() succeeds. + Used for maintaining data consistency. + """ + + def _cleanup_sync(): + """Synchronous helper to cleanup orphaned messages.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + with closing(conn.cursor()) as cursor: + # Find messages without structure metadata + cursor.execute( + """ + SELECT am.id + FROM agent_messages am + LEFT JOIN message_structure ms ON am.id = ms.message_id + WHERE am.session_id = ? AND ms.message_id IS NULL + """, + (self.session_id,), + ) + + orphaned_ids = [row[0] for row in cursor.fetchall()] + + if orphaned_ids: + # Delete orphaned messages + placeholders = ",".join("?" * len(orphaned_ids)) + cursor.execute( + f"DELETE FROM agent_messages WHERE id IN ({placeholders})", orphaned_ids + ) + + deleted_count = cursor.rowcount + conn.commit() + + self._logger.info(f"Cleaned up {deleted_count} orphaned messages") + return deleted_count + + return 0 + + return await asyncio.to_thread(_cleanup_sync) + + def _classify_message_type(self, item: TResponseInputItem) -> str: + """Classify the type of a message item. + + Args: + item: The message item to classify. + + Returns: + String representing the message type (user, assistant, etc.). + """ + if isinstance(item, dict): + if item.get("role") == "user": + return "user" + elif item.get("role") == "assistant": + return "assistant" + elif item.get("type"): + return str(item.get("type")) + return "other" + + def _extract_tool_name(self, item: TResponseInputItem) -> str | None: + """Extract tool name if this is a tool call/output. + + Args: + item: The message item to extract tool name from. + + Returns: + Tool name if item is a tool call, None otherwise. + """ + if isinstance(item, dict): + item_type = item.get("type") + + # For MCP tools, try to extract from server_label if available + if item_type in {"mcp_call", "mcp_approval_request"} and "server_label" in item: + server_label = item.get("server_label") + tool_name = item.get("name") + if tool_name and server_label: + return f"{server_label}.{tool_name}" + elif server_label: + return str(server_label) + elif tool_name: + return str(tool_name) + + # For tool types without a 'name' field, derive from the type + elif item_type in { + "computer_call", + "file_search_call", + "web_search_call", + "code_interpreter_call", + }: + return item_type + + # Most other tool calls have a 'name' field + elif "name" in item: + name = item.get("name") + return str(name) if name is not None else None + + return None + + def _is_user_message(self, item: TResponseInputItem) -> bool: + """Check if this is a user message. + + Args: + item: The message item to check. + + Returns: + True if the item is a user message, False otherwise. + """ + return isinstance(item, dict) and item.get("role") == "user" + + async def create_branch_from_turn( + self, turn_number: int, branch_name: str | None = None + ) -> str: + """Create a new branch starting from a specific user message turn. + + Args: + turn_number: The branch turn number of the user message to branch from + branch_name: Optional name for the branch (auto-generated if None) + + Returns: + The branch_id of the newly created branch + + Raises: + ValueError: If turn doesn't exist or doesn't contain a user message + """ + import time + + # Validate the turn exists and contains a user message + def _validate_turn(): + """Synchronous helper to validate turn exists and contains user message.""" + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT am.message_data + FROM message_structure ms + JOIN agent_messages am ON ms.message_id = am.id + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.branch_turn_number = ? AND ms.message_type = 'user' + """, + (self.session_id, self._current_branch_id, turn_number), + ) + + result = cursor.fetchone() + if not result: + raise ValueError( + f"Turn {turn_number} does not contain a user message " + f"in branch '{self._current_branch_id}'" + ) + + message_data = result[0] + try: + content = json.loads(message_data).get("content", "") + return content[:50] + "..." if len(content) > 50 else content + except Exception: + return "Unable to parse content" + + turn_content = await asyncio.to_thread(_validate_turn) + + # Generate branch name if not provided + if branch_name is None: + timestamp = int(time.time()) + branch_name = f"branch_from_turn_{turn_number}_{timestamp}" + + # Copy messages before the branch point to the new branch + await self._copy_messages_to_new_branch(branch_name, turn_number) + + # Switch to new branch + old_branch = self._current_branch_id + self._current_branch_id = branch_name + + self._logger.debug( + f"Created branch '{branch_name}' from turn {turn_number} ('{turn_content}') in '{old_branch}'" # noqa: E501 + ) + return branch_name + + async def create_branch_from_content( + self, search_term: str, branch_name: str | None = None + ) -> str: + """Create branch from the first user turn matching the search term. + + Args: + search_term: Text to search for in user messages. + branch_name: Optional name for the branch (auto-generated if None). + + Returns: + The branch_id of the newly created branch. + + Raises: + ValueError: If no matching turns are found. + """ + matching_turns = await self.find_turns_by_content(search_term) + if not matching_turns: + raise ValueError(f"No user turns found containing '{search_term}'") + + # Use the first (earliest) match + turn_number = matching_turns[0]["turn"] + return await self.create_branch_from_turn(turn_number, branch_name) + + async def switch_to_branch(self, branch_id: str) -> None: + """Switch to a different branch. + + Args: + branch_id: The branch to switch to. + + Raises: + ValueError: If the branch doesn't exist. + """ + + # Validate branch exists + def _validate_branch(): + """Synchronous helper to validate branch exists.""" + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COUNT(*) FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + + count = cursor.fetchone()[0] + if count == 0: + raise ValueError(f"Branch '{branch_id}' does not exist") + + await asyncio.to_thread(_validate_branch) + + old_branch = self._current_branch_id + self._current_branch_id = branch_id + self._logger.info(f"Switched from branch '{old_branch}' to '{branch_id}'") + + async def delete_branch(self, branch_id: str, force: bool = False) -> None: + """Delete a branch and all its associated data. + + Args: + branch_id: The branch to delete. + force: If True, allows deleting the current branch (will switch to 'main'). + + Raises: + ValueError: If branch doesn't exist, is 'main', or is current branch without force. + """ + if not branch_id or not branch_id.strip(): + raise ValueError("Branch ID cannot be empty") + + branch_id = branch_id.strip() + + # Protect main branch + if branch_id == "main": + raise ValueError("Cannot delete the 'main' branch") + + # Check if trying to delete current branch + if branch_id == self._current_branch_id: + if not force: + raise ValueError( + f"Cannot delete current branch '{branch_id}'. Use force=True or switch branches first" # noqa: E501 + ) + else: + # Switch to main before deleting + await self.switch_to_branch("main") + + def _delete_sync(): + """Synchronous helper to delete branch and associated data.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + with closing(conn.cursor()) as cursor: + # First verify the branch exists + cursor.execute( + """ + SELECT COUNT(*) FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + + count = cursor.fetchone()[0] + if count == 0: + raise ValueError(f"Branch '{branch_id}' does not exist") + + # Delete from turn_usage first (foreign key constraint) + cursor.execute( + """ + DELETE FROM turn_usage + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + + usage_deleted = cursor.rowcount + + # Delete from message_structure + cursor.execute( + """ + DELETE FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + + structure_deleted = cursor.rowcount + + conn.commit() + + return usage_deleted, structure_deleted + + usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync) + + self._logger.info( + f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501 + ) + + async def list_branches(self) -> list[dict[str, Any]]: + """List all branches in this session. + + Returns: + List of dicts with branch info containing: + - 'branch_id': Branch identifier + - 'message_count': Number of messages in branch + - 'user_turns': Number of user turns in branch + - 'is_current': Whether this is the current branch + - 'created_at': When the branch was first created + """ + + def _list_branches_sync(): + """Synchronous helper to list all branches.""" + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT + ms.branch_id, + COUNT(*) as message_count, + COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns, + MIN(ms.created_at) as created_at + FROM message_structure ms + WHERE ms.session_id = ? + GROUP BY ms.branch_id + ORDER BY created_at + """, + (self.session_id,), + ) + + branches = [] + for row in cursor.fetchall(): + branch_id, msg_count, user_turns, created_at = row + branches.append( + { + "branch_id": branch_id, + "message_count": msg_count, + "user_turns": user_turns, + "is_current": branch_id == self._current_branch_id, + "created_at": created_at, + } + ) + + return branches + + return await asyncio.to_thread(_list_branches_sync) + + async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_number: int) -> None: + """Copy messages before the branch point to the new branch. + + Args: + new_branch_id: The ID of the new branch to copy messages to. + from_turn_number: The turn number to copy messages up to (exclusive). + """ + + def _copy_sync(): + """Synchronous helper to copy messages to new branch.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + with closing(conn.cursor()) as cursor: + # Get all messages before the branch point + cursor.execute( + """ + SELECT + ms.message_id, + ms.message_type, + ms.sequence_number, + ms.user_turn_number, + ms.branch_turn_number, + ms.tool_name + FROM message_structure ms + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.branch_turn_number < ? + ORDER BY ms.sequence_number + """, + (self.session_id, self._current_branch_id, from_turn_number), + ) + + messages_to_copy = cursor.fetchall() + + if messages_to_copy: + # Get the max sequence number for the new inserts + cursor.execute( + """ + SELECT COALESCE(MAX(sequence_number), 0) + FROM message_structure + WHERE session_id = ? + """, + (self.session_id,), + ) + + seq_start = cursor.fetchone()[0] + + # Insert copied messages with new branch_id + new_structure_data = [] + for i, ( + msg_id, + msg_type, + _, + user_turn, + branch_turn, + tool_name, + ) in enumerate(messages_to_copy): + new_structure_data.append( + ( + self.session_id, + msg_id, # Same message_id (sharing the actual message data) + new_branch_id, + msg_type, + seq_start + i + 1, # New sequence number + user_turn, # Keep same global turn number + branch_turn, # Keep same branch turn number + tool_name, + ) + ) + + cursor.executemany( + """ + INSERT INTO message_structure + (session_id, message_id, branch_id, message_type, sequence_number, + user_turn_number, branch_turn_number, tool_name) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + new_structure_data, + ) + + conn.commit() + + await asyncio.to_thread(_copy_sync) + + async def get_conversation_turns(self, branch_id: str | None = None) -> list[dict[str, Any]]: + """Get user turns with content for easy browsing and branching decisions. + + Args: + branch_id: Branch to get turns from (current branch if None). + + Returns: + List of dicts with turn info containing: + - 'turn': Branch turn number + - 'content': User message content (truncated) + - 'full_content': Full user message content + - 'timestamp': When the turn was created + - 'can_branch': Always True (all user messages can branch) + """ + if branch_id is None: + branch_id = self._current_branch_id + + def _get_turns_sync(): + """Synchronous helper to get conversation turns.""" + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT + ms.branch_turn_number, + am.message_data, + ms.created_at + FROM message_structure ms + JOIN agent_messages am ON ms.message_id = am.id + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.message_type = 'user' + ORDER BY ms.branch_turn_number + """, + (self.session_id, branch_id), + ) + + turns = [] + for row in cursor.fetchall(): + turn_num, message_data, created_at = row + try: + content = json.loads(message_data).get("content", "") + turns.append( + { + "turn": turn_num, + "content": content[:100] + "..." if len(content) > 100 else content, + "full_content": content, + "timestamp": created_at, + "can_branch": True, + } + ) + except (json.JSONDecodeError, AttributeError): + continue + + return turns + + return await asyncio.to_thread(_get_turns_sync) + + async def find_turns_by_content( + self, search_term: str, branch_id: str | None = None + ) -> list[dict[str, Any]]: + """Find user turns containing specific content. + + Args: + search_term: Text to search for in user messages. + branch_id: Branch to search in (current branch if None). + + Returns: + List of matching turns with same format as get_conversation_turns(). + """ + if branch_id is None: + branch_id = self._current_branch_id + + def _search_sync(): + """Synchronous helper to search turns by content.""" + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT + ms.branch_turn_number, + am.message_data, + ms.created_at + FROM message_structure ms + JOIN agent_messages am ON ms.message_id = am.id + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.message_type = 'user' + AND am.message_data LIKE ? + ORDER BY ms.branch_turn_number + """, + (self.session_id, branch_id, f"%{search_term}%"), + ) + + matches = [] + for row in cursor.fetchall(): + turn_num, message_data, created_at = row + try: + content = json.loads(message_data).get("content", "") + matches.append( + { + "turn": turn_num, + "content": content, + "full_content": content, + "timestamp": created_at, + "can_branch": True, + } + ) + except (json.JSONDecodeError, AttributeError): + continue + + return matches + + return await asyncio.to_thread(_search_sync) + + async def get_conversation_by_turns( + self, branch_id: str | None = None + ) -> dict[int, list[dict[str, str | None]]]: + """Get conversation grouped by user turns for specified branch. + + Args: + branch_id: Branch to get conversation from (current branch if None). + + Returns: + Dictionary mapping turn numbers to lists of message metadata. + """ + if branch_id is None: + branch_id = self._current_branch_id + + def _get_conversation_sync(): + """Synchronous helper to get conversation by turns.""" + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT user_turn_number, message_type, tool_name + FROM message_structure + WHERE session_id = ? AND branch_id = ? + ORDER BY sequence_number + """, + (self.session_id, branch_id), + ) + + turns: dict[int, list[dict[str, str | None]]] = {} + for row in cursor.fetchall(): + turn_num, msg_type, tool_name = row + if turn_num not in turns: + turns[turn_num] = [] + turns[turn_num].append({"type": msg_type, "tool_name": tool_name}) + return turns + + return await asyncio.to_thread(_get_conversation_sync) + + async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]: + """Get all tool usage by turn for specified branch. + + Args: + branch_id: Branch to get tool usage from (current branch if None). + + Returns: + List of tuples containing (tool_name, usage_count, turn_number). + """ + if branch_id is None: + branch_id = self._current_branch_id + + def _get_tool_usage_sync(): + """Synchronous helper to get tool usage statistics.""" + conn = self._get_connection() + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT tool_name, COUNT(*), user_turn_number + FROM message_structure + WHERE session_id = ? AND branch_id = ? AND message_type IN ( + 'tool_call', 'function_call', 'computer_call', 'file_search_call', + 'web_search_call', 'code_interpreter_call', 'custom_tool_call', + 'mcp_call', 'mcp_approval_request' + ) + GROUP BY tool_name, user_turn_number + ORDER BY user_turn_number + """, + (self.session_id, branch_id), + ) + return cursor.fetchall() + + return await asyncio.to_thread(_get_tool_usage_sync) + + async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None: + """Get cumulative usage for session or specific branch. + + Args: + branch_id: If provided, only get usage for that branch. If None, get all branches. + + Returns: + Dictionary with usage statistics or None if no usage data found. + """ + + def _get_usage_sync(): + """Synchronous helper to get session usage data.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + if branch_id: + # Branch-specific usage + query = """ + SELECT + SUM(requests) as total_requests, + SUM(input_tokens) as total_input_tokens, + SUM(output_tokens) as total_output_tokens, + SUM(total_tokens) as total_total_tokens, + COUNT(*) as total_turns + FROM turn_usage + WHERE session_id = ? AND branch_id = ? + """ + params: tuple[str, ...] = (self.session_id, branch_id) + else: + # All branches + query = """ + SELECT + SUM(requests) as total_requests, + SUM(input_tokens) as total_input_tokens, + SUM(output_tokens) as total_output_tokens, + SUM(total_tokens) as total_total_tokens, + COUNT(*) as total_turns + FROM turn_usage + WHERE session_id = ? + """ + params = (self.session_id,) + + with closing(conn.cursor()) as cursor: + cursor.execute(query, params) + row = cursor.fetchone() + + if row and row[0] is not None: + return { + "requests": row[0] or 0, + "input_tokens": row[1] or 0, + "output_tokens": row[2] or 0, + "total_tokens": row[3] or 0, + "total_turns": row[4] or 0, + } + return None + + result = await asyncio.to_thread(_get_usage_sync) + + return cast(Union[dict[str, int], None], result) + + async def get_turn_usage( + self, + user_turn_number: int | None = None, + branch_id: str | None = None, + ) -> list[dict[str, Any]] | dict[str, Any]: + """Get usage statistics by turn with full JSON token details. + + Args: + user_turn_number: Specific turn to get usage for. If None, returns all turns. + branch_id: Branch to get usage from (current branch if None). + + Returns: + Dictionary with usage data for specific turn, or list of dictionaries for all turns. + """ + + if branch_id is None: + branch_id = self._current_branch_id + + def _get_turn_usage_sync(): + """Synchronous helper to get turn usage statistics.""" + conn = self._get_connection() + + if user_turn_number is not None: + query = """ + SELECT requests, input_tokens, output_tokens, total_tokens, + input_tokens_details, output_tokens_details + FROM turn_usage + WHERE session_id = ? AND branch_id = ? AND user_turn_number = ? + """ + + with closing(conn.cursor()) as cursor: + cursor.execute(query, (self.session_id, branch_id, user_turn_number)) + row = cursor.fetchone() + + if row: + # Parse JSON details if present + input_details = None + output_details = None + + if row[4]: # input_tokens_details + try: + input_details = json.loads(row[4]) + except json.JSONDecodeError: + pass + + if row[5]: # output_tokens_details + try: + output_details = json.loads(row[5]) + except json.JSONDecodeError: + pass + + return { + "requests": row[0], + "input_tokens": row[1], + "output_tokens": row[2], + "total_tokens": row[3], + "input_tokens_details": input_details, + "output_tokens_details": output_details, + } + return {} + else: + query = """ + SELECT user_turn_number, requests, input_tokens, output_tokens, + total_tokens, input_tokens_details, output_tokens_details + FROM turn_usage + WHERE session_id = ? AND branch_id = ? + ORDER BY user_turn_number + """ + + with closing(conn.cursor()) as cursor: + cursor.execute(query, (self.session_id, branch_id)) + results = [] + for row in cursor.fetchall(): + # Parse JSON details if present + input_details = None + output_details = None + + if row[5]: # input_tokens_details + try: + input_details = json.loads(row[5]) + except json.JSONDecodeError: + pass + + if row[6]: # output_tokens_details + try: + output_details = json.loads(row[6]) + except json.JSONDecodeError: + pass + + results.append( + { + "user_turn_number": row[0], + "requests": row[1], + "input_tokens": row[2], + "output_tokens": row[3], + "total_tokens": row[4], + "input_tokens_details": input_details, + "output_tokens_details": output_details, + } + ) + return results + + result = await asyncio.to_thread(_get_turn_usage_sync) + + return cast(Union[list[dict[str, Any]], dict[str, Any]], result) + + async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: Usage) -> None: + """Internal method to update usage for a specific turn with full JSON details. + + Args: + user_turn_number: The turn number to update usage for. + usage_data: The usage data to store. + """ + + def _update_sync(): + """Synchronous helper to update turn usage data.""" + conn = self._get_connection() + # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 + with self._lock if self._is_memory_db else threading.Lock(): + # Serialize token details as JSON + input_details_json = None + output_details_json = None + + if hasattr(usage_data, "input_tokens_details") and usage_data.input_tokens_details: + try: + input_details_json = json.dumps(usage_data.input_tokens_details.__dict__) + except (TypeError, ValueError) as e: + self._logger.warning(f"Failed to serialize input tokens details: {e}") + input_details_json = None + + if ( + hasattr(usage_data, "output_tokens_details") + and usage_data.output_tokens_details + ): + try: + output_details_json = json.dumps( + usage_data.output_tokens_details.__dict__ + ) + except (TypeError, ValueError) as e: + self._logger.warning(f"Failed to serialize output tokens details: {e}") + output_details_json = None + + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + INSERT OR REPLACE INTO turn_usage + (session_id, branch_id, user_turn_number, requests, input_tokens, output_tokens, + total_tokens, input_tokens_details, output_tokens_details) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, # noqa: E501 + ( + self.session_id, + self._current_branch_id, + user_turn_number, + usage_data.requests or 0, + usage_data.input_tokens or 0, + usage_data.output_tokens or 0, + usage_data.total_tokens or 0, + input_details_json, + output_details_json, + ), + ) + conn.commit() + + await asyncio.to_thread(_update_sync) diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py new file mode 100644 index 000000000..d352e6c40 --- /dev/null +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -0,0 +1,986 @@ +"""Tests for AdvancedSQLiteSession functionality.""" + +from typing import Any, Optional, cast + +import pytest + +pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents import Agent, Runner, TResponseInputItem, function_tool +from agents.extensions.memory import AdvancedSQLiteSession +from agents.result import RunResult +from agents.run_context import RunContextWrapper +from agents.usage import Usage +from tests.fake_model import FakeModel +from tests.test_responses import get_text_message + +# Mark all tests in this file as asyncio +pytestmark = pytest.mark.asyncio + + +@function_tool +async def test_tool(query: str) -> str: + """A test tool for testing tool call tracking.""" + return f"Tool result for: {query}" + + +@pytest.fixture +def agent() -> Agent: + """Fixture for a basic agent with a fake model.""" + return Agent(name="test", model=FakeModel(), tools=[test_tool]) + + +@pytest.fixture +def usage_data() -> Usage: + """Fixture for test usage data.""" + return Usage( + requests=1, + input_tokens=50, + output_tokens=30, + total_tokens=80, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + ) + + +def create_mock_run_result( + usage: Optional[Usage] = None, agent: Optional[Agent] = None +) -> RunResult: + """Helper function to create a mock RunResult for testing.""" + if agent is None: + agent = Agent(name="test", model=FakeModel()) + + if usage is None: + usage = Usage( + requests=1, + input_tokens=50, + output_tokens=30, + total_tokens=80, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + ) + + context_wrapper = RunContextWrapper(context=None, usage=usage) + + return RunResult( + input="test input", + new_items=[], + raw_responses=[], + final_output="test output", + input_guardrail_results=[], + output_guardrail_results=[], + context_wrapper=context_wrapper, + _last_agent=agent, + ) + + +async def test_advanced_session_basic_functionality(agent: Agent): + """Test basic AdvancedSQLiteSession functionality.""" + session_id = "advanced_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Test basic session operations work + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + await session.add_items(items) + + # Get items and verify + retrieved = await session.get_items() + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Hello" + assert retrieved[1].get("content") == "Hi there!" + + session.close() + + +async def test_message_structure_tracking(agent: Agent): + """Test that message structure is properly tracked.""" + session_id = "structure_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add various types of messages + items: list[TResponseInputItem] = [ + {"role": "user", "content": "What's 2+2?"}, + {"type": "function_call", "name": "calculator", "arguments": '{"expression": "2+2"}'}, # type: ignore + {"type": "function_call_output", "output": "4"}, # type: ignore + {"role": "assistant", "content": "The answer is 4"}, + {"type": "reasoning", "summary": [{"text": "Simple math", "type": "summary_text"}]}, # type: ignore + ] + await session.add_items(items) + + # Get conversation structure + conversation_turns = await session.get_conversation_by_turns() + assert len(conversation_turns) == 1 # Should be one user turn + + turn_1_items = conversation_turns[1] + assert len(turn_1_items) == 5 + + # Verify item types are classified correctly + item_types = [item["type"] for item in turn_1_items] + assert "user" in item_types + assert "function_call" in item_types + assert "function_call_output" in item_types + assert "assistant" in item_types + assert "reasoning" in item_types + + session.close() + + +async def test_tool_usage_tracking(agent: Agent): + """Test tool usage tracking functionality.""" + session_id = "tools_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add items with tool calls + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Search for cats"}, + {"type": "function_call", "name": "web_search", "arguments": '{"query": "cats"}'}, # type: ignore + {"type": "function_call_output", "output": "Found cat information"}, # type: ignore + {"type": "function_call", "name": "calculator", "arguments": '{"expression": "1+1"}'}, # type: ignore + {"type": "function_call_output", "output": "2"}, # type: ignore + {"role": "assistant", "content": "I found information about cats and calculated 1+1=2"}, + ] + await session.add_items(items) + + # Get tool usage + tool_usage = await session.get_tool_usage() + assert len(tool_usage) == 2 # Two different tools used + + tool_names = {usage[0] for usage in tool_usage} + assert "web_search" in tool_names + assert "calculator" in tool_names + + session.close() + + +async def test_branching_functionality(agent: Agent): + """Test branching functionality - create, switch, and delete branches.""" + session_id = "branching_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add multiple turns to main branch + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "First question"}, + {"role": "assistant", "content": "First answer"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Second question"}, + {"role": "assistant", "content": "Second answer"}, + ] + await session.add_items(turn_2_items) + + turn_3_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Third question"}, + {"role": "assistant", "content": "Third answer"}, + ] + await session.add_items(turn_3_items) + + # Verify all items are in main branch + all_items = await session.get_items() + assert len(all_items) == 6 + + # Create a branch from turn 2 + branch_name = await session.create_branch_from_turn(2, "test_branch") + assert branch_name == "test_branch" + + # Verify we're now on the new branch + assert session._current_branch_id == "test_branch" + + # Verify the branch has the same content up to turn 2 (copies messages before turn 2) + branch_items = await session.get_items() + assert len(branch_items) == 2 # Only first turn items (before turn 2) + assert branch_items[0].get("content") == "First question" + assert branch_items[1].get("content") == "First answer" + + # Switch back to main branch + await session.switch_to_branch("main") + assert session._current_branch_id == "main" + + # Verify main branch still has all items + main_items = await session.get_items() + assert len(main_items) == 6 + + # List branches + branches = await session.list_branches() + assert len(branches) == 2 + branch_ids = [b["branch_id"] for b in branches] + assert "main" in branch_ids + assert "test_branch" in branch_ids + + # Delete the test branch + await session.delete_branch("test_branch") + + # Verify branch is deleted + branches_after_delete = await session.list_branches() + assert len(branches_after_delete) == 1 + assert branches_after_delete[0]["branch_id"] == "main" + + session.close() + + +async def test_get_conversation_turns(): + """Test get_conversation_turns functionality.""" + session_id = "conversation_turns_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add multiple turns + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello there"}, + {"role": "assistant", "content": "Hi!"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well, thanks!"}, + ] + await session.add_items(turn_2_items) + + # Get conversation turns + turns = await session.get_conversation_turns() + assert len(turns) == 2 + + # Verify turn structure + assert turns[0]["turn"] == 1 + assert turns[0]["content"] == "Hello there" + assert turns[0]["full_content"] == "Hello there" + assert turns[0]["can_branch"] is True + assert "timestamp" in turns[0] + + assert turns[1]["turn"] == 2 + assert turns[1]["content"] == "How are you doing today?" + assert turns[1]["full_content"] == "How are you doing today?" + assert turns[1]["can_branch"] is True + + session.close() + + +async def test_find_turns_by_content(): + """Test find_turns_by_content functionality.""" + session_id = "find_turns_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add multiple turns with different content + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Tell me about cats"}, + {"role": "assistant", "content": "Cats are great pets"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "What about dogs?"}, + {"role": "assistant", "content": "Dogs are also great pets"}, + ] + await session.add_items(turn_2_items) + + turn_3_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Tell me about cats again"}, + {"role": "assistant", "content": "Cats are wonderful companions"}, + ] + await session.add_items(turn_3_items) + + # Search for turns containing "cats" + cat_turns = await session.find_turns_by_content("cats") + assert len(cat_turns) == 2 + assert cat_turns[0]["turn"] == 1 + assert cat_turns[1]["turn"] == 3 + + # Search for turns containing "dogs" + dog_turns = await session.find_turns_by_content("dogs") + assert len(dog_turns) == 1 + assert dog_turns[0]["turn"] == 2 + + # Search for non-existent content + no_turns = await session.find_turns_by_content("elephants") + assert len(no_turns) == 0 + + session.close() + + +async def test_create_branch_from_content(): + """Test create_branch_from_content functionality.""" + session_id = "branch_from_content_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add multiple turns + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "First question about math"}, + {"role": "assistant", "content": "Math answer"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Second question about science"}, + {"role": "assistant", "content": "Science answer"}, + ] + await session.add_items(turn_2_items) + + turn_3_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Another math question"}, + {"role": "assistant", "content": "Another math answer"}, + ] + await session.add_items(turn_3_items) + + # Create branch from first occurrence of "math" + branch_name = await session.create_branch_from_content("math", "math_branch") + assert branch_name == "math_branch" + + # Verify we're on the new branch + assert session._current_branch_id == "math_branch" + + # Verify branch contains only items up to the first math turn (copies messages before turn 1) + branch_items = await session.get_items() + assert len(branch_items) == 0 # No messages before turn 1 + + # Test error case - search term not found + with pytest.raises(ValueError, match="No user turns found containing 'nonexistent'"): + await session.create_branch_from_content("nonexistent", "error_branch") + + session.close() + + +async def test_branch_specific_operations(): + """Test operations that work with specific branches.""" + session_id = "branch_specific_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add items to main branch + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Main branch question"}, + {"role": "assistant", "content": "Main branch answer"}, + ] + await session.add_items(turn_1_items) + + # Add usage data for main branch + usage_main = Usage(requests=1, input_tokens=50, output_tokens=30, total_tokens=80) + run_result_main = create_mock_run_result(usage_main) + await session.store_run_usage(run_result_main) + + # Create a branch from turn 1 (copies messages before turn 1, so empty) + await session.create_branch_from_turn(1, "test_branch") + + # Add items to the new branch + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Branch question"}, + {"role": "assistant", "content": "Branch answer"}, + ] + await session.add_items(turn_2_items) + + # Add usage data for branch + usage_branch = Usage(requests=1, input_tokens=40, output_tokens=20, total_tokens=60) + run_result_branch = create_mock_run_result(usage_branch) + await session.store_run_usage(run_result_branch) + + # Test get_items with branch_id parameter + main_items = await session.get_items(branch_id="main") + assert len(main_items) == 2 + assert main_items[0].get("content") == "Main branch question" + + current_items = await session.get_items() # Should get from current branch + assert len(current_items) == 2 # Only the items added to the branch (copied branch is empty) + + # Test get_conversation_turns with branch_id + main_turns = await session.get_conversation_turns(branch_id="main") + assert len(main_turns) == 1 + assert main_turns[0]["content"] == "Main branch question" + + current_turns = await session.get_conversation_turns() # Should get from current branch + assert len(current_turns) == 1 # Only one turn in the current branch + + # Test get_session_usage with branch_id + main_usage = await session.get_session_usage(branch_id="main") + assert main_usage is not None + assert main_usage["total_turns"] == 1 + + all_usage = await session.get_session_usage() # Should get from all branches + assert all_usage is not None + assert all_usage["total_turns"] == 2 # Main branch has 1, current branch has 1 + + session.close() + + +async def test_branch_error_handling(): + """Test error handling in branching operations.""" + session_id = "branch_error_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Test creating branch from non-existent turn + with pytest.raises(ValueError, match="Turn 5 does not contain a user message"): + await session.create_branch_from_turn(5, "error_branch") + + # Test switching to non-existent branch + with pytest.raises(ValueError, match="Branch 'nonexistent' does not exist"): + await session.switch_to_branch("nonexistent") + + # Test deleting non-existent branch + with pytest.raises(ValueError, match="Branch 'nonexistent' does not exist"): + await session.delete_branch("nonexistent") + + # Test deleting main branch + with pytest.raises(ValueError, match="Cannot delete the 'main' branch"): + await session.delete_branch("main") + + # Test deleting empty branch ID + with pytest.raises(ValueError, match="Branch ID cannot be empty"): + await session.delete_branch("") + + # Test deleting empty branch ID (whitespace only) + with pytest.raises(ValueError, match="Branch ID cannot be empty"): + await session.delete_branch(" ") + + session.close() + + +async def test_branch_deletion_with_force(): + """Test branch deletion with force parameter.""" + session_id = "force_delete_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add items to main branch + await session.add_items([{"role": "user", "content": "Main question"}]) + await session.add_items([{"role": "user", "content": "Second question"}]) + + # Create and switch to a branch from turn 2 + await session.create_branch_from_turn(2, "temp_branch") + assert session._current_branch_id == "temp_branch" + + # Add some content to the branch so it exists + await session.add_items([{"role": "user", "content": "Branch question"}]) + + # Verify branch exists + branches = await session.list_branches() + branch_ids = [b["branch_id"] for b in branches] + assert "temp_branch" in branch_ids + + # Try to delete current branch without force (should fail) + with pytest.raises(ValueError, match="Cannot delete current branch"): + await session.delete_branch("temp_branch") + + # Delete current branch with force (should succeed and switch to main) + await session.delete_branch("temp_branch", force=True) + + # Verify we're back on main branch + assert session._current_branch_id == "main" + + # Verify branch is deleted + branches_after = await session.list_branches() + assert len(branches_after) == 1 + assert branches_after[0]["branch_id"] == "main" + + session.close() + + +async def test_get_items_with_parameters(): + """Test get_items with new parameters (include_inactive, branch_id).""" + session_id = "get_items_params_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add items to main branch + items: list[TResponseInputItem] = [ + {"role": "user", "content": "First question"}, + {"role": "assistant", "content": "First answer"}, + {"role": "user", "content": "Second question"}, + {"role": "assistant", "content": "Second answer"}, + ] + await session.add_items(items) + + # Test get_items with limit (gets most recent N items) + limited_items = await session.get_items(limit=2) + assert len(limited_items) == 2 + assert limited_items[0].get("content") == "Second question" # Most recent first + assert limited_items[1].get("content") == "Second answer" + + # Test get_items with branch_id + main_items = await session.get_items(branch_id="main") + assert len(main_items) == 4 + + # Test get_items (no longer has include_inactive parameter) + all_items = await session.get_items() + assert len(all_items) == 4 + + # Create a branch from turn 2 and test branch-specific get_items + await session.create_branch_from_turn(2, "test_branch") + + # Add items to branch + branch_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Branch question"}, + {"role": "assistant", "content": "Branch answer"}, + ] + await session.add_items(branch_items) + + # Test getting items from specific branch (should include copied items + new items) + branch_items_result = await session.get_items(branch_id="test_branch") + assert len(branch_items_result) == 4 # 2 copied from main (before turn 2) + 2 new items + + # Test getting items from main branch while on different branch + main_items_from_branch = await session.get_items(branch_id="main") + assert len(main_items_from_branch) == 4 + + session.close() + + +async def test_usage_tracking_storage(agent: Agent, usage_data: Usage): + """Test usage data storage and retrieval.""" + session_id = "usage_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Simulate adding items for turn 1 to increment turn counter + await session.add_items([{"role": "user", "content": "First turn"}]) + run_result_1 = create_mock_run_result(usage_data) + await session.store_run_usage(run_result_1) + + # Create different usage data for turn 2 + usage_data_2 = Usage( + requests=2, + input_tokens=75, + output_tokens=45, + total_tokens=120, + input_tokens_details=InputTokensDetails(cached_tokens=20), + output_tokens_details=OutputTokensDetails(reasoning_tokens=15), + ) + + # Simulate adding items for turn 2 to increment turn counter + await session.add_items([{"role": "user", "content": "Second turn"}]) + run_result_2 = create_mock_run_result(usage_data_2) + await session.store_run_usage(run_result_2) + + # Test session-level usage aggregation + session_usage = await session.get_session_usage() + assert session_usage is not None + assert session_usage["requests"] == 3 # 1 + 2 + assert session_usage["total_tokens"] == 200 # 80 + 120 + assert session_usage["input_tokens"] == 125 # 50 + 75 + assert session_usage["output_tokens"] == 75 # 30 + 45 + assert session_usage["total_turns"] == 2 + + # Test turn-level usage retrieval + turn_1_usage = await session.get_turn_usage(1) + assert isinstance(turn_1_usage, dict) + assert turn_1_usage["requests"] == 1 + assert turn_1_usage["total_tokens"] == 80 + assert turn_1_usage["input_tokens_details"]["cached_tokens"] == 10 + assert turn_1_usage["output_tokens_details"]["reasoning_tokens"] == 5 + + turn_2_usage = await session.get_turn_usage(2) + assert isinstance(turn_2_usage, dict) + assert turn_2_usage["requests"] == 2 + assert turn_2_usage["total_tokens"] == 120 + assert turn_2_usage["input_tokens_details"]["cached_tokens"] == 20 + assert turn_2_usage["output_tokens_details"]["reasoning_tokens"] == 15 + + # Test getting all turn usage + all_turn_usage = await session.get_turn_usage() + assert isinstance(all_turn_usage, list) + assert len(all_turn_usage) == 2 + assert all_turn_usage[0]["user_turn_number"] == 1 + assert all_turn_usage[1]["user_turn_number"] == 2 + + session.close() + + +async def test_runner_integration_with_usage_tracking(agent: Agent): + """Test integration with Runner and automatic usage tracking pattern.""" + session_id = "integration_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + async def store_session_usage(result: Any, session: AdvancedSQLiteSession): + """Helper function to store usage after runner completes.""" + try: + await session.store_run_usage(result) + except Exception: + # Ignore errors in test helper + pass + + # Set up fake model responses + assert isinstance(agent.model, FakeModel) + fake_model = agent.model + fake_model.set_next_output([get_text_message("San Francisco")]) + + # First turn + result1 = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + assert result1.final_output == "San Francisco" + await store_session_usage(result1, session) + + # Second turn + fake_model.set_next_output([get_text_message("California")]) + result2 = await Runner.run(agent, "What state is it in?", session=session) + assert result2.final_output == "California" + await store_session_usage(result2, session) + + # Verify conversation structure + conversation_turns = await session.get_conversation_by_turns() + assert len(conversation_turns) == 2 + + # Verify usage was tracked + session_usage = await session.get_session_usage() + assert session_usage is not None + assert session_usage["total_turns"] == 2 + # FakeModel doesn't generate realistic usage data, so we just check structure exists + assert "requests" in session_usage + assert "total_tokens" in session_usage + + session.close() + + +async def test_sequence_ordering(): + """Test that sequence ordering works correctly even with same timestamps.""" + session_id = "sequence_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add multiple items quickly to test sequence ordering + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Response 1"}, + {"role": "user", "content": "Message 2"}, + {"role": "assistant", "content": "Response 2"}, + ] + await session.add_items(items) + + # Get items and verify order is preserved + retrieved = await session.get_items() + assert len(retrieved) == 4 + assert retrieved[0].get("content") == "Message 1" + assert retrieved[1].get("content") == "Response 1" + assert retrieved[2].get("content") == "Message 2" + assert retrieved[3].get("content") == "Response 2" + + session.close() + + +async def test_conversation_structure_with_multiple_turns(): + """Test conversation structure tracking with multiple user turns.""" + session_id = "multi_turn_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Turn 1 + turn_1: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + await session.add_items(turn_1) + + # Turn 2 + turn_2: list[TResponseInputItem] = [ + {"role": "user", "content": "How are you?"}, + {"type": "function_call", "name": "mood_check", "arguments": "{}"}, # type: ignore + {"type": "function_call_output", "output": "I'm good"}, # type: ignore + {"role": "assistant", "content": "I'm doing well!"}, + ] + await session.add_items(turn_2) + + # Turn 3 + turn_3: list[TResponseInputItem] = [ + {"role": "user", "content": "Goodbye"}, + {"role": "assistant", "content": "See you later!"}, + ] + await session.add_items(turn_3) + + # Verify conversation structure + conversation_turns = await session.get_conversation_by_turns() + assert len(conversation_turns) == 3 + + # Turn 1 should have 2 items + assert len(conversation_turns[1]) == 2 + assert conversation_turns[1][0]["type"] == "user" + assert conversation_turns[1][1]["type"] == "assistant" + + # Turn 2 should have 4 items including tool calls + assert len(conversation_turns[2]) == 4 + turn_2_types = [item["type"] for item in conversation_turns[2]] + assert "user" in turn_2_types + assert "function_call" in turn_2_types + assert "function_call_output" in turn_2_types + assert "assistant" in turn_2_types + + # Turn 3 should have 2 items + assert len(conversation_turns[3]) == 2 + + session.close() + + +async def test_empty_session_operations(): + """Test operations on empty sessions.""" + session_id = "empty_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Test getting items from empty session + items = await session.get_items() + assert len(items) == 0 + + # Test getting conversation from empty session + conversation = await session.get_conversation_by_turns() + assert len(conversation) == 0 + + # Test getting tool usage from empty session + tool_usage = await session.get_tool_usage() + assert len(tool_usage) == 0 + + # Test getting session usage from empty session + session_usage = await session.get_session_usage() + assert session_usage is None + + # Test getting turns from empty session + turns = await session.get_conversation_turns() + assert len(turns) == 0 + + session.close() + + +async def test_json_serialization_edge_cases(usage_data: Usage): + """Test edge cases in JSON serialization of usage data.""" + session_id = "json_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Test with normal usage data (need to add user message first to create turn) + await session.add_items([{"role": "user", "content": "First test"}]) + run_result_1 = create_mock_run_result(usage_data) + await session.store_run_usage(run_result_1) + + # Test with None usage data + run_result_none = create_mock_run_result(None) + await session.store_run_usage(run_result_none) + + # Test with usage data missing details + minimal_usage = Usage( + requests=1, + input_tokens=10, + output_tokens=5, + total_tokens=15, + ) + await session.add_items([{"role": "user", "content": "Second test"}]) + run_result_2 = create_mock_run_result(minimal_usage) + await session.store_run_usage(run_result_2) + + # Verify we can retrieve the data + turn_1_usage = await session.get_turn_usage(1) + assert isinstance(turn_1_usage, dict) + assert turn_1_usage["requests"] == 1 + assert turn_1_usage["input_tokens_details"]["cached_tokens"] == 10 + + turn_2_usage = await session.get_turn_usage(2) + assert isinstance(turn_2_usage, dict) + assert turn_2_usage["requests"] == 1 + # Should have default values for minimal data (Usage class provides defaults) + assert turn_2_usage["input_tokens_details"]["cached_tokens"] == 0 + assert turn_2_usage["output_tokens_details"]["reasoning_tokens"] == 0 + + session.close() + + +async def test_session_isolation(): + """Test that different session IDs maintain separate data.""" + session1 = AdvancedSQLiteSession(session_id="session_1", create_tables=True) + session2 = AdvancedSQLiteSession(session_id="session_2", create_tables=True) + + # Add data to session 1 + await session1.add_items([{"role": "user", "content": "Session 1 message"}]) + + # Add data to session 2 + await session2.add_items([{"role": "user", "content": "Session 2 message"}]) + + # Verify isolation + session1_items = await session1.get_items() + session2_items = await session2.get_items() + + assert len(session1_items) == 1 + assert len(session2_items) == 1 + assert session1_items[0].get("content") == "Session 1 message" + assert session2_items[0].get("content") == "Session 2 message" + + # Test conversation structure isolation + session1_turns = await session1.get_conversation_by_turns() + session2_turns = await session2.get_conversation_by_turns() + + assert len(session1_turns) == 1 + assert len(session2_turns) == 1 + + session1.close() + session2.close() + + +async def test_error_handling_in_usage_tracking(usage_data: Usage): + """Test that usage tracking errors don't break the main flow.""" + session_id = "error_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Test normal operation + run_result = create_mock_run_result(usage_data) + await session.store_run_usage(run_result) + + # Close the session to simulate database errors + session.close() + + # This should not raise an exception (error should be caught) + await session.store_run_usage(run_result) + + +async def test_advanced_tool_name_extraction(): + """Test advanced tool name extraction for different tool types.""" + session_id = "advanced_tool_names_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add items with various tool types and naming patterns + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Use various tools"}, + # MCP tools with server labels + {"type": "mcp_call", "server_label": "filesystem", "name": "read_file", "arguments": "{}"}, # type: ignore + { + "type": "mcp_approval_request", + "server_label": "database", + "name": "execute_query", + "arguments": "{}", + }, # type: ignore + # Built-in tool types + {"type": "computer_call", "arguments": "{}"}, # type: ignore + {"type": "file_search_call", "arguments": "{}"}, # type: ignore + {"type": "web_search_call", "arguments": "{}"}, # type: ignore + {"type": "code_interpreter_call", "arguments": "{}"}, # type: ignore + # Regular function calls + {"type": "function_call", "name": "calculator", "arguments": "{}"}, # type: ignore + {"type": "custom_tool_call", "name": "custom_tool", "arguments": "{}"}, # type: ignore + ] + await session.add_items(items) + + # Get conversation structure and verify tool names + conversation_turns = await session.get_conversation_by_turns() + turn_items = conversation_turns[1] + + tool_items = [item for item in turn_items if item["tool_name"]] + tool_names = [item["tool_name"] for item in tool_items] + + # Verify MCP tools get server_label.name format + assert "filesystem.read_file" in tool_names + assert "database.execute_query" in tool_names + + # Verify built-in tools use their type as name + assert "computer_call" in tool_names + assert "file_search_call" in tool_names + assert "web_search_call" in tool_names + assert "code_interpreter_call" in tool_names + + # Verify regular function calls use their name + assert "calculator" in tool_names + assert "custom_tool" in tool_names + + session.close() + + +async def test_branch_usage_tracking(): + """Test usage tracking across different branches.""" + session_id = "branch_usage_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add items and usage to main branch + await session.add_items([{"role": "user", "content": "Main question"}]) + usage_main = Usage(requests=1, input_tokens=50, output_tokens=30, total_tokens=80) + run_result_main = create_mock_run_result(usage_main) + await session.store_run_usage(run_result_main) + + # Create a branch and add usage there + await session.create_branch_from_turn(1, "usage_branch") + await session.add_items([{"role": "user", "content": "Branch question"}]) + usage_branch = Usage(requests=2, input_tokens=100, output_tokens=60, total_tokens=160) + run_result_branch = create_mock_run_result(usage_branch) + await session.store_run_usage(run_result_branch) + + # Test branch-specific usage + main_usage = await session.get_session_usage(branch_id="main") + assert main_usage is not None + assert main_usage["requests"] == 1 + assert main_usage["total_tokens"] == 80 + assert main_usage["total_turns"] == 1 + + branch_usage = await session.get_session_usage(branch_id="usage_branch") + assert branch_usage is not None + assert branch_usage["requests"] == 2 + assert branch_usage["total_tokens"] == 160 + assert branch_usage["total_turns"] == 1 + + # Test total usage across all branches + total_usage = await session.get_session_usage() + assert total_usage is not None + assert total_usage["requests"] == 3 # 1 + 2 + assert total_usage["total_tokens"] == 240 # 80 + 160 + assert total_usage["total_turns"] == 2 + + # Test turn usage for specific branch + branch_turn_usage = await session.get_turn_usage(branch_id="usage_branch") + assert isinstance(branch_turn_usage, list) + assert len(branch_turn_usage) == 1 + assert branch_turn_usage[0]["requests"] == 2 + + session.close() + + +async def test_tool_name_extraction(): + """Test that tool names are correctly extracted from different item types.""" + session_id = "tool_names_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Add items with different ways of specifying tool names + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Use tools please"}, # Need user message to create turn + {"type": "function_call", "name": "search_web", "arguments": "{}"}, # type: ignore + {"type": "function_call_output", "tool_name": "search_web", "output": "result"}, # type: ignore + {"type": "function_call", "name": "calculator", "arguments": "{}"}, # type: ignore + ] + await session.add_items(items) + + # Get conversation structure and verify tool names + conversation_turns = await session.get_conversation_by_turns() + turn_items = conversation_turns[1] + + tool_items = [item for item in turn_items if item["tool_name"]] + tool_names = [item["tool_name"] for item in tool_items] + + assert "search_web" in tool_names + assert "calculator" in tool_names + + session.close() + + +async def test_tool_execution_integration(agent: Agent): + """Test integration with actual tool execution.""" + session_id = "tool_integration_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Set up the fake model to trigger a tool call + fake_model = cast(FakeModel, agent.model) + fake_model.set_next_output( + [ + { # type: ignore + "type": "function_call", + "name": "test_tool", + "arguments": '{"query": "test query"}', + "call_id": "call_123", + } + ] + ) + + # Then set the final response + fake_model.set_next_output([get_text_message("Tool executed successfully")]) + + # Run the agent + result = await Runner.run( + agent, + "Please use the test tool", + session=session, + ) + + # Verify the tool was executed + assert "Tool result for: test query" in str(result.new_items) + + # Verify tool usage was tracked + tool_usage = await session.get_tool_usage() + assert len(tool_usage) > 0 + + session.close()