Skip to content

Commit b32bb91

Browse files
committed
bug: preserve usage tracking on streaming errors
1 parent 4bc33e3 commit b32bb91

File tree

7 files changed

+101
-16
lines changed

7 files changed

+101
-16
lines changed

src/agents/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MaxTurnsExceeded,
2121
ModelBehaviorError,
2222
OutputGuardrailTripwireTriggered,
23+
RunError,
2324
RunErrorDetails,
2425
ToolInputGuardrailTripwireTriggered,
2526
ToolOutputGuardrailTripwireTriggered,
@@ -212,6 +213,8 @@ def enable_verbose_stdout_logging():
212213
"OutputGuardrailTripwireTriggered",
213214
"ToolInputGuardrailTripwireTriggered",
214215
"ToolOutputGuardrailTripwireTriggered",
216+
"RunError",
217+
"RunErrorDetails",
215218
"DynamicPromptFunction",
216219
"GenerateDynamicPromptData",
217220
"Prompt",

src/agents/exceptions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,20 @@ def __init__(self, guardrail: ToolOutputGuardrail[Any], output: ToolGuardrailFun
129129
self.guardrail = guardrail
130130
self.output = output
131131
super().__init__(f"Tool output guardrail {guardrail.__class__.__name__} triggered tripwire")
132+
133+
134+
class RunError(AgentsException):
135+
"""Wrapper exception for non-AgentsException errors that occur during agent runs.
136+
137+
This exception wraps external errors (API errors, connection failures, etc.) to ensure
138+
that run data including usage information is preserved and accessible.
139+
"""
140+
141+
original_exception: Exception
142+
"""The original exception that was raised."""
143+
144+
def __init__(self, original_exception: Exception):
145+
self.original_exception = original_exception
146+
super().__init__(str(original_exception))
147+
# Preserve the original exception as the cause
148+
self.__cause__ = original_exception

src/agents/result.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AgentsException,
1616
InputGuardrailTripwireTriggered,
1717
MaxTurnsExceeded,
18+
RunError,
1819
RunErrorDetails,
1920
)
2021
from .guardrail import InputGuardrailResult, OutputGuardrailResult
@@ -299,23 +300,40 @@ def _check_errors(self):
299300
if self._run_impl_task and self._run_impl_task.done():
300301
run_impl_exc = self._run_impl_task.exception()
301302
if run_impl_exc and isinstance(run_impl_exc, Exception):
302-
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
303-
run_impl_exc.run_data = self._create_error_details()
304-
self._stored_exception = run_impl_exc
303+
if isinstance(run_impl_exc, AgentsException):
304+
# For AgentsException, attach run_data if missing
305+
if run_impl_exc.run_data is None:
306+
run_impl_exc.run_data = self._create_error_details()
307+
self._stored_exception = run_impl_exc
308+
else:
309+
# For non-AgentsException, wrap it to preserve run_data
310+
wrapped_exc = RunError(run_impl_exc)
311+
wrapped_exc.run_data = self._create_error_details()
312+
self._stored_exception = wrapped_exc
305313

306314
if self._input_guardrails_task and self._input_guardrails_task.done():
307315
in_guard_exc = self._input_guardrails_task.exception()
308316
if in_guard_exc and isinstance(in_guard_exc, Exception):
309-
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
310-
in_guard_exc.run_data = self._create_error_details()
311-
self._stored_exception = in_guard_exc
317+
if isinstance(in_guard_exc, AgentsException):
318+
if in_guard_exc.run_data is None:
319+
in_guard_exc.run_data = self._create_error_details()
320+
self._stored_exception = in_guard_exc
321+
else:
322+
wrapped_exc = RunError(in_guard_exc)
323+
wrapped_exc.run_data = self._create_error_details()
324+
self._stored_exception = wrapped_exc
312325

313326
if self._output_guardrails_task and self._output_guardrails_task.done():
314327
out_guard_exc = self._output_guardrails_task.exception()
315328
if out_guard_exc and isinstance(out_guard_exc, Exception):
316-
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
317-
out_guard_exc.run_data = self._create_error_details()
318-
self._stored_exception = out_guard_exc
329+
if isinstance(out_guard_exc, AgentsException):
330+
if out_guard_exc.run_data is None:
331+
out_guard_exc.run_data = self._create_error_details()
332+
self._stored_exception = out_guard_exc
333+
else:
334+
wrapped_exc = RunError(out_guard_exc)
335+
wrapped_exc.run_data = self._create_error_details()
336+
self._stored_exception = wrapped_exc
319337

320338
def _cleanup_tasks(self):
321339
if self._run_impl_task and not self._run_impl_task.done():

src/agents/run.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MaxTurnsExceeded,
3636
ModelBehaviorError,
3737
OutputGuardrailTripwireTriggered,
38+
RunError,
3839
RunErrorDetails,
3940
UserError,
4041
)
@@ -702,6 +703,19 @@ async def run(
702703
output_guardrail_results=[],
703704
)
704705
raise
706+
except Exception as exc:
707+
# Wrap non-AgentsException to preserve run_data including usage
708+
wrapped_exc = RunError(exc)
709+
wrapped_exc.run_data = RunErrorDetails(
710+
input=original_input,
711+
new_items=generated_items,
712+
raw_responses=model_responses,
713+
last_agent=current_agent,
714+
context_wrapper=context_wrapper,
715+
input_guardrail_results=input_guardrail_results,
716+
output_guardrail_results=[],
717+
)
718+
raise wrapped_exc from exc
705719
finally:
706720
if current_span:
707721
current_span.finish(reset_current=True)

