diff --git a/docs/sessions.md b/docs/sessions.md index c66cb85ae..1cfe6cd4d 100644 --- a/docs/sessions.md +++ b/docs/sessions.md @@ -141,6 +141,46 @@ result = await Runner.run( ) ``` +### Structured metadata + +By default, `SQLiteSession` stores all conversation events as JSON blobs in a single table. You can enable structured metadata to create additional tables for messages, tool calls, and per-response usage: + +```python +from agents import SQLiteSession + +# Enable structured metadata storage +session = SQLiteSession( + "user_123", + "conversations.db", + structured_metadata=True, +) + +# This creates additional tables: +# - agent_conversation_messages: stores user, assistant, and system messages +# - agent_tool_calls: stores tool call requests and outputs +# - agent_usage: stores per-response usage (model name, token counts) with trace/span attribution +``` + +With structured metadata enabled, you can query conversations and usage using standard SQL: + +```sql +-- Get all user messages in a session +SELECT content FROM agent_conversation_messages +WHERE session_id = 'user_123' AND role = 'user'; + +-- Get all tool calls and their results +SELECT tool_name, arguments, output, status +FROM agent_tool_calls +WHERE session_id = 'user_123'; + +-- Inspect usage records (model, token counts) and spans +SELECT response_id, model, requests, input_tokens, output_tokens, total_tokens, + trace_id, span_id, created_at +FROM agent_usage +WHERE session_id = 'user_123' +ORDER BY created_at DESC; +``` + ### Multiple sessions ```python diff --git a/examples/basic/structured_metadata_session.py b/examples/basic/structured_metadata_session.py new file mode 100644 index 000000000..8400e20d9 --- /dev/null +++ b/examples/basic/structured_metadata_session.py @@ -0,0 +1,133 @@ +"""A script to test and demonstrate the structured metadata session storage feature.""" + +import asyncio +import os +import random +import sqlite3 +import sys + +# Add the parent directory to the path to import from the local package +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from agents import Agent, Runner, SQLiteSession, function_tool + + +async def main(): + # Create a tool + @function_tool + def get_random_number(max_val: int) -> int: + """Get a random number between 0 and max_val.""" + return random.randint(0, max_val) + + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely. When using tools, explain what you're doing.", + tools=[get_random_number], + ) + + # Create a session with structured storage enabled + db_path = "structured_conversation_demo.db" + session = SQLiteSession("demo_session", db_path, structured_metadata=True) + + print("=== Structured Session Storage Demo ===") + print("This demo shows structured storage that makes conversations easy to query.\n") + + # First turn + print("First turn:") + print("User: Pick 3 random numbers between 0 and 100") + result = await Runner.run(agent, "Pick 3 random numbers between 0 and 100", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What number did you pick for me?") + result = await Runner.run(agent, "What number did you pick for me?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - another tool call + print("Third turn:") + print("User: Now pick one more number between 0 and 50") + result = await Runner.run(agent, "Now pick one more number between 0 and 50", session=session) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print(f"Data stored in: {db_path}") + print() + + # Now demonstrate the structured storage benefits + print("=== Structured Storage Analysis ===") + print("With structured storage, you can easily query the conversation:") + print() + + conn = sqlite3.connect(db_path) + + # Show all messages + print("1. All conversation messages:") + cursor = conn.execute(""" + SELECT role, content FROM agent_conversation_messages + WHERE session_id = 'demo_session' + ORDER BY created_at + """) + for role, content in cursor.fetchall(): + content_preview = content[:60] + "..." if len(content) > 60 else content + print(f" {role}: {content_preview}") + print() + + # Show all tool calls + print("2. All tool calls and results:") + cursor = conn.execute(""" + SELECT tool_name, arguments, output, status + FROM agent_tool_calls + WHERE session_id = 'demo_session' + ORDER BY created_at + """) + for tool_name, arguments, output, status in cursor.fetchall(): + print(f" Tool: {tool_name}") + print(f" Args: {arguments}") + print(f" Result: {output}") + print(f" Status: {status}") + print() + + # Show message count by role + print("3. Message count by role:") + cursor = conn.execute(""" + SELECT role, COUNT(*) as count + FROM agent_conversation_messages + WHERE session_id = 'demo_session' + GROUP BY role + """) + for role, count in cursor.fetchall(): + print(f" {role}: {count} messages") + print() + + # Show usage rows with model and spans + print("4. Usage records (per model response):") + cursor = conn.execute( + """ + SELECT response_id, model, requests, input_tokens, output_tokens, total_tokens, trace_id, span_id, created_at + FROM agent_usage + WHERE session_id = 'demo_session' + ORDER BY created_at + """ + ) + usage_rows = cursor.fetchall() + if not usage_rows: + print(" (no usage rows found — ensure your model/provider returns usage)") + for row in usage_rows: + response_id, model, requests, in_toks, out_toks, total, trace_id, span_id, created_at = row + print( + f" model={model} resp_id={response_id} reqs={requests} in={in_toks} out={out_toks} total={total}" + ) + print(f" trace={trace_id} span={span_id} at={created_at}") + print() + + conn.close() + session.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 73fcf3e56..04f3def43 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -4,11 +4,12 @@ import logging import struct from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, assert_never +from typing import TYPE_CHECKING, Any from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from typing_extensions import assert_never from agents.realtime import RealtimeRunner, RealtimeSession, RealtimeSessionEvent diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 6c417b308..166631ee4 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -548,6 +548,14 @@ async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: with function_span(func_tool.name) as span_fn: + # Register span info so session storage can attribute this tool call + try: + # noqa: WPS433 import inside to avoid circular dependency + from .memory.session import register_tool_call_span + + register_tool_call_span(tool_call.call_id, span_fn.trace_id, span_fn.span_id) + except Exception: + pass # Non-critical tool_context = ToolContext.from_agent_context( context_wrapper, tool_call.call_id, diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 8db0971eb..446e40e2f 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -9,7 +9,72 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable if TYPE_CHECKING: - from ..items import TResponseInputItem + from ..items import ModelResponse, TResponseInputItem + +from ..tracing import get_current_span + +# Registry mapping tool call IDs to their exact function span (trace_id, span_id) +_TOOL_CALL_SPAN_REGISTRY: dict[str, tuple[str | None, str | None]] = {} + +# Registry mapping response IDs to their model span (trace_id, span_id) +_RESPONSE_SPAN_REGISTRY: dict[str, tuple[str | None, str | None]] = {} + +# Registry mapping trace_id to the "last" model response span seen in that trace +_LAST_RESPONSE_SPAN_BY_TRACE: dict[str | None, tuple[str | None, str | None]] = {} + + +def register_tool_call_span(call_id: str, trace_id: str | None, span_id: str | None) -> None: + """Registers a mapping between a tool-call ID and the span that executed it.""" + _TOOL_CALL_SPAN_REGISTRY[call_id] = (trace_id, span_id) + + +def pop_tool_call_span(call_id: str) -> tuple[str | None, str | None] | None: + """Retrieve & remove a span mapping for the given tool-call ID, if present.""" + return _TOOL_CALL_SPAN_REGISTRY.pop(call_id, None) + + +def register_response_span( + response_id: str | None, trace_id: str | None, span_id: str | None +) -> None: # noqa: E501 + """Registers a mapping between a model response ID and its response/generation span. + + If response_id is None (provider doesn't return one), only the per-trace cache is updated. + """ + _LAST_RESPONSE_SPAN_BY_TRACE[trace_id] = (trace_id, span_id) + if response_id: + _RESPONSE_SPAN_REGISTRY[response_id] = (trace_id, span_id) + + +def get_response_span(response_id: str) -> tuple[str | None, str | None] | None: + """Retrieve a span mapping for the given response ID, if present.""" + return _RESPONSE_SPAN_REGISTRY.get(response_id) + + +def get_last_response_span_for_trace(trace_id: str | None) -> tuple[str | None, str | None] | None: + """Retrieve the last seen model response span for the given trace ID, if present.""" + return _LAST_RESPONSE_SPAN_BY_TRACE.get(trace_id) + + +# Registry for model names (by response_id and by trace) +_RESPONSE_MODEL_REGISTRY: dict[str, str | None] = {} +_LAST_MODEL_BY_TRACE: dict[str | None, str | None] = {} + + +def register_response_model( + response_id: str | None, trace_id: str | None, model: str | None +) -> None: # noqa: E501 + """Registers a mapping for model names by response_id and by trace.""" + _LAST_MODEL_BY_TRACE[trace_id] = model + if response_id: + _RESPONSE_MODEL_REGISTRY[response_id] = model + + +def get_response_model(response_id: str) -> str | None: + return _RESPONSE_MODEL_REGISTRY.get(response_id) + + +def get_last_model_for_trace(trace_id: str | None) -> str | None: + return _LAST_MODEL_BY_TRACE.get(trace_id) @runtime_checkable @@ -118,6 +183,11 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + *, + structured_metadata: bool = False, + conversation_table: str = "agent_conversation_messages", + tool_calls_table: str = "agent_tool_calls", + usage_table: str = "agent_usage", ): """Initialize the SQLite session. @@ -127,11 +197,23 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + structured_metadata: If True, enables structured storage mode, creating + additional tables for messages and tool calls. Defaults to False. + conversation_table: Name for the structured conversation messages table. + Defaults to 'agent_conversation_messages'. + tool_calls_table: Name for the structured tool calls table. + Defaults to 'agent_tool_calls'. + usage_table: Name for the structured usage table. + Defaults to 'agent_usage'. """ self.session_id = session_id self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table + self.structured_metadata = structured_metadata + self.conversation_table = conversation_table + self.tool_calls_table = tool_calls_table + self.usage_table = usage_table self._local = threading.local() self._lock = threading.Lock() @@ -141,11 +223,13 @@ def __init__( if self._is_memory_db: self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._shared_connection.execute("PRAGMA foreign_keys=ON") self._init_db_for_connection(self._shared_connection) else: # For file databases, initialize the schema once since it persists init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) init_conn.execute("PRAGMA journal_mode=WAL") + init_conn.execute("PRAGMA foreign_keys=ON") self._init_db_for_connection(init_conn) init_conn.close() @@ -162,6 +246,7 @@ def _get_connection(self) -> sqlite3.Connection: check_same_thread=False, ) self._local.connection.execute("PRAGMA journal_mode=WAL") + self._local.connection.execute("PRAGMA foreign_keys=ON") assert isinstance(self._local.connection, sqlite3.Connection), ( f"Expected sqlite3.Connection, got {type(self._local.connection)}" ) @@ -201,6 +286,127 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: conn.commit() + # Create additional structured tables if enabled + if self.structured_metadata: + # Conversation messages table + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.conversation_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + raw_event_id INTEGER NOT NULL, + role TEXT, + content TEXT, + parent_raw_event_id INTEGER, + trace_id TEXT, + span_id TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) + ON DELETE CASCADE, + FOREIGN KEY (raw_event_id) REFERENCES {self.messages_table} (id) + ON DELETE CASCADE + ) + """ + ) + + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.conversation_table}_session_id + ON {self.conversation_table} (session_id, created_at) + """ + ) + + # Tool calls table + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.tool_calls_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + raw_event_id INTEGER NOT NULL, + call_id TEXT, + tool_name TEXT, + arguments JSON, + output JSON, + status TEXT, + trace_id TEXT, + span_id TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) + ON DELETE CASCADE, + FOREIGN KEY (raw_event_id) REFERENCES {self.messages_table} (id) + ON DELETE CASCADE + ) + """ + ) + + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.tool_calls_table}_session_id + ON {self.tool_calls_table} (session_id, created_at) + """ + ) + + # Usage table (per LLM response) + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.usage_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + response_id TEXT, + model TEXT, + requests INTEGER, + input_tokens INTEGER, + output_tokens INTEGER, + total_tokens INTEGER, + input_tokens_details JSON, + output_tokens_details JSON, + trace_id TEXT, + span_id TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) + ON DELETE CASCADE + ) + """ + ) + + # Indexes for faster queries + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.conversation_table}_trace + ON {self.conversation_table} (trace_id, created_at) + """ + ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.conversation_table}_span + ON {self.conversation_table} (span_id, created_at) + """ + ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.tool_calls_table}_trace + ON {self.tool_calls_table} (trace_id, created_at) + """ + ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.tool_calls_table}_span + ON {self.tool_calls_table} (span_id, created_at) + """ + ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.usage_table}_trace + ON {self.usage_table} (trace_id, created_at) + """ + ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.usage_table}_response + ON {self.usage_table} (response_id) + """ + ) + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. @@ -278,13 +484,131 @@ def _add_items_sync(): ) # Add items - message_data = [(self.session_id, json.dumps(item)) for item in items] - conn.executemany( - f""" - INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) - """, - message_data, - ) + if not self.structured_metadata: + # Flat storage: bulk insert for performance + message_data = [(self.session_id, json.dumps(item)) for item in items] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + else: + # Structured storage: insert each item individually so we can capture rowid + current_span = get_current_span() + _trace_id = current_span.trace_id if current_span else None + _span_id = current_span.span_id if current_span else None + + last_user_raw_event_id: int | None = None + assistant_seen_count = 0 + for item in items: + raw_json = json.dumps(item) + cursor = conn.execute( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) + VALUES (?, ?) + RETURNING id + """, + (self.session_id, raw_json), + ) + raw_event_id = cursor.fetchone()[0] + + # Handle structured inserts + if "role" in item: + role = item.get("role") + content_val = item.get("content") + try: + content_str = ( + json.dumps(content_val) if content_val is not None else None + ) + except TypeError: + content_str = str(content_val) + + parent_raw_event_id = ( + last_user_raw_event_id if role == "assistant" else None + ) + + # Attribute assistant messages to the model response span if available + _msg_trace_id = _trace_id + _msg_span_id = _span_id + if role == "assistant": + try: + maybe_span = get_last_response_span_for_trace(_trace_id) + if maybe_span: + _msg_trace_id, _msg_span_id = maybe_span + except Exception: + pass + + conn.execute( + f""" + INSERT INTO {self.conversation_table} + ( + session_id, raw_event_id, role, content, + parent_raw_event_id, trace_id, span_id + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + self.session_id, + raw_event_id, + role, + content_str, + parent_raw_event_id, + _msg_trace_id, + _msg_span_id, + ), + ) + + if role == "user": + last_user_raw_event_id = raw_event_id + elif role == "assistant": + assistant_seen_count += 1 + + event_type = item.get("type") + if event_type == "function_call": + call_id = item.get("call_id") + tool_name = item.get("name") + arguments_val = item.get("arguments") + # If a precise function-span mapping exists, use it + if call_id: + mapped = pop_tool_call_span( + str(call_id) if call_id is not None else "" + ) + if mapped: + _trace_id, _span_id = mapped + conn.execute( + f""" + INSERT INTO {self.tool_calls_table} + ( + session_id, raw_event_id, call_id, tool_name, + arguments, status, trace_id, span_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + self.session_id, + raw_event_id, + call_id, + tool_name, + arguments_val, + item.get("status"), + _trace_id, + _span_id, + ), + ) + elif event_type == "function_call_output": + call_id = item.get("call_id") + output_val = item.get("output") + conn.execute( + f""" + UPDATE {self.tool_calls_table} + SET output = ?, status = 'completed' + WHERE session_id = ? AND call_id = ? + """, + ( + json.dumps(output_val) if output_val is not None else None, + self.session_id, + call_id, + ), + ) # Update session timestamp conn.execute( @@ -300,6 +624,103 @@ def _add_items_sync(): await asyncio.to_thread(_add_items_sync) + async def add_usage_records(self, responses: list[ModelResponse]) -> None: + """Optionally store usage rows for a set of model responses. + + Best-effort and only active when structured_metadata=True. It is safe to call even if + structured_metadata=False. + """ + if not self.structured_metadata or not responses: + return + + def _add_usage_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + current_span = get_current_span() + _trace_id = current_span.trace_id if current_span else None + _span_id = current_span.span_id if current_span else None + + def _to_json_text(obj: object | None) -> str | None: + if obj is None: + return None + try: + return json.dumps(obj) + except TypeError: + # Try common object-to-dict conversions (e.g., Pydantic models) + try: + if hasattr(obj, "model_dump"): + return json.dumps(obj.model_dump()) + if hasattr(obj, "dict"): + return json.dumps(obj.dict()) + if hasattr(obj, "__dict__"): + return json.dumps(obj.__dict__) + except Exception: + pass + # Fallback to string representation + return json.dumps(str(obj)) + + for resp in responses: + usage = getattr(resp, "usage", None) + response_id = getattr(resp, "response_id", None) + if usage is None: + continue + + # Details may not be JSON-serializable; store as JSON-encoded strings + input_details = _to_json_text(getattr(usage, "input_tokens_details", None)) + output_details = _to_json_text(getattr(usage, "output_tokens_details", None)) + + # Prefer the precise response span if available + _usage_trace_id = _trace_id + _usage_span_id = _span_id + try: + if response_id is not None: + mapped = get_response_span(response_id) + if mapped: + _usage_trace_id, _usage_span_id = mapped + else: + maybe = get_last_response_span_for_trace(_trace_id) + if maybe: + _usage_trace_id, _usage_span_id = maybe + except Exception: + pass + + # Prefer model in response_id; fall back to last seen model for this trace. + _model_name: str | None = None + try: + if response_id is not None: + _model_name = get_response_model(response_id) + if _model_name is None: + _model_name = get_last_model_for_trace(_usage_trace_id) + except Exception: + pass + + conn.execute( + f""" + INSERT INTO {self.usage_table} ( + session_id, response_id, model, requests, input_tokens, + output_tokens, total_tokens, input_tokens_details, + output_tokens_details, trace_id, span_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + self.session_id, + response_id, + _model_name, + getattr(usage, "requests", None), + getattr(usage, "input_tokens", None), + getattr(usage, "output_tokens", None), + getattr(usage, "total_tokens", None), + input_details, + output_details, + _usage_trace_id, + _usage_span_id, + ), + ) + + conn.commit() + + await asyncio.to_thread(_add_usage_sync) + async def pop_item(self) -> TResponseInputItem | None: """Remove and return the most recent item from the session. @@ -326,6 +747,7 @@ def _pop_item_sync(): ) result = cursor.fetchone() + conn.commit() if result: @@ -334,7 +756,6 @@ def _pop_item_sync(): item = json.loads(message_data) return item except json.JSONDecodeError: - # Return None for corrupted JSON entries (already deleted) return None return None diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index c6d1d7d22..68affbe80 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -127,6 +127,24 @@ async def get_response( items = Converter.message_to_output_items(message) if message is not None else [] + # Register this generation span and model so sessions can attribute rows correctly + try: + # noqa: WPS433 import inside to avoid circular dependency + from ..memory.session import register_response_model, register_response_span + + register_response_span( + response_id=None, # Chat Completions doesn't produce a Response ID + trace_id=span_generation.trace_id, + span_id=span_generation.span_id, + ) + register_response_model( + response_id=None, + trace_id=span_generation.trace_id, + model=str(self.model) if self.model is not None else None, + ) + except Exception: + pass + return ModelResponse( output=items, usage=usage, @@ -182,6 +200,23 @@ async def stream_response( "output_tokens": final_response.usage.output_tokens, } + # Register this generation span and model as last seen for the current trace + try: + from ..memory.session import register_response_model, register_response_span + + register_response_span( + response_id=None, + trace_id=span_generation.trace_id, + span_id=span_generation.span_id, + ) + register_response_model( + response_id=None, + trace_id=span_generation.trace_id, + model=str(self.model) if self.model is not None else None, + ) + except Exception: + pass + @overload async def _fetch_response( self, diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 4352c99c7..87f250317 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -122,6 +122,23 @@ async def get_response( if tracing.include_data(): span_response.span_data.response = response span_response.span_data.input = input + # Register this model response span so sessions can attribute rows correctly + try: + # noqa: WPS433 import inside to avoid circular dependency + from ..memory.session import register_response_model, register_response_span + + register_response_span( + response_id=response.id, + trace_id=span_response.trace_id, + span_id=span_response.span_id, + ) + register_response_model( + response_id=response.id, + trace_id=span_response.trace_id, + model=str(self.model) if self.model is not None else None, + ) + except Exception: + pass except Exception as e: span_response.set_error( SpanError( @@ -180,6 +197,22 @@ async def stream_response( if final_response and tracing.include_data(): span_response.span_data.response = final_response span_response.span_data.input = input + # Register the span using final response (if any). Some providers omit IDs. + try: + from ..memory.session import register_response_model, register_response_span + + register_response_span( + response_id=(final_response.id if final_response else None), + trace_id=span_response.trace_id, + span_id=span_response.span_id, + ) + register_response_model( + response_id=(final_response.id if final_response else None), + trace_id=span_response.trace_id, + model=str(self.model) if self.model is not None else None, + ) + except Exception: + pass except Exception as e: span_response.set_error( diff --git a/src/agents/run.py b/src/agents/run.py index e63d7751e..0c441e919 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -500,6 +500,13 @@ async def run( # Save the conversation to session if enabled await self._save_result_to_session(session, input, result) + # Optionally persist usage if the session supports it (non-breaking) + if session is not None and hasattr(session, "add_usage_records"): + try: + await session.add_usage_records(model_responses) + except Exception: + # Do not fail the run on usage write errors + pass return result elif isinstance(turn_result.next_step, NextStepHandoff): @@ -855,6 +862,12 @@ async def _start_streaming( await AgentRunner._save_result_to_session( session, starting_input, temp_result ) + # Optionally persist usage if supported + if session is not None and hasattr(session, "add_usage_records"): + try: + await session.add_usage_records(streamed_result.raw_responses) + except Exception: + pass streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): diff --git a/tests/test_structured_session.py b/tests/test_structured_session.py new file mode 100644 index 000000000..11ef18ff2 --- /dev/null +++ b/tests/test_structured_session.py @@ -0,0 +1,275 @@ +"""Tests for structured session storage functionality.""" + +import sqlite3 +import tempfile +from pathlib import Path + +import pytest + +from agents import Agent, Runner, SQLiteSession, function_tool +from agents.items import TResponseInputItem + +from .fake_model import FakeModel +from .test_responses import get_text_message + + +@pytest.mark.asyncio +async def test_structured_session_creation(): + """Test that structured session creates the additional tables.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_structured.db" + session = SQLiteSession("test_session", db_path, structured_metadata=True) + + # Check that the structured tables were created + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") + tables = [row[0] for row in cursor.fetchall()] + conn.close() + + expected_tables = [ + "agent_conversation_messages", + "agent_messages", + "agent_sessions", + "agent_tool_calls", + "agent_usage", + ] + for table in expected_tables: + assert table in tables + + session.close() + + +@pytest.mark.asyncio +async def test_structured_session_disabled_by_default(): + """Test that structured tables are not created when structured_metadata=False.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_flat.db" + session = SQLiteSession("test_session", db_path, structured_metadata=False) + + # Check that only the basic tables were created + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") + tables = [row[0] for row in cursor.fetchall()] + conn.close() + + expected_tables = ["agent_messages", "agent_sessions"] + for table in expected_tables: + assert table in tables + + # Structured tables should not exist + assert "agent_conversation_messages" not in tables + assert "agent_tool_calls" not in tables + assert "agent_usage" not in tables + + session.close() + + +@pytest.mark.asyncio +async def test_structured_session_conversation_flow(): + """Test a full conversation flow with structured storage.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_conversation.db" + session = SQLiteSession("test_session", db_path, structured_metadata=True) + + # Create a simple tool for testing + @function_tool + def get_test_number(max_val: int = 100) -> int: + """Get a test number.""" + return 42 + + model = FakeModel() + agent = Agent(name="test", model=model, tools=[get_test_number]) + + # Simulate a simple message without tool calls for this test + model.set_next_output([get_text_message("I'll pick a random number: 42")]) + + await Runner.run(agent, "Pick a random number", session=session) + + # Check that data was stored in structured tables + conn = sqlite3.connect(str(db_path)) + + # Check conversation messages table + cursor = conn.execute( + """SELECT role, content FROM agent_conversation_messages + WHERE session_id = ? ORDER BY created_at""", + ("test_session",), + ) + conversation_rows = cursor.fetchall() + + # Should have user message and potentially assistant message + assert len(conversation_rows) >= 1 + assert conversation_rows[0][0] == "user" # First should be user role + assert "Pick a random number" in conversation_rows[0][1] + + # Check tool calls table (should be empty for this simple message test) + cursor = conn.execute( + "SELECT COUNT(*) FROM agent_tool_calls WHERE session_id = ?", ("test_session",) + ) + tool_call_count = cursor.fetchone()[0] + assert tool_call_count == 0 # No tool calls in this simple test + + # Usage table exists; rows may be 0 depending on provider, but schema should be present + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='agent_usage'" + ) + assert cursor.fetchone() is not None + + conn.close() + session.close() + + +@pytest.mark.asyncio +async def test_structured_session_backward_compatibility(): + """Test that structured_metadata=True doesn't break existing functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_compat.db" + session = SQLiteSession("test_session", db_path, structured_metadata=True) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("Hello!")]) + result1 = await Runner.run(agent, "Hi there", session=session) + assert result1.final_output == "Hello!" + + # Second turn - should have conversation history + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await Runner.run(agent, "What did I say?", session=session) + assert result2.final_output == "I remember you said hi" + + # Verify conversation history is working + items = await session.get_items() + assert len(items) >= 2 # Should have multiple items from the conversation + + session.close() + + +@pytest.mark.asyncio +async def test_structured_session_pop_item(): + """Test that pop_item works correctly with structured storage.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop.db" + session = SQLiteSession("test_session", db_path, structured_metadata=True) + + # Add some test items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + await session.add_items(items) + + # Pop the last item + popped = await session.pop_item() + assert popped is not None + assert popped.get("role") == "assistant" + assert popped.get("content") == "Hi there!" + + # Check that structured tables are also cleaned up + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute( + "SELECT COUNT(*) FROM agent_conversation_messages WHERE session_id = ?", + ("test_session",), + ) + count = cursor.fetchone()[0] + conn.close() + + # Should only have 1 message left (the user message) + assert count == 1 + + session.close() + + +@pytest.mark.asyncio +async def test_structured_session_clear(): + """Test that clear_session works correctly with structured storage.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_clear.db" + session = SQLiteSession("test_session", db_path, structured_metadata=True) + + # Add some test items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + { + "type": "function_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{"param": "value"}', + "status": "completed", + }, + ] + await session.add_items(items) + + # Clear the session + await session.clear_session() + + # Check that all tables are empty + conn = sqlite3.connect(str(db_path)) + + cursor = conn.execute( + "SELECT COUNT(*) FROM agent_messages WHERE session_id = ?", ("test_session",) + ) + assert cursor.fetchone()[0] == 0 + + cursor = conn.execute( + "SELECT COUNT(*) FROM agent_conversation_messages WHERE session_id = ?", + ("test_session",), + ) + assert cursor.fetchone()[0] == 0 + + cursor = conn.execute( + "SELECT COUNT(*) FROM agent_tool_calls WHERE session_id = ?", ("test_session",) + ) + assert cursor.fetchone()[0] == 0 + + cursor = conn.execute( + "SELECT COUNT(*) FROM agent_usage WHERE session_id = ?", ("test_session",) + ) + assert cursor.fetchone()[0] == 0 + + conn.close() + session.close() + + +@pytest.mark.asyncio +async def test_flat_vs_structured_storage_equivalence(): + """Test that flat and structured storage produce equivalent get_items results.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path_flat = Path(temp_dir) / "test_flat.db" + db_path_structured = Path(temp_dir) / "test_structured.db" + + session_flat = SQLiteSession("test_session", db_path_flat, structured_metadata=False) + session_structured = SQLiteSession( + "test_session", + db_path_structured, + structured_metadata=True, + ) + + # Add the same items to both sessions + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + { + "type": "function_call", + "call_id": "call_123", + "name": "test_tool", + "arguments": '{"param": "value"}', + "status": "completed", + }, + {"type": "function_call_output", "call_id": "call_123", "output": "result"}, + ] + + await session_flat.add_items(items) + await session_structured.add_items(items) + + # Get items from both sessions + items_flat = await session_flat.get_items() + items_structured = await session_structured.get_items() + + # Should be identical + assert len(items_flat) == len(items_structured) + assert items_flat == items_structured + + session_flat.close() + session_structured.close()