Skip to content

Commit 518914a

Browse files
draft
1 parent 0c4f2b9 commit 518914a

File tree

7 files changed

+218
-8
lines changed

7 files changed

+218
-8
lines changed

src/agents/exceptions.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, Any
3+
from dataclasses import dataclass, replace
4+
from textwrap import dedent
5+
from typing import TYPE_CHECKING, Any, Optional
56

67
if TYPE_CHECKING:
78
from .agent import Agent
89
from .guardrail import InputGuardrailResult, OutputGuardrailResult
910
from .items import ModelResponse, RunItem, TResponseInputItem
1011
from .run_context import RunContextWrapper
12+
from .run import RunConfig
13+
from .result import RunResult
1114
from .tool_guardrails import (
1215
ToolGuardrailFunctionOutput,
1316
ToolInputGuardrail,
@@ -28,6 +31,7 @@ class RunErrorDetails:
2831
context_wrapper: RunContextWrapper[Any]
2932
input_guardrail_results: list[InputGuardrailResult]
3033
output_guardrail_results: list[OutputGuardrailResult]
34+
run_config: RunConfig
3135

3236
def __str__(self) -> str:
3337
return pretty_print_run_error_details(self)
@@ -48,10 +52,99 @@ class MaxTurnsExceeded(AgentsException):
4852

4953
message: str
5054

55+
_DEFAULT_RESUME_PROMPT = """
56+
You reached the maximum number of turns.
57+
Return a final answer to the query using ONLY the information already gathered in the conversation so far.
58+
"""
59+
5160
def __init__(self, message: str):
5261
self.message = message
5362
super().__init__(message)
5463

64+
def resume(self, prompt: Optional[str] = _DEFAULT_RESUME_PROMPT) -> RunResult:
65+
"""Resume the failed run synchronously with a final, tool-free turn.
66+
67+
Args:
68+
prompt: Optional user instruction to append before rerunning the final turn.
69+
Pass ``None`` to skip injecting an extra message; defaults to a reminder
70+
to produce a final answer from existing context.
71+
"""
72+
run_data = self._require_run_data()
73+
inputs, run_config = self._prepare_resume_arguments(run_data, prompt)
74+
75+
from .run import Runner
76+
77+
return Runner.run_sync(
78+
starting_agent=run_data.last_agent,
79+
input=inputs,
80+
context=run_data.context_wrapper.context,
81+
max_turns=1,
82+
run_config=run_config,
83+
)
84+
85+
async def resume_async(self, prompt: Optional[str] = _DEFAULT_RESUME_PROMPT) -> RunResult:
86+
"""Resume the failed run asynchronously with a final, tool-free turn.
87+
88+
Args:
89+
prompt: Optional user instruction to append before rerunning the final turn.
90+
Pass ``None`` to skip injecting an extra message; defaults to a reminder
91+
to produce a final answer from existing context.
92+
"""
93+
run_data = self._require_run_data()
94+
inputs, run_config = self._prepare_resume_arguments(run_data, prompt)
95+
96+
from .run import Runner
97+
98+
return await Runner.run(
99+
starting_agent=run_data.last_agent,
100+
input=inputs,
101+
context=run_data.context_wrapper.context,
102+
max_turns=1,
103+
run_config=run_config,
104+
)
105+
106+
def _prepare_resume_arguments(
107+
self,
108+
run_data: RunErrorDetails,
109+
prompt: Optional[str] = None,
110+
) -> tuple[list[TResponseInputItem], RunConfig]:
111+
from .items import ItemHelpers
112+
from .model_settings import ModelSettings
113+
114+
history: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(run_data.input)
115+
for item in run_data.new_items:
116+
history.append(item.to_input_item())
117+
118+
normalized_prompt = self._normalize_resume_prompt(prompt)
119+
if normalized_prompt is not None:
120+
history.append({"content": normalized_prompt, "role": "user"})
121+
122+
run_config = replace(run_data.run_config)
123+
if run_config.model_settings is None:
124+
run_config.model_settings = ModelSettings(tool_choice="none")
125+
else:
126+
run_config.model_settings = run_config.model_settings.resolve(
127+
ModelSettings(tool_choice="none")
128+
)
129+
130+
return (
131+
history,
132+
run_config,
133+
)
134+
135+
def _normalize_resume_prompt(self, prompt: Optional[str]) -> Optional[str]:
136+
if prompt is None:
137+
return None
138+
normalized = dedent(prompt).strip()
139+
return normalized or None
140+
141+
def _require_run_data(self) -> RunErrorDetails:
142+
if self.run_data is None:
143+
raise RuntimeError(
144+
"Run data is not available; resume() can only be called on exceptions raised by Runner."
145+
)
146+
return self.run_data
147+
55148

56149
class ModelBehaviorError(AgentsException):
57150
"""Exception raised when the model does something unexpected, e.g. calling a tool that doesn't

src/agents/result.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import abc
44
import asyncio
55
from collections.abc import AsyncIterator
6-
from dataclasses import dataclass, field
6+
from dataclasses import dataclass, field, replace
77
from typing import TYPE_CHECKING, Any, Literal, cast
88

99
from typing_extensions import TypeVar
@@ -31,6 +31,7 @@
3131
if TYPE_CHECKING:
3232
from ._run_impl import QueueCompleteSentinel
3333
from .agent import Agent
34+
from .run import RunConfig
3435
from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
3536

3637
T = TypeVar("T")
@@ -69,6 +70,9 @@ class RunResultBase(abc.ABC):
6970
context_wrapper: RunContextWrapper[Any]
7071
"""The context wrapper for the agent run."""
7172

73+
run_config: "RunConfig"
74+
"""The run configuration that was used for the agent run."""
75+
7276
@property
7377
@abc.abstractmethod
7478
def last_agent(self) -> Agent[Any]:
@@ -279,6 +283,7 @@ def _create_error_details(self) -> RunErrorDetails:
279283
context_wrapper=self.context_wrapper,
280284
input_guardrail_results=self.input_guardrail_results,
281285
output_guardrail_results=self.output_guardrail_results,
286+
run_config=replace(self.run_config),
282287
)
283288

284289
def _check_errors(self):

src/agents/run.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
import os
77
import warnings
8-
from dataclasses import dataclass, field
8+
from dataclasses import dataclass, field, replace
99
from typing import Any, Callable, Generic, cast, get_args
1010

1111
from openai.types.responses import (
@@ -665,6 +665,7 @@ async def run(
665665
output_guardrail_results=output_guardrail_results,
666666
tool_input_guardrail_results=tool_input_guardrail_results,
667667
tool_output_guardrail_results=tool_output_guardrail_results,
668+
run_config=replace(run_config),
668669
context_wrapper=context_wrapper,
669670
)
670671
if not any(
@@ -702,6 +703,7 @@ async def run(
702703
context_wrapper=context_wrapper,
703704
input_guardrail_results=input_guardrail_results,
704705
output_guardrail_results=[],
706+
run_config=replace(run_config),
705707
)
706708
raise
707709
finally:
@@ -837,6 +839,7 @@ def run_streamed(
837839
output_guardrail_results=[],
838840
tool_input_guardrail_results=[],
839841
tool_output_guardrail_results=[],
842+
run_config=replace(run_config),
840843
_current_agent_output_schema=output_schema,
841844
trace=new_trace,
842845
context_wrapper=context_wrapper,
@@ -1174,6 +1177,7 @@ async def _start_streaming(
11741177
context_wrapper=context_wrapper,
11751178
input_guardrail_results=streamed_result.input_guardrail_results,
11761179
output_guardrail_results=streamed_result.output_guardrail_results,
1180+
run_config=replace(run_config),
11771181
)
11781182
raise
11791183
except Exception as e:

tests/extensions/memory/test_advanced_sqlite_session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed
88
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
99

10-
from agents import Agent, Runner, TResponseInputItem, function_tool
10+
from agents import Agent, RunConfig, Runner, TResponseInputItem, function_tool
1111
from agents.extensions.memory import AdvancedSQLiteSession
1212
from agents.result import RunResult
1313
from agents.run_context import RunContextWrapper
@@ -74,6 +74,7 @@ def create_mock_run_result(
7474
tool_output_guardrail_results=[],
7575
context_wrapper=context_wrapper,
7676
_last_agent=agent,
77+
run_config=RunConfig(),
7778
)
7879

7980

tests/test_max_turns.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from typing_extensions import TypedDict
77

8-
from agents import Agent, MaxTurnsExceeded, Runner
8+
from agents import Agent, MaxTurnsExceeded, ModelSettings, RunConfig, Runner
99

1010
from .fake_model import FakeModel
1111
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
@@ -75,6 +75,110 @@ async def test_streamed_max_turns():
7575
pass
7676

7777

78+
@pytest.mark.asyncio
79+
async def test_max_turns_resume_async_runs_final_turn():
80+
model = FakeModel()
81+
agent = Agent(
82+
name="test_1",
83+
model=model,
84+
tools=[get_function_tool("some_function", "result")],
85+
)
86+
87+
func_output = json.dumps({"a": "b"})
88+
final_answer = "final answer"
89+
90+
model.add_multiple_turn_outputs(
91+
[
92+
[get_text_message("1"), get_function_tool_call("some_function", func_output)],
93+
[get_text_message("2"), get_function_tool_call("some_function", func_output)],
94+
[get_text_message(final_answer)],
95+
]
96+
)
97+
98+
with pytest.raises(MaxTurnsExceeded) as exc_info:
99+
await Runner.run(agent, input="user_message", max_turns=2)
100+
101+
result = await exc_info.value.resume_async("Finish without tools.")
102+
103+
assert result.final_output == final_answer
104+
resume_input = model.last_turn_args["input"]
105+
assert resume_input[0]["content"] == "user_message"
106+
assert resume_input[-1] == {"content": "Finish without tools.", "role": "user"}
107+
assert any(item.get("type") == "function_call_output" for item in resume_input)
108+
assert model.last_turn_args["model_settings"].tool_choice == "none"
109+
110+
111+
def test_max_turns_resume_sync_uses_default_prompt():
112+
model = FakeModel()
113+
agent = Agent(
114+
name="test_1",
115+
model=model,
116+
tools=[get_function_tool("some_function", "result")],
117+
)
118+
119+
func_output = json.dumps({"a": "b"})
120+
final_answer = "final answer"
121+
122+
model.add_multiple_turn_outputs(
123+
[
124+
[get_text_message("1"), get_function_tool_call("some_function", func_output)],
125+
[get_text_message("2"), get_function_tool_call("some_function", func_output)],
126+
[get_text_message(final_answer)],
127+
]
128+
)
129+
130+
with pytest.raises(MaxTurnsExceeded) as exc_info:
131+
Runner.run_sync(agent, input="user_message", max_turns=2)
132+
133+
result = exc_info.value.resume()
134+
135+
assert result.final_output == final_answer
136+
resume_input = model.last_turn_args["input"]
137+
expected_prompt = (
138+
"You reached the maximum number of turns.\n"
139+
"Return a final answer to the query using ONLY the information already gathered in the conversation so far."
140+
)
141+
assert resume_input[-1] == {"content": expected_prompt, "role": "user"}
142+
assert model.last_turn_args["model_settings"].tool_choice == "none"
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_resume_preserves_run_config_settings():
147+
model = FakeModel()
148+
agent = Agent(
149+
name="test_1",
150+
model=model,
151+
tools=[get_function_tool("some_function", "result")],
152+
)
153+
154+
func_output = json.dumps({"a": "b"})
155+
final_answer = "final answer"
156+
157+
model.add_multiple_turn_outputs(
158+
[
159+
[get_text_message("1"), get_function_tool_call("some_function", func_output)],
160+
[get_text_message("2"), get_function_tool_call("some_function", func_output)],
161+
[get_text_message(final_answer)],
162+
]
163+
)
164+
165+
run_config = RunConfig(model_settings=ModelSettings(temperature=0.25, tool_choice="auto"))
166+
167+
with pytest.raises(MaxTurnsExceeded) as exc_info:
168+
await Runner.run(agent, input="user_message", max_turns=2, run_config=run_config)
169+
170+
await exc_info.value.resume_async("Finish without tools.")
171+
172+
final_settings = model.last_turn_args["model_settings"]
173+
assert final_settings.temperature == 0.25
174+
assert final_settings.tool_choice == "none"
175+
176+
stored_settings = exc_info.value.run_data.run_config.model_settings
177+
assert stored_settings is not None
178+
assert stored_settings.temperature == 0.25
179+
assert stored_settings.tool_choice == "auto"
180+
181+
78182
class Foo(TypedDict):
79183
a: str
80184

tests/test_result_cast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from pydantic import BaseModel
55

6-
from agents import Agent, RunContextWrapper, RunResult
6+
from agents import Agent, RunConfig, RunContextWrapper, RunResult
77

88

99
def create_run_result(final_output: Any) -> RunResult:
@@ -18,6 +18,7 @@ def create_run_result(final_output: Any) -> RunResult:
1818
tool_output_guardrail_results=[],
1919
_last_agent=Agent(name="test"),
2020
context_wrapper=RunContextWrapper(context=None),
21+
run_config=RunConfig(),
2122
)
2223

2324

tests/test_run_error_details.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner
5+
from agents import Agent, MaxTurnsExceeded, RunConfig, RunErrorDetails, Runner
66

77
from .fake_model import FakeModel
88
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
@@ -25,6 +25,7 @@ async def test_run_error_includes_data():
2525
assert data.last_agent == agent
2626
assert len(data.raw_responses) == 1
2727
assert len(data.new_items) > 0
28+
assert isinstance(data.run_config, RunConfig)
2829

2930

3031
@pytest.mark.asyncio
@@ -46,3 +47,4 @@ async def test_streamed_run_error_includes_data():
4647
assert data.last_agent == agent
4748
assert len(data.raw_responses) == 1
4849
assert len(data.new_items) > 0
50+
assert isinstance(data.run_config, RunConfig)

0 commit comments

Comments
 (0)