Skip to content

Commit f083c48

Browse files
Return error details for streaming + tests
1 parent 7989e0d commit f083c48

File tree

7 files changed

+169
-68
lines changed

7 files changed

+169
-68
lines changed

src/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MaxTurnsExceeded,
1515
ModelBehaviorError,
1616
OutputGuardrailTripwireTriggered,
17+
RunErrorDetails,
1718
UserError,
1819
)
1920
from .guardrail import (
@@ -44,7 +45,7 @@
4445
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
4546
from .models.openai_provider import OpenAIProvider
4647
from .models.openai_responses import OpenAIResponsesModel
47-
from .result import RunErrorDetails, RunResult, RunResultStreaming
48+
from .result import RunResult, RunResultStreaming
4849
from .run import RunConfig, Runner
4950
from .run_context import RunContextWrapper, TContext
5051
from .stream_events import (

src/agents/exceptions.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,49 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Any
45

56
if TYPE_CHECKING:
7+
from .agent import Agent
68
from .guardrail import InputGuardrailResult, OutputGuardrailResult
7-
from .result import RunErrorDetails
9+
from .items import ModelResponse, RunItem, TResponseInputItem
10+
from .run_context import RunContextWrapper
11+
12+
from .util._pretty_print import pretty_print_run_error_details
13+
14+
15+
@dataclass
16+
class RunErrorDetails:
17+
"""Data collected from an agent run when an exception occurs."""
18+
input: str | list[TResponseInputItem]
19+
new_items: list[RunItem]
20+
raw_responses: list[ModelResponse]
21+
last_agent: Agent[Any]
22+
context_wrapper: RunContextWrapper[Any]
23+
input_guardrail_results: list[InputGuardrailResult]
24+
output_guardrail_results: list[OutputGuardrailResult]
25+
26+
def __str__(self) -> str:
27+
return pretty_print_run_error_details(self)
828

929

1030
class AgentsException(Exception):
1131
"""Base class for all exceptions in the Agents SDK."""
32+
run_data: RunErrorDetails | None
33+
34+
def __init__(self, *args: object) -> None:
35+
super().__init__(*args)
36+
self.run_data = None
1237

1338

1439
class MaxTurnsExceeded(AgentsException):
1540
"""Exception raised when the maximum number of turns is exceeded."""
1641

1742
message: str
18-
run_error_details: RunErrorDetails | None
1943

20-
def __init__(self, message: str, run_error_details: RunErrorDetails | None = None):
44+
def __init__(self, message: str):
2145
self.message = message
22-
self.run_error_details = run_error_details
46+
super().__init__(message)
2347

2448

2549
class ModelBehaviorError(AgentsException):
@@ -31,6 +55,7 @@ class ModelBehaviorError(AgentsException):
3155

3256
def __init__(self, message: str):
3357
self.message = message
58+
super().__init__(message)
3459

3560

3661
class UserError(AgentsException):
@@ -40,6 +65,7 @@ class UserError(AgentsException):
4065

4166
def __init__(self, message: str):
4267
self.message = message
68+
super().__init__(message)
4369

4470

4571
class InputGuardrailTripwireTriggered(AgentsException):

src/agents/result.py

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from ._run_impl import QueueCompleteSentinel
1212
from .agent import Agent
1313
from .agent_output import AgentOutputSchemaBase
14-
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
14+
from .exceptions import (
15+
AgentsException,
16+
InputGuardrailTripwireTriggered,
17+
MaxTurnsExceeded,
18+
RunErrorDetails,
19+
)
1520
from .guardrail import InputGuardrailResult, OutputGuardrailResult
1621
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
1722
from .logger import logger
@@ -20,7 +25,6 @@
2025
from .tracing import Trace
2126
from .util._pretty_print import (
2227
pretty_print_result,
23-
pretty_print_run_error_details,
2428
pretty_print_run_result_streaming,
2529
)
2630

@@ -212,29 +216,79 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
212216

213217
def _check_errors(self):
214218
if self.current_turn > self.max_turns:
215-
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
219+
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
220+
max_turns_exc.run_data = RunErrorDetails(
221+
input=self.input,
222+
new_items=self.new_items,
223+
raw_responses=self.raw_responses,
224+
last_agent=self.current_agent,
225+
context_wrapper=self.context_wrapper,
226+
input_guardrail_results=self.input_guardrail_results,
227+
output_guardrail_results=self.output_guardrail_results,
228+
)
229+
self._stored_exception = max_turns_exc
216230

217231
# Fetch all the completed guardrail results from the queue and raise if needed
218232
while not self._input_guardrail_queue.empty():
219233
guardrail_result = self._input_guardrail_queue.get_nowait()
220234
if guardrail_result.output.tripwire_triggered:
221-
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
235+
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result)
236+
tripwire_exc.run_data = RunErrorDetails(
237+
input=self.input,
238+
new_items=self.new_items,
239+
raw_responses=self.raw_responses,
240+
last_agent=self.current_agent,
241+
context_wrapper=self.context_wrapper,
242+
input_guardrail_results=self.input_guardrail_results,
243+
output_guardrail_results=self.output_guardrail_results,
244+
)
245+
self._stored_exception = tripwire_exc
222246

223247
# Check the tasks for any exceptions
224248
if self._run_impl_task and self._run_impl_task.done():
225-
exc = self._run_impl_task.exception()
226-
if exc and isinstance(exc, Exception):
227-
self._stored_exception = exc
249+
run_impl_exc = self._run_impl_task.exception()
250+
if run_impl_exc and isinstance(run_impl_exc, Exception):
251+
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
252+
run_impl_exc.run_data = RunErrorDetails(
253+
input=self.input,
254+
new_items=self.new_items,
255+
raw_responses=self.raw_responses,
256+
last_agent=self.current_agent,
257+
context_wrapper=self.context_wrapper,
258+
input_guardrail_results=self.input_guardrail_results,
259+
output_guardrail_results=self.output_guardrail_results,
260+
)
261+
self._stored_exception = run_impl_exc
228262

229263
if self._input_guardrails_task and self._input_guardrails_task.done():
230-
exc = self._input_guardrails_task.exception()
231-
if exc and isinstance(exc, Exception):
232-
self._stored_exception = exc
264+
in_guard_exc = self._input_guardrails_task.exception()
265+
if in_guard_exc and isinstance(in_guard_exc, Exception):
266+
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
267+
in_guard_exc.run_data = RunErrorDetails(
268+
input=self.input,
269+
new_items=self.new_items,
270+
raw_responses=self.raw_responses,
271+
last_agent=self.current_agent,
272+
context_wrapper=self.context_wrapper,
273+
input_guardrail_results=self.input_guardrail_results,
274+
output_guardrail_results=self.output_guardrail_results,
275+
)
276+
self._stored_exception = in_guard_exc
233277

234278
if self._output_guardrails_task and self._output_guardrails_task.done():
235-
exc = self._output_guardrails_task.exception()
236-
if exc and isinstance(exc, Exception):
237-
self._stored_exception = exc
279+
out_guard_exc = self._output_guardrails_task.exception()
280+
if out_guard_exc and isinstance(out_guard_exc, Exception):
281+
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
282+
out_guard_exc.run_data = RunErrorDetails(
283+
input=self.input,
284+
new_items=self.new_items,
285+
raw_responses=self.raw_responses,
286+
last_agent=self.current_agent,
287+
context_wrapper=self.context_wrapper,
288+
input_guardrail_results=self.input_guardrail_results,
289+
output_guardrail_results=self.output_guardrail_results,
290+
)
291+
self._stored_exception = out_guard_exc
238292

239293
def _cleanup_tasks(self):
240294
if self._run_impl_task and not self._run_impl_task.done():
@@ -249,34 +303,3 @@ def _cleanup_tasks(self):
249303
def __str__(self) -> str:
250304
return pretty_print_run_result_streaming(self)
251305

252-
253-
@dataclass
254-
class RunErrorDetails:
255-
input: str | list[TResponseInputItem]
256-
"""The original input items i.e. the items before run() was called. This may be a mutated
257-
version of the input, if there are handoff input filters that mutate the input.
258-
"""
259-
260-
new_items: list[RunItem]
261-
"""The new items generated during the agent run. These include things like new messages, tool
262-
calls and their outputs, etc.
263-
"""
264-
265-
raw_responses: list[ModelResponse]
266-
"""The raw LLM responses generated by the model during the agent run."""
267-
268-
input_guardrail_results: list[InputGuardrailResult]
269-
"""Guardrail results for the input messages."""
270-
271-
context_wrapper: RunContextWrapper[Any]
272-
"""The context wrapper for the agent run."""
273-
274-
_last_agent: Agent[Any]
275-
276-
@property
277-
def last_agent(self) -> Agent[Any]:
278-
"""The last agent that was run."""
279-
return self._last_agent
280-
281-
def __str__(self) -> str:
282-
return pretty_print_run_error_details(self)

src/agents/run.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
MaxTurnsExceeded,
2828
ModelBehaviorError,
2929
OutputGuardrailTripwireTriggered,
30+
RunErrorDetails,
3031
)
3132
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
3233
from .handoffs import Handoff, HandoffInputFilter, handoff
@@ -36,7 +37,7 @@
3637
from .model_settings import ModelSettings
3738
from .models.interface import Model, ModelProvider
3839
from .models.multi_provider import MultiProvider
39-
from .result import RunErrorDetails, RunResult, RunResultStreaming
40+
from .result import RunResult, RunResultStreaming
4041
from .run_context import RunContextWrapper, TContext
4142
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
4243
from .tool import Tool
@@ -286,23 +287,17 @@ async def run(
286287
raise AgentsException(
287288
f"Unknown next step type: {type(turn_result.next_step)}"
288289
)
289-
except Exception as e:
290-
run_error_details = RunErrorDetails(
290+
except AgentsException as exc:
291+
exc.run_data = RunErrorDetails(
291292
input=original_input,
292293
new_items=generated_items,
293294
raw_responses=model_responses,
294-
input_guardrail_results=input_guardrail_results,
295+
last_agent=current_agent,
295296
context_wrapper=context_wrapper,
296-
_last_agent=current_agent
297+
input_guardrail_results=input_guardrail_results,
298+
output_guardrail_results=[]
297299
)
298-
# Re-raise with the error details
299-
if isinstance(e, MaxTurnsExceeded):
300-
raise MaxTurnsExceeded(
301-
f"Max turns ({max_turns}) exceeded",
302-
run_error_details
303-
) from e
304-
else:
305-
raise
300+
raise
306301
finally:
307302
if current_span:
308303
current_span.finish(reset_current=True)
@@ -629,6 +624,19 @@ async def _run_streamed_impl(
629624
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
630625
elif isinstance(turn_result.next_step, NextStepRunAgain):
631626
pass
627+
except AgentsException as exc:
628+
streamed_result.is_complete = True
629+
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
630+
exc.run_data = RunErrorDetails(
631+
input=streamed_result.input,
632+
new_items=streamed_result.new_items,
633+
raw_responses=streamed_result.raw_responses,
634+
last_agent=current_agent,
635+
context_wrapper=context_wrapper,
636+
input_guardrail_results=streamed_result.input_guardrail_results,
637+
output_guardrail_results=streamed_result.output_guardrail_results,
638+
)
639+
raise
632640
except Exception as e:
633641
if current_span:
634642
_error_tracing.attach_error_to_span(

src/agents/util/_pretty_print.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from pydantic import BaseModel
44

55
if TYPE_CHECKING:
6-
from ..result import RunErrorDetails, RunResult, RunResultBase, RunResultStreaming
6+
from ..exceptions import RunErrorDetails
7+
from ..result import RunResult, RunResultBase, RunResultStreaming
78

89

910
def _indent(text: str, indent_level: int) -> str:
@@ -37,6 +38,7 @@ def pretty_print_result(result: "RunResult") -> str:
3738

3839
return output
3940

41+
4042
def pretty_print_run_error_details(result: "RunErrorDetails") -> str:
4143
output = "RunErrorDetails:"
4244
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
@@ -47,6 +49,7 @@ def pretty_print_run_error_details(result: "RunErrorDetails") -> str:
4749

4850
return output
4951

52+
5053
def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
5154
output = "RunResultStreaming:"
5255
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'

tests/test_run_error_details.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import json
2+
3+
import pytest
4+
5+
from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner
6+
7+
from .fake_model import FakeModel
8+
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_run_error_includes_data():
13+
model = FakeModel()
14+
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
15+
model.add_multiple_turn_outputs([
16+
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
17+
[get_text_message("done")],
18+
])
19+
with pytest.raises(MaxTurnsExceeded) as exc:
20+
await Runner.run(agent, input="hello", max_turns=1)
21+
data = exc.value.run_data
22+
assert isinstance(data, RunErrorDetails)
23+
assert data.last_agent == agent
24+
assert len(data.raw_responses) == 1
25+
assert len(data.new_items) > 0
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_streamed_run_error_includes_data():
30+
model = FakeModel()
31+
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
32+
model.add_multiple_turn_outputs([
33+
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
34+
[get_text_message("done")],
35+
])
36+
result = Runner.run_streamed(agent, input="hello", max_turns=1)
37+
with pytest.raises(MaxTurnsExceeded) as exc:
38+
async for _ in result.stream_events():
39+
pass
40+
data = exc.value.run_data
41+
assert isinstance(data, RunErrorDetails)
42+
assert data.last_agent == agent
43+
assert len(data.raw_responses) == 1
44+
assert len(data.new_items) > 0

tests/test_tracing_errors_streamed.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,6 @@ async def test_tool_call_error():
168168
"children": [
169169
{
170170
"type": "agent",
171-
"error": {
172-
"message": "Error in agent run",
173-
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
174-
},
175171
"data": {
176172
"name": "test_agent",
177173
"handoffs": [],

0 commit comments

Comments
 (0)