Skip to content

Commit da6a7e2

Browse files
committed
fix: #1840 roll back session changes when Guardrail tripwire is triggered
1 parent f3cac17 commit da6a7e2

File tree

4 files changed

+153
-10
lines changed

4 files changed

+153
-10
lines changed

src/agents/run.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,13 @@ async def run(
663663
tool_output_guardrail_results=tool_output_guardrail_results,
664664
context_wrapper=context_wrapper,
665665
)
666-
await self._save_result_to_session(session, [], turn_result.new_step_items)
666+
if not any(
667+
guardrail_result.output.tripwire_triggered
668+
for guardrail_result in input_guardrail_results
669+
):
670+
await self._save_result_to_session(
671+
session, [], turn_result.new_step_items
672+
)
667673

668674
return result
669675
elif isinstance(turn_result.next_step, NextStepHandoff):
@@ -672,7 +678,13 @@ async def run(
672678
current_span = None
673679
should_run_agent_start_hooks = True
674680
elif isinstance(turn_result.next_step, NextStepRunAgain):
675-
await self._save_result_to_session(session, [], turn_result.new_step_items)
681+
if not any(
682+
guardrail_result.output.tripwire_triggered
683+
for guardrail_result in input_guardrail_results
684+
):
685+
await self._save_result_to_session(
686+
session, [], turn_result.new_step_items
687+
)
676688
else:
677689
raise AgentsException(
678690
f"Unknown next step type: {type(turn_result.next_step)}"
@@ -1041,15 +1053,29 @@ async def _start_streaming(
10411053
streamed_result.is_complete = True
10421054

10431055
# Save the conversation to session if enabled
1044-
await AgentRunner._save_result_to_session(
1045-
session, [], turn_result.new_step_items
1046-
)
1056+
if session is not None:
1057+
should_skip_session_save = (
1058+
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
1059+
streamed_result
1060+
)
1061+
)
1062+
if should_skip_session_save is False:
1063+
await AgentRunner._save_result_to_session(
1064+
session, [], turn_result.new_step_items
1065+
)
10471066

10481067
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
10491068
elif isinstance(turn_result.next_step, NextStepRunAgain):
1050-
await AgentRunner._save_result_to_session(
1051-
session, [], turn_result.new_step_items
1052-
)
1069+
if session is not None:
1070+
should_skip_session_save = (
1071+
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
1072+
streamed_result
1073+
)
1074+
)
1075+
if should_skip_session_save is False:
1076+
await AgentRunner._save_result_to_session(
1077+
session, [], turn_result.new_step_items
1078+
)
10531079
except AgentsException as exc:
10541080
streamed_result.is_complete = True
10551081
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1719,6 +1745,24 @@ async def _save_result_to_session(
17191745
items_to_save = input_list + new_items_as_input
17201746
await session.add_items(items_to_save)
17211747

1748+
@staticmethod
1749+
async def _input_guardrail_tripwire_triggered_for_stream(
1750+
streamed_result: RunResultStreaming,
1751+
) -> bool:
1752+
"""Return True if any input guardrail triggered during a streamed run."""
1753+
1754+
task = streamed_result._input_guardrails_task
1755+
if task is None:
1756+
return False
1757+
1758+
if not task.done():
1759+
await task
1760+
1761+
return any(
1762+
guardrail_result.output.tripwire_triggered
1763+
for guardrail_result in streamed_result.input_guardrail_results
1764+
)
1765+
17221766

17231767
DEFAULT_AGENT_RUNNER = AgentRunner()
17241768
_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes)

tests/test_agent_runner.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
import tempfile
56
from pathlib import Path
6-
from typing import Any
7+
from typing import Any, cast
78
from unittest.mock import patch
89

910
import pytest
@@ -39,6 +40,7 @@
3940
get_text_input_item,
4041
get_text_message,
4142
)
43+
from .utils.simple_session import SimpleListSession
4244

4345

