diff --git a/src/agents/run.py b/src/agents/run.py index 727927b08..9ac4eed49 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -419,6 +419,9 @@ async def run( current_agent = starting_agent should_run_agent_start_hooks = True + # save the original input to the session if enabled + await self._save_result_to_session(session, original_input, []) + try: while True: all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) @@ -516,9 +519,7 @@ async def run( output_guardrail_results=output_guardrail_results, context_wrapper=context_wrapper, ) - - # Save the conversation to session if enabled - await self._save_result_to_session(session, input, result) + await self._save_result_to_session(session, [], turn_result.new_step_items) return result elif isinstance(turn_result.next_step, NextStepHandoff): @@ -527,7 +528,7 @@ async def run( current_span = None should_run_agent_start_hooks = True elif isinstance(turn_result.next_step, NextStepRunAgain): - pass + await self._save_result_to_session(session, [], turn_result.new_step_items) else: raise AgentsException( f"Unknown next step type: {type(turn_result.next_step)}" @@ -758,6 +759,8 @@ async def _start_streaming( # Update the streamed result with the prepared input streamed_result.input = prepared_input + await AgentRunner._save_result_to_session(session, starting_input, []) + while True: if streamed_result.is_complete: break @@ -860,24 +863,15 @@ async def _start_streaming( streamed_result.is_complete = True # Save the conversation to session if enabled - # Create a temporary RunResult for session saving - temp_result = RunResult( - input=streamed_result.input, - new_items=streamed_result.new_items, - raw_responses=streamed_result.raw_responses, - final_output=streamed_result.final_output, - _last_agent=current_agent, - input_guardrail_results=streamed_result.input_guardrail_results, - output_guardrail_results=streamed_result.output_guardrail_results, - context_wrapper=context_wrapper, - ) await AgentRunner._save_result_to_session( - session, starting_input, temp_result + session, [], turn_result.new_step_items ) streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): - pass + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) except AgentsException as exc: streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1448,7 +1442,7 @@ async def _save_result_to_session( cls, session: Session | None, original_input: str | list[TResponseInputItem], - result: RunResult, + new_items: list[RunItem], ) -> None: """Save the conversation turn to session.""" if session is None: @@ -1458,7 +1452,7 @@ async def _save_result_to_session( input_list = ItemHelpers.input_to_new_input_list(original_input) # Convert new items to input format - new_items_as_input = [item.to_input_item() for item in result.new_items] + new_items_as_input = [item.to_input_item() for item in new_items] # Save all items from this turn items_to_save = input_list + new_items_as_input diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index c8ae5b5f2..887defa5b 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -1,7 +1,10 @@ from __future__ import annotations import json +import tempfile +from pathlib import Path from typing import Any +from unittest.mock import patch import pytest from typing_extensions import TypedDict @@ -20,6 +23,7 @@ RunConfig, RunContextWrapper, Runner, + SQLiteSession, UserError, handoff, ) @@ -780,3 +784,96 @@ async def add_tool() -> str: assert executed["called"] is True assert result.final_output == "done" + + +@pytest.mark.asyncio +async def test_session_add_items_called_multiple_times_for_multi_turn_completion(): + """Test that SQLiteSession.add_items is called multiple times + during a multi-turn agent completion. + + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_agent_runner_session_multi_turn_calls.db" + session_id = "runner_session_multi_turn_calls" + session = SQLiteSession(session_id, db_path) + + # Define a tool that will be called by the orchestrator agent + @function_tool + async def echo_tool(text: str) -> str: + return f"Echo: {text}" + + # Orchestrator agent that calls the tool multiple times in one completion + orchestrator_agent = Agent( + name="orchestrator_agent", + instructions=( + "Call echo_tool twice with inputs of 'foo' and 'bar', then return a summary." + ), + tools=[echo_tool], + ) + + # Patch the model to simulate two tool calls and a final message + model = FakeModel() + orchestrator_agent.model = model + model.add_multiple_turn_outputs( + [ + # First turn: tool call + [get_function_tool_call("echo_tool", json.dumps({"text": "foo"}), call_id="1")], + # Second turn: tool call + [get_function_tool_call("echo_tool", json.dumps({"text": "bar"}), call_id="2")], + # Third turn: final output + [get_final_output_message("Summary: Echoed foo and bar")], + ] + ) + + # Patch add_items to count calls + with patch.object(SQLiteSession, "add_items", wraps=session.add_items) as mock_add_items: + result = await Runner.run(orchestrator_agent, input="foo and bar", session=session) + + expected_items = [ + {"content": "foo and bar", "role": "user"}, + { + "arguments": '{"text": "foo"}', + "call_id": "1", + "name": "echo_tool", + "type": "function_call", + "id": "1", + }, + {"call_id": "1", "output": "Echo: foo", "type": "function_call_output"}, + { + "arguments": '{"text": "bar"}', + "call_id": "2", + "name": "echo_tool", + "type": "function_call", + "id": "1", + }, + {"call_id": "2", "output": "Echo: bar", "type": "function_call_output"}, + { + "id": "1", + "content": [ + { + "annotations": [], + "text": "Summary: Echoed foo and bar", + "type": "output_text", + } + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + ] + + expected_calls = [ + # First call is the initial input + (([expected_items[0]],),), + # Second call is the first tool call and its result + (([expected_items[1], expected_items[2]],),), + # Third call is the second tool call and its result + (([expected_items[3], expected_items[4]],),), + # Fourth call is the final output + (([expected_items[5]],),), + ] + assert mock_add_items.call_args_list == expected_calls + assert result.final_output == "Summary: Echoed foo and bar" + assert (await session.get_items()) == expected_items + + session.close()