Skip to content

Commit 1989347

Browse files
committed
add unit tests
1 parent ae84338 commit 1989347

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import asyncio
2+
from datetime import datetime
3+
from typing import Any
4+
5+
import pytest
6+
7+
from agents import Agent, GuardrailFunctionOutput, InputGuardrail, Runner, RunContextWrapper
8+
from agents.items import TResponseInputItem
9+
from agents.exceptions import InputGuardrailTripwireTriggered
10+
11+
from .fake_model import FakeModel
12+
from openai.types.responses import ResponseCompletedEvent
13+
from .test_responses import get_text_message
14+
15+
16+
def make_input_guardrail(delay_seconds: float, *, trip: bool) -> InputGuardrail[Any]:
17+
async def guardrail(
18+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
19+
) -> GuardrailFunctionOutput:
20+
# Simulate variable guardrail completion timing.
21+
if delay_seconds > 0:
22+
await asyncio.sleep(delay_seconds)
23+
return GuardrailFunctionOutput(
24+
output_info={"delay": delay_seconds}, tripwire_triggered=trip
25+
)
26+
27+
name = "tripping_input_guardrail" if trip else "delayed_input_guardrail"
28+
return InputGuardrail(guardrail_function=guardrail, name=name)
29+
30+
31+
@pytest.mark.asyncio
32+
@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2])
33+
async def test_run_streamed_input_guardrail_timing_is_consistent(guardrail_delay: float):
34+
"""Ensure streaming behavior matches whether input guardrail finishes before or after LLM stream.
35+
36+
We verify that:
37+
- The sequence of streamed event types is identical.
38+
- Final output matches.
39+
- Exactly one input guardrail result is recorded and does not trigger.
40+
"""
41+
42+
# Arrange: Agent with a single text output and a delayed input guardrail
43+
model = FakeModel()
44+
model.set_next_output([get_text_message("Final response")])
45+
46+
agent = Agent(
47+
name="TimingAgent",
48+
model=model,
49+
input_guardrails=[make_input_guardrail(guardrail_delay, trip=False)],
50+
)
51+
52+
# Act: Run streamed and collect event types
53+
result = Runner.run_streamed(agent, input="Hello")
54+
event_types: list[str] = []
55+
56+
async for event in result.stream_events():
57+
event_types.append(event.type)
58+
59+
# Assert: Guardrail results populated and identical behavioral outcome
60+
assert len(result.input_guardrail_results) == 1, "Expected exactly one input guardrail result"
61+
assert result.input_guardrail_results[0].guardrail.get_name() == "delayed_input_guardrail", (
62+
"Guardrail name mismatch"
63+
)
64+
assert result.input_guardrail_results[0].output.tripwire_triggered is False, (
65+
"Guardrail should not trigger in this test"
66+
)
67+
68+
# Final output should be the text from the model's single message
69+
assert result.final_output == "Final response"
70+
71+
# Minimal invariants on event sequence to ensure stability across timing
72+
# Must start with agent update and include raw response events
73+
assert len(event_types) >= 3, f"Unexpectedly few events: {event_types}"
74+
assert event_types[0] == "agent_updated_stream_event"
75+
# Ensure we observed raw response events in the stream irrespective of guardrail timing
76+
assert any(t == "raw_response_event" for t in event_types)
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_run_streamed_input_guardrail_sequences_match_between_fast_and_slow():
81+
"""Run twice with fast vs slow input guardrail and compare event sequences exactly."""
82+
83+
async def run_once(delay: float) -> list[str]:
84+
model = FakeModel()
85+
model.set_next_output([get_text_message("Final response")])
86+
agent = Agent(
87+
name="TimingAgent",
88+
model=model,
89+
input_guardrails=[make_input_guardrail(delay, trip=False)],
90+
)
91+
result = Runner.run_streamed(agent, input="Hello")
92+
events: list[str] = []
93+
async for ev in result.stream_events():
94+
events.append(ev.type)
95+
return events
96+
97+
events_fast = await run_once(0.0)
98+
events_slow = await run_once(0.2)
99+
100+
assert events_fast == events_slow, (
101+
f"Event sequences differ between guardrail timings:\nfast={events_fast}\nslow={events_slow}"
102+
)
103+
104+
105+
@pytest.mark.asyncio
106+
@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2])
107+
async def test_run_streamed_input_guardrail_tripwire_raises(guardrail_delay: float):
108+
"""Guardrail tripwire must raise from stream_events regardless of timing."""
109+
110+
model = FakeModel()
111+
model.set_next_output([get_text_message("Final response")])
112+
113+
agent = Agent(
114+
name="TimingAgentTrip",
115+
model=model,
116+
input_guardrails=[make_input_guardrail(guardrail_delay, trip=True)],
117+
)
118+
119+
result = Runner.run_streamed(agent, input="Hello")
120+
121+
with pytest.raises(InputGuardrailTripwireTriggered) as excinfo:
122+
async for _ in result.stream_events():
123+
pass
124+
125+
# Exception contains the guardrail result and run data
126+
exc = excinfo.value
127+
assert exc.guardrail_result.output.tripwire_triggered is True
128+
assert exc.run_data is not None
129+
assert len(exc.run_data.input_guardrail_results) == 1
130+
assert (
131+
exc.run_data.input_guardrail_results[0].guardrail.get_name() == "tripping_input_guardrail"
132+
)
133+
134+
135+
class SlowCompleteFakeModel(FakeModel):
136+
"""A FakeModel that delays just before emitting ResponseCompletedEvent in streaming."""
137+
138+
def __init__(self, delay_seconds: float, tracing_enabled: bool = True):
139+
super().__init__(tracing_enabled=tracing_enabled)
140+
self._delay_seconds = delay_seconds
141+
142+
async def stream_response(self, *args, **kwargs): # type: ignore[override]
143+
async for ev in super().stream_response(*args, **kwargs):
144+
if isinstance(ev, ResponseCompletedEvent) and self._delay_seconds > 0:
145+
await asyncio.sleep(self._delay_seconds)
146+
yield ev
147+
148+
149+
def _get_span_by_type(spans, span_type: str):
150+
for s in spans:
151+
exported = s.export()
152+
if not exported:
153+
continue
154+
if exported.get("span_data", {}).get("type") == span_type:
155+
return s
156+
return None
157+
158+
159+
def _iso(s: str | None) -> datetime:
160+
assert s is not None
161+
return datetime.fromisoformat(s)
162+
163+
164+
@pytest.mark.asyncio
165+
async def test_parent_span_and_trace_finish_after_slow_input_guardrail():
166+
"""Agent span and trace finish after guardrail when guardrail completes last."""
167+
168+
model = FakeModel(tracing_enabled=True)
169+
model.set_next_output([get_text_message("Final response")])
170+
agent = Agent(
171+
name="TimingAgentTrace",
172+
model=model,
173+
input_guardrails=[make_input_guardrail(0.2, trip=False)], # guardrail slower than model
174+
)
175+
176+
result = Runner.run_streamed(agent, input="Hello")
177+
async for _ in result.stream_events():
178+
pass
179+
180+
from .testing_processor import fetch_ordered_spans
181+
182+
spans = fetch_ordered_spans()
183+
agent_span = _get_span_by_type(spans, "agent")
184+
guardrail_span = _get_span_by_type(spans, "guardrail")
185+
generation_span = _get_span_by_type(spans, "generation")
186+
187+
assert agent_span and guardrail_span and generation_span, (
188+
"Expected agent, guardrail, generation spans"
189+
)
190+
191+
# Agent span must finish last
192+
assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at)
193+
assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at)
194+
195+
# Trace should end after all spans end
196+
from .testing_processor import fetch_events
197+
198+
events = fetch_events()
199+
assert events[-1] == "trace_end"
200+
201+
202+
@pytest.mark.asyncio
203+
async def test_parent_span_and_trace_finish_after_slow_model():
204+
"""Agent span and trace finish after model when model completes last."""
205+
206+
model = SlowCompleteFakeModel(delay_seconds=0.2, tracing_enabled=True)
207+
model.set_next_output([get_text_message("Final response")])
208+
agent = Agent(
209+
name="TimingAgentTrace",
210+
model=model,
211+
input_guardrails=[make_input_guardrail(0.0, trip=False)], # guardrail faster than model
212+
)
213+
214+
result = Runner.run_streamed(agent, input="Hello")
215+
async for _ in result.stream_events():
216+
pass
217+
218+
from .testing_processor import fetch_ordered_spans
219+
220+
spans = fetch_ordered_spans()
221+
agent_span = _get_span_by_type(spans, "agent")
222+
guardrail_span = _get_span_by_type(spans, "guardrail")
223+
generation_span = _get_span_by_type(spans, "generation")
224+
225+
assert agent_span and guardrail_span and generation_span, (
226+
"Expected agent, guardrail, generation spans"
227+
)
228+
229+
# Agent span must finish last
230+
assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at)
231+
assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at)
232+
233+
from .testing_processor import fetch_events
234+
235+
events = fetch_events()
236+
assert events[-1] == "trace_end"

0 commit comments

Comments
 (0)