tests/test_run_hooks.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ async def test_async_run_hooks_with_agent_hooks_with_llm():
172172

173173
@pytest.mark.asyncio
174174
async def test_run_hooks_llm_error_non_streaming(monkeypatch):
175+
from agents import RunError
176+
175177
hooks = RunHooksForTests()
176178
model = FakeModel()
177179
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
@@ -181,9 +183,16 @@ async def boom(*args, **kwargs):
181183

182184
monkeypatch.setattr(FakeModel, "get_response", boom, raising=True)
183185

184-
with pytest.raises(RuntimeError, match="boom"):
186+
with pytest.raises(RunError) as exc_info:
185187
await Runner.run(agent, input="hello", hooks=hooks)
186188

189+
# Verify the original exception is preserved
190+
assert isinstance(exc_info.value.original_exception, RuntimeError)
191+
assert str(exc_info.value.original_exception) == "boom"
192+
# Verify run_data is attached
193+
assert exc_info.value.run_data is not None
194+
assert exc_info.value.run_data.context_wrapper is not None
195+
187196
# Current behavior is that hooks will not fire on LLM failure
188197
assert hooks.events["on_agent_start"] == 1
189198
assert hooks.events["on_llm_start"] == 1
@@ -229,16 +238,26 @@ async def test_streamed_run_hooks_llm_error(monkeypatch):
229238
Verify that when the streaming path raises, we still emit on_llm_start
230239
but do NOT emit on_llm_end (current behavior), and the exception propagates.
231240
"""
241+
from agents import RunError
242+
232243
hooks = RunHooksForTests()
233244
agent = Agent(name="A", model=BoomModel(), tools=[get_function_tool("f", "res")], handoffs=[])
234245

235246
stream = Runner.run_streamed(agent, input="hello", hooks=hooks)
236247

237-
# Consuming the stream should surface the exception
238-
with pytest.raises(RuntimeError, match="stream blew up"):
248+
# Consuming the stream should surface the exception (wrapped in RunError to preserve usage data)
249+
with pytest.raises(RunError) as exc_info:
239250
async for _ in stream.stream_events():
240251
pass
241252

253+
# Verify the original exception is preserved and accessible
254+
assert isinstance(exc_info.value.original_exception, RuntimeError)
255+
assert str(exc_info.value.original_exception) == "stream blew up"
256+
# Verify run_data is attached with usage information
257+
assert exc_info.value.run_data is not None
258+
assert exc_info.value.run_data.context_wrapper is not None
259+
assert exc_info.value.run_data.context_wrapper.usage is not None
260+
242261
# Current behavior: success-only on_llm_end; ensure starts fired but ends did not.
243262
assert hooks.events["on_agent_start"] == 1
244263
assert hooks.events["on_llm_start"] == 1

tests/test_tracing_errors.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MaxTurnsExceeded,
1616
ModelBehaviorError,
1717
RunContextWrapper,
18+
RunError,
1819
Runner,
1920
TResponseInputItem,
2021
)
@@ -39,9 +40,12 @@ async def test_single_turn_model_error():
3940
name="test_agent",
4041
model=model,
4142
)
42-
with pytest.raises(ValueError):
43+
with pytest.raises(RunError) as exc_info:
4344
await Runner.run(agent, input="first_test")
4445

46+
# Verify the original exception is preserved
47+
assert isinstance(exc_info.value.original_exception, ValueError)
48+
4549
assert fetch_normalized_spans() == snapshot(
4650
[
4751
{
@@ -92,9 +96,12 @@ async def test_multi_turn_no_handoffs():
9296
]
9397
)
9498

95-
with pytest.raises(ValueError):
99+
with pytest.raises(RunError) as exc_info:
96100
await Runner.run(agent, input="first_test")
97101

102+
# Verify the original exception is preserved
103+
assert isinstance(exc_info.value.original_exception, ValueError)
104+
98105
assert fetch_normalized_spans() == snapshot(
99106
[
100107
{

tests/test_tracing_errors_streamed.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
OutputGuardrail,
1919
OutputGuardrailTripwireTriggered,
2020
RunContextWrapper,
21+
RunError,
2122
Runner,
2223
TResponseInputItem,
2324
)
@@ -42,11 +43,14 @@ async def test_single_turn_model_error():
4243
name="test_agent",
4344
model=model,
4445
)
45-
with pytest.raises(ValueError):
46+
with pytest.raises(RunError) as exc_info:
4647
result = Runner.run_streamed(agent, input="first_test")
4748
async for _ in result.stream_events():
4849
pass
4950

51+
# Verify the original exception is preserved
52+
assert isinstance(exc_info.value.original_exception, ValueError)
53+
5054
assert fetch_normalized_spans() == snapshot(
5155
[
5256
{
@@ -98,11 +102,14 @@ async def test_multi_turn_no_handoffs():
98102
]
99103
)
100104

101-
with pytest.raises(ValueError):
105+
with pytest.raises(RunError) as exc_info:
102106
result = Runner.run_streamed(agent, input="first_test")
103107
async for _ in result.stream_events():
104108
pass
105109

110+
# Verify the original exception is preserved
111+
assert isinstance(exc_info.value.original_exception, ValueError)
112+
106113
assert fetch_normalized_spans() == snapshot(
107114
[
108115
{

0 commit comments

Comments
 (0)