diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 059ca57ab..7f3b45dba 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,3 +1,4 @@ from .session import Session, SQLiteSession +from .util import SessionInputHandler, SessionMixerCallable -__all__ = ["Session", "SQLiteSession"] +__all__ = ["Session", "SessionInputHandler", "SessionMixerCallable", "SQLiteSession"] diff --git a/src/agents/memory/util.py b/src/agents/memory/util.py new file mode 100644 index 000000000..e530b2f30 --- /dev/null +++ b/src/agents/memory/util.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Callable, Union + +from ..items import TResponseInputItem +from ..util._types import MaybeAwaitable + +SessionMixerCallable = Callable[ + [list[TResponseInputItem], list[TResponseInputItem]], + MaybeAwaitable[list[TResponseInputItem]], +] +"""A function that combines session history with new input items. + +Args: + history_items: The list of items from the session history. + new_items: The list of new input items for the current turn. + +Returns: + A list of combined items to be used as input for the agent. Can be sync or async. +""" + + +SessionInputHandler = Union[SessionMixerCallable, None] +"""Defines how to handle session history when new input is provided. + +- `None` (default): The new input is appended to the session history. +- `SessionMixerCallable`: A custom function that receives the history and new input, and + returns the desired combined list of items. +""" diff --git a/src/agents/run.py b/src/agents/run.py index 2dd9524bb..3646a380a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -44,7 +44,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger -from .memory import Session +from .memory import Session, SessionInputHandler from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -139,6 +139,14 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ + session_input_callback: SessionInputHandler = None + """Defines how to handle session history when new input is provided. + + - `None` (default): The new input is appended to the session history. + - `SessionMixerCallable`: A custom function that receives the history and new input, and + returns the desired combined list of items. + """ + class RunOptions(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" @@ -343,7 +351,9 @@ async def run( run_config = RunConfig() # Prepare input with session if enabled - prepared_input = await self._prepare_input_with_session(input, session) + prepared_input = await self._prepare_input_with_session( + input, session, run_config.session_input_callback + ) tool_use_tracker = AgentToolUseTracker() @@ -662,7 +672,9 @@ async def _start_streaming( try: # Prepare input with session if enabled - prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session) + prepared_input = await AgentRunner._prepare_input_with_session( + starting_input, session, run_config.session_input_callback + ) # Update the streamed result with the prepared input streamed_result.input = prepared_input @@ -1191,18 +1203,18 @@ async def _prepare_input_with_session( cls, input: str | list[TResponseInputItem], session: Session | None, + session_input_callback: SessionInputHandler, ) -> str | list[TResponseInputItem]: """Prepare input by combining it with session history if enabled.""" if session is None: return input - # Validate that we don't have both a session and a list input, as this creates - # ambiguity about whether the list should append to or replace existing session history - if isinstance(input, list): + # If the user doesn't explicitly specify a mode, raise an error + if isinstance(input, list) and not session_input_callback: raise UserError( - "Cannot provide both a session and a list of input items. " - "When using session memory, provide only a string input to append to the " - "conversation, or use session=None and provide a list to manually manage " + "You must specify the `session_input_callback` in the `RunConfig`. " + "Otherwise, when using session memory, provide only a string input to append to " + "the conversation, or use session=None and provide a list to manually manage " "conversation history." ) @@ -1212,10 +1224,18 @@ async def _prepare_input_with_session( # Convert input to list format new_input_list = ItemHelpers.input_to_new_input_list(input) - # Combine history with new input - combined_input = history + new_input_list - - return combined_input + if session_input_callback is None: + return history + new_input_list + elif callable(session_input_callback): + res = session_input_callback(history, new_input_list) + if inspect.isawaitable(res): + return await res + return res + else: + raise UserError( + f"Invalid `session_input_callback` value: {session_input_callback}. " + "Choose between `None` or a custom callable function." + ) @classmethod async def _save_result_to_session( @@ -1224,7 +1244,11 @@ async def _save_result_to_session( original_input: str | list[TResponseInputItem], result: RunResult, ) -> None: - """Save the conversation turn to session.""" + """ + Save the conversation turn to session. + It does not account for any filtering or modification performed by + `RunConfig.session_input_handling`. + """ if session is None: return diff --git a/tests/test_session.py b/tests/test_session.py index 032f2bb38..1cfc62a92 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,7 +6,7 @@ import pytest -from agents import Agent, Runner, SQLiteSession, TResponseInputItem +from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem from agents.exceptions import UserError from .fake_model import FakeModel @@ -394,7 +394,55 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method) await run_agent_async(runner_method, agent, list_input, session=session) # Verify the error message explains the issue - assert "Cannot provide both a session and a list of input items" in str(exc_info.value) + assert "You must specify the `session_input_callback` in" in str(exc_info.value) assert "manually manage conversation history" in str(exc_info.value) session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_callback_prepared_input(runner_method): + """Test if the user passes a list of items and want to append them.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session + session_id = "session_1" + session = SQLiteSession(session_id, db_path) + + # Add first messages manually + initial_history: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello there."}, + {"role": "assistant", "content": "Hi, I'm here to assist you."}, + ] + await session.add_items(initial_history) + + def filter_assistant_messages(history, new_input): + # Only include user messages from history + return [item for item in history if item["role"] == "user"] + new_input + + new_turn_input = [{"role": "user", "content": "What your name?"}] + model.set_next_output([get_text_message("I'm gpt-4o")]) + + # Run the agent with the callable + await run_agent_async( + runner_method, + agent, + new_turn_input, + session=session, + run_config=RunConfig(session_input_callback=filter_assistant_messages), + ) + + expected_model_input = [ + initial_history[0], # From history + new_turn_input[0], # New input + ] + + assert len(model.last_turn_args["input"]) == 2 + assert model.last_turn_args["input"] == expected_model_input + + session.close()