Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,13 @@ async def run(
tool_output_guardrail_results=tool_output_guardrail_results,
context_wrapper=context_wrapper,
)
await self._save_result_to_session(session, [], turn_result.new_step_items)
if not any(
guardrail_result.output.tripwire_triggered
for guardrail_result in input_guardrail_results
):
await self._save_result_to_session(
session, [], turn_result.new_step_items
)

return result
elif isinstance(turn_result.next_step, NextStepHandoff):
Expand All @@ -672,7 +678,13 @@ async def run(
current_span = None
should_run_agent_start_hooks = True
elif isinstance(turn_result.next_step, NextStepRunAgain):
await self._save_result_to_session(session, [], turn_result.new_step_items)
if not any(
guardrail_result.output.tripwire_triggered
for guardrail_result in input_guardrail_results
):
await self._save_result_to_session(
session, [], turn_result.new_step_items
)
else:
raise AgentsException(
f"Unknown next step type: {type(turn_result.next_step)}"
Expand Down Expand Up @@ -1041,15 +1053,29 @@ async def _start_streaming(
streamed_result.is_complete = True

# Save the conversation to session if enabled
await AgentRunner._save_result_to_session(
session, [], turn_result.new_step_items
)
if session is not None:
should_skip_session_save = (
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
streamed_result
)
)
if should_skip_session_save is False:
await AgentRunner._save_result_to_session(
session, [], turn_result.new_step_items
)

streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
elif isinstance(turn_result.next_step, NextStepRunAgain):
await AgentRunner._save_result_to_session(
session, [], turn_result.new_step_items
)
if session is not None:
should_skip_session_save = (
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
streamed_result
)
)
if should_skip_session_save is False:
await AgentRunner._save_result_to_session(
session, [], turn_result.new_step_items
)
except AgentsException as exc:
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
Expand Down Expand Up @@ -1719,6 +1745,24 @@ async def _save_result_to_session(
items_to_save = input_list + new_items_as_input
await session.add_items(items_to_save)

@staticmethod
async def _input_guardrail_tripwire_triggered_for_stream(
streamed_result: RunResultStreaming,
) -> bool:
"""Return True if any input guardrail triggered during a streamed run."""

task = streamed_result._input_guardrails_task
if task is None:
return False

if not task.done():
await task

return any(
guardrail_result.output.tripwire_triggered
for guardrail_result in streamed_result.input_guardrail_results
)


DEFAULT_AGENT_RUNNER = AgentRunner()
_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes)
Expand Down
38 changes: 37 additions & 1 deletion tests/test_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import asyncio
import json
import tempfile
from pathlib import Path
from typing import Any
from typing import Any, cast
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -39,6 +40,7 @@
get_text_input_item,
get_text_message,
)
from .utils.simple_session import SimpleListSession


@pytest.mark.asyncio
Expand Down Expand Up @@ -542,6 +544,40 @@ def guardrail_function(
await Runner.run(agent, input="user_message")


@pytest.mark.asyncio
async def test_input_guardrail_tripwire_does_not_save_assistant_message_to_session():
async def guardrail_function(
context: RunContextWrapper[Any], agent: Agent[Any], input: Any
) -> GuardrailFunctionOutput:
# Delay to ensure the agent has time to produce output before the guardrail finishes.
await asyncio.sleep(0.01)
return GuardrailFunctionOutput(
output_info=None,
tripwire_triggered=True,
)

session = SimpleListSession()

model = FakeModel()
model.set_next_output([get_text_message("should_not_be_saved")])

agent = Agent(
name="test",
model=model,
input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)],
)

with pytest.raises(InputGuardrailTripwireTriggered):
await Runner.run(agent, input="user_message", session=session)

items = await session.get_items()

assert len(items) == 1
first_item = cast(dict[str, Any], items[0])
assert "role" in first_item
assert first_item["role"] == "user"


@pytest.mark.asyncio
async def test_output_guardrail_tripwire_triggered_causes_exception():
def guardrail_function(
Expand Down
35 changes: 34 additions & 1 deletion tests/test_agent_runner_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import json
from typing import Any
from typing import Any, cast

import pytest
from typing_extensions import TypedDict
Expand Down Expand Up @@ -35,6 +35,7 @@
get_text_input_item,
get_text_message,
)
from .utils.simple_session import SimpleListSession


@pytest.mark.asyncio
Expand Down Expand Up @@ -524,6 +525,38 @@ def guardrail_function(
pass


@pytest.mark.asyncio
async def test_input_guardrail_streamed_does_not_save_assistant_message_to_session():
async def guardrail_function(
context: RunContextWrapper[Any], agent: Agent[Any], input: Any
) -> GuardrailFunctionOutput:
await asyncio.sleep(0.01)
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True)

session = SimpleListSession()

model = FakeModel()
model.set_next_output([get_text_message("should_not_be_saved")])

agent = Agent(
name="test",
model=model,
input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)],
)

with pytest.raises(InputGuardrailTripwireTriggered):
result = Runner.run_streamed(agent, input="user_message", session=session)
async for _ in result.stream_events():
pass

items = await session.get_items()

assert len(items) == 1
first_item = cast(dict[str, Any], items[0])
assert "role" in first_item
assert first_item["role"] == "user"


@pytest.mark.asyncio
async def test_slow_input_guardrail_still_raises_exception_streamed():
async def guardrail_function(
Expand Down
30 changes: 30 additions & 0 deletions tests/utils/simple_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

from agents.items import TResponseInputItem
from agents.memory.session import Session


class SimpleListSession(Session):
"""A minimal in-memory session implementation for tests."""

def __init__(self, session_id: str = "test") -> None:
self.session_id = session_id
self._items: list[TResponseInputItem] = []

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
if limit is None:
return list(self._items)
if limit <= 0:
return []
return self._items[-limit:]

async def add_items(self, items: list[TResponseInputItem]) -> None:
self._items.extend(items)

async def pop_item(self) -> TResponseInputItem | None:
if not self._items:
return None
return self._items.pop()

async def clear_session(self) -> None:
self._items.clear()