4446
@pytest.mark.asyncio
@@ -542,6 +544,40 @@ def guardrail_function(
542544
await Runner.run(agent, input="user_message")
543545

544546

547+
@pytest.mark.asyncio
548+
async def test_input_guardrail_tripwire_does_not_save_assistant_message_to_session():
549+
async def guardrail_function(
550+
context: RunContextWrapper[Any], agent: Agent[Any], input: Any
551+
) -> GuardrailFunctionOutput:
552+
# Delay to ensure the agent has time to produce output before the guardrail finishes.
553+
await asyncio.sleep(0.01)
554+
return GuardrailFunctionOutput(
555+
output_info=None,
556+
tripwire_triggered=True,
557+
)
558+
559+
session = SimpleListSession()
560+
561+
model = FakeModel()
562+
model.set_next_output([get_text_message("should_not_be_saved")])
563+
564+
agent = Agent(
565+
name="test",
566+
model=model,
567+
input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)],
568+
)
569+
570+
with pytest.raises(InputGuardrailTripwireTriggered):
571+
await Runner.run(agent, input="user_message", session=session)
572+
573+
items = await session.get_items()
574+
575+
assert len(items) == 1
576+
first_item = cast(dict[str, Any], items[0])
577+
assert "role" in first_item
578+
assert first_item["role"] == "user"
579+
580+
545581
@pytest.mark.asyncio
546582
async def test_output_guardrail_tripwire_triggered_causes_exception():
547583
def guardrail_function(

tests/test_agent_runner_streamed.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import json
5-
from typing import Any
5+
from typing import Any, cast
66

77
import pytest
88
from typing_extensions import TypedDict
@@ -35,6 +35,7 @@
3535
get_text_input_item,
3636
get_text_message,
3737
)
38+
from .utils.simple_session import SimpleListSession
3839

3940

4041
@pytest.mark.asyncio
@@ -524,6 +525,38 @@ def guardrail_function(
524525
pass
525526

526527

528+
@pytest.mark.asyncio
529+
async def test_input_guardrail_streamed_does_not_save_assistant_message_to_session():
530+
async def guardrail_function(
531+
context: RunContextWrapper[Any], agent: Agent[Any], input: Any
532+
) -> GuardrailFunctionOutput:
533+
await asyncio.sleep(0.01)
534+
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True)
535+
536+
session = SimpleListSession()
537+
538+
model = FakeModel()
539+
model.set_next_output([get_text_message("should_not_be_saved")])
540+
541+
agent = Agent(
542+
name="test",
543+
model=model,
544+
input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)],
545+
)
546+
547+
with pytest.raises(InputGuardrailTripwireTriggered):
548+
result = Runner.run_streamed(agent, input="user_message", session=session)
549+
async for _ in result.stream_events():
550+
pass
551+
552+
items = await session.get_items()
553+
554+
assert len(items) == 1
555+
first_item = cast(dict[str, Any], items[0])
556+
assert "role" in first_item
557+
assert first_item["role"] == "user"
558+
559+
527560
@pytest.mark.asyncio
528561
async def test_slow_input_guardrail_still_raises_exception_streamed():
529562
async def guardrail_function(

tests/utils/simple_session.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
from agents.items import TResponseInputItem
4+
from agents.memory.session import Session
5+
6+
7+
class SimpleListSession(Session):
8+
"""A minimal in-memory session implementation for tests."""
9+
10+
def __init__(self, session_id: str = "test") -> None:
11+
self.session_id = session_id
12+
self._items: list[TResponseInputItem] = []
13+
14+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
15+
if limit is None:
16+
return list(self._items)
17+
if limit <= 0:
18+
return []
19+
return self._items[-limit:]
20+
21+
async def add_items(self, items: list[TResponseInputItem]) -> None:
22+
self._items.extend(items)
23+
24+
async def pop_item(self) -> TResponseInputItem | None:
25+
if not self._items:
26+
return None
27+
return self._items.pop()
28+
29+
async def clear_session(self) -> None:
30+
self._items.clear()

0 commit comments

Comments
 (0)