diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index eeb2ace5d..1db1598ac 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,10 +1,12 @@ from .openai_conversations_session import OpenAIConversationsSession from .session import Session, SessionABC from .sqlite_session import SQLiteSession +from .util import SessionInputCallback __all__ = [ "Session", "SessionABC", + "SessionInputCallback", "SQLiteSession", "OpenAIConversationsSession", ] diff --git a/src/agents/memory/util.py b/src/agents/memory/util.py new file mode 100644 index 000000000..49f281151 --- /dev/null +++ b/src/agents/memory/util.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import Callable + +from ..items import TResponseInputItem +from ..util._types import MaybeAwaitable + +SessionInputCallback = Callable[ + [list[TResponseInputItem], list[TResponseInputItem]], + MaybeAwaitable[list[TResponseInputItem]], +] +"""A function that combines session history with new input items. + +Args: + history_items: The list of items from the session history. + new_items: The list of new input items for the current turn. + +Returns: + A list of combined items to be used as input for the agent. Can be sync or async. +""" diff --git a/src/agents/run.py b/src/agents/run.py index a77900dff..ee08ad134 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -54,7 +54,7 @@ ) from .lifecycle import RunHooks from .logger import logger -from .memory import Session +from .memory import Session, SessionInputCallback from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -179,6 +179,13 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ + session_input_callback: SessionInputCallback | None = None + """Defines how to handle session history when new input is provided. + - `None` (default): The new input is appended to the session history. + - `SessionInputCallback`: A custom function that receives the history and new input, and + returns the desired combined list of items. + """ + call_model_input_filter: CallModelInputFilter | None = None """ Optional callback that is invoked immediately before calling the model. It receives the current @@ -413,7 +420,9 @@ async def run( # Keep original user input separate from session-prepared input original_user_input = input - prepared_input = await self._prepare_input_with_session(input, session) + prepared_input = await self._prepare_input_with_session( + input, session, run_config.session_input_callback + ) tool_use_tracker = AgentToolUseTracker() @@ -781,7 +790,9 @@ async def _start_streaming( try: # Prepare input with session if enabled - prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session) + prepared_input = await AgentRunner._prepare_input_with_session( + starting_input, session, run_config.session_input_callback + ) # Update the streamed result with the prepared input streamed_result.input = prepared_input @@ -1474,19 +1485,20 @@ async def _prepare_input_with_session( cls, input: str | list[TResponseInputItem], session: Session | None, + session_input_callback: SessionInputCallback | None, ) -> str | list[TResponseInputItem]: """Prepare input by combining it with session history if enabled.""" if session is None: return input - # Validate that we don't have both a session and a list input, as this creates - # ambiguity about whether the list should append to or replace existing session history - if isinstance(input, list): + # If the user doesn't specify an input callback and pass a list as input + if isinstance(input, list) and not session_input_callback: raise UserError( - "Cannot provide both a session and a list of input items. " - "When using session memory, provide only a string input to append to the " - "conversation, or use session=None and provide a list to manually manage " - "conversation history." + "When using session memory, list inputs require a " + "`RunConfig.session_input_callback` to define how they should be merged " + "with the conversation history. If you don't want to use a callback, " + "provide your input as a string instead, or disable session memory " + "(session=None) and pass a list to manage the history manually." ) # Get previous conversation history @@ -1495,10 +1507,18 @@ async def _prepare_input_with_session( # Convert input to list format new_input_list = ItemHelpers.input_to_new_input_list(input) - # Combine history with new input - combined_input = history + new_input_list - - return combined_input + if session_input_callback is None: + return history + new_input_list + elif callable(session_input_callback): + res = session_input_callback(history, new_input_list) + if inspect.isawaitable(res): + return await res + return res + else: + raise UserError( + f"Invalid `session_input_callback` value: {session_input_callback}. " + "Choose between `None` or a custom callable function." + ) @classmethod async def _save_result_to_session( @@ -1507,7 +1527,11 @@ async def _save_result_to_session( original_input: str | list[TResponseInputItem], new_items: list[RunItem], ) -> None: - """Save the conversation turn to session.""" + """ + Save the conversation turn to session. + It does not account for any filtering or modification performed by + `RunConfig.session_input_callback`. + """ if session is None: return diff --git a/tests/test_session.py b/tests/test_session.py index 3b7c4a98c..d249e900d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,7 +6,7 @@ import pytest -from agents import Agent, Runner, SQLiteSession, TResponseInputItem +from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem from agents.exceptions import UserError from .fake_model import FakeModel @@ -394,11 +394,57 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method) await run_agent_async(runner_method, agent, list_input, session=session) # Verify the error message explains the issue - assert "Cannot provide both a session and a list of input items" in str(exc_info.value) - assert "manually manage conversation history" in str(exc_info.value) + assert "list inputs require a `RunConfig.session_input_callback" in str(exc_info.value) + assert "to manage the history manually" in str(exc_info.value) session.close() + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_callback_prepared_input(runner_method): + """Test if the user passes a list of items and want to append them.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session + session_id = "session_1" + session = SQLiteSession(session_id, db_path) + + # Add first messages manually + initial_history: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello there."}, + {"role": "assistant", "content": "Hi, I'm here to assist you."}, + ] + await session.add_items(initial_history) + + def filter_assistant_messages(history, new_input): + # Only include user messages from history + return [item for item in history if item["role"] == "user"] + new_input + + new_turn_input = [{"role": "user", "content": "What your name?"}] + model.set_next_output([get_text_message("I'm gpt-4o")]) + + # Run the agent with the callable + await run_agent_async( + runner_method, + agent, + new_turn_input, + session=session, + run_config=RunConfig(session_input_callback=filter_assistant_messages), + ) + + expected_model_input = [ + initial_history[0], # From history + new_turn_input[0], # New input + ] + + assert len(model.last_turn_args["input"]) == 2 + assert model.last_turn_args["input"] == expected_model_input + @pytest.mark.asyncio async def test_sqlite_session_unicode_content(): """Test that session correctly stores and retrieves unicode/non-ASCII content."""