diff --git a/src/agents/run.py b/src/agents/run.py index 52d395a13..bbf9dd6ac 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -663,7 +663,13 @@ async def run( tool_output_guardrail_results=tool_output_guardrail_results, context_wrapper=context_wrapper, ) - await self._save_result_to_session(session, [], turn_result.new_step_items) + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) return result elif isinstance(turn_result.next_step, NextStepHandoff): @@ -672,7 +678,13 @@ async def run( current_span = None should_run_agent_start_hooks = True elif isinstance(turn_result.next_step, NextStepRunAgain): - await self._save_result_to_session(session, [], turn_result.new_step_items) + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) else: raise AgentsException( f"Unknown next step type: {type(turn_result.next_step)}" @@ -1041,15 +1053,29 @@ async def _start_streaming( streamed_result.is_complete = True # Save the conversation to session if enabled - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) except AgentsException as exc: streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1719,6 +1745,24 @@ async def _save_result_to_session( items_to_save = input_list + new_items_as_input await session.add_items(items_to_save) + @staticmethod + async def _input_guardrail_tripwire_triggered_for_stream( + streamed_result: RunResultStreaming, + ) -> bool: + """Return True if any input guardrail triggered during a streamed run.""" + + task = streamed_result._input_guardrails_task + if task is None: + return False + + if not task.done(): + await task + + return any( + guardrail_result.output.tripwire_triggered + for guardrail_result in streamed_result.input_guardrail_results + ) + DEFAULT_AGENT_RUNNER = AgentRunner() _TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index dae68fc4c..441054dd4 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -1,9 +1,10 @@ from __future__ import annotations +import asyncio import json import tempfile from pathlib import Path -from typing import Any +from typing import Any, cast from unittest.mock import patch import pytest @@ -39,6 +40,7 @@ get_text_input_item, get_text_message, ) +from .utils.simple_session import SimpleListSession @pytest.mark.asyncio @@ -542,6 +544,40 @@ def guardrail_function( await Runner.run(agent, input="user_message") +@pytest.mark.asyncio +async def test_input_guardrail_tripwire_does_not_save_assistant_message_to_session(): + async def guardrail_function( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + # Delay to ensure the agent has time to produce output before the guardrail finishes. + await asyncio.sleep(0.01) + return GuardrailFunctionOutput( + output_info=None, + tripwire_triggered=True, + ) + + session = SimpleListSession() + + model = FakeModel() + model.set_next_output([get_text_message("should_not_be_saved")]) + + agent = Agent( + name="test", + model=model, + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], + ) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, input="user_message", session=session) + + items = await session.get_items() + + assert len(items) == 1 + first_item = cast(dict[str, Any], items[0]) + assert "role" in first_item + assert first_item["role"] == "user" + + @pytest.mark.asyncio async def test_output_guardrail_tripwire_triggered_causes_exception(): def guardrail_function( diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 90071a3d7..00c98eed0 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any +from typing import Any, cast import pytest from typing_extensions import TypedDict @@ -35,6 +35,7 @@ get_text_input_item, get_text_message, ) +from .utils.simple_session import SimpleListSession @pytest.mark.asyncio @@ -524,6 +525,38 @@ def guardrail_function( pass +@pytest.mark.asyncio +async def test_input_guardrail_streamed_does_not_save_assistant_message_to_session(): + async def guardrail_function( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + await asyncio.sleep(0.01) + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + session = SimpleListSession() + + model = FakeModel() + model.set_next_output([get_text_message("should_not_be_saved")]) + + agent = Agent( + name="test", + model=model, + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], + ) + + with pytest.raises(InputGuardrailTripwireTriggered): + result = Runner.run_streamed(agent, input="user_message", session=session) + async for _ in result.stream_events(): + pass + + items = await session.get_items() + + assert len(items) == 1 + first_item = cast(dict[str, Any], items[0]) + assert "role" in first_item + assert first_item["role"] == "user" + + @pytest.mark.asyncio async def test_slow_input_guardrail_still_raises_exception_streamed(): async def guardrail_function( diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py new file mode 100644 index 000000000..b18d6fb92 --- /dev/null +++ b/tests/utils/simple_session.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from agents.items import TResponseInputItem +from agents.memory.session import Session + + +class SimpleListSession(Session): + """A minimal in-memory session implementation for tests.""" + + def __init__(self, session_id: str = "test") -> None: + self.session_id = session_id + self._items: list[TResponseInputItem] = [] + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + if limit is None: + return list(self._items) + if limit <= 0: + return [] + return self._items[-limit:] + + async def add_items(self, items: list[TResponseInputItem]) -> None: + self._items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + if not self._items: + return None + return self._items.pop() + + async def clear_session(self) -> None: + self._items.clear()