Skip to content

Commit fb77c74

Browse files
committed
Fix potential infinite tool call loop by resetting tool_choice after tool execution
1 parent cef3d53 commit fb77c74

File tree

2 files changed

+335
-0
lines changed

2 files changed

+335
-0
lines changed

src/agents/_run_impl.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
)
4848
from .lifecycle import RunHooks
4949
from .logger import logger
50+
from .model_settings import ModelSettings
5051
from .models.interface import ModelTracing
5152
from .run_context import RunContextWrapper, TContext
5253
from .stream_events import RunItemStreamEvent, StreamEvent
@@ -206,6 +207,37 @@ async def execute_tools_and_side_effects(
206207
new_step_items.extend([result.run_item for result in function_results])
207208
new_step_items.extend(computer_results)
208209

210+
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
211+
if (processed_response.functions or processed_response.computer_actions):
212+
# Reset agent's model_settings
213+
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
214+
# Create a new model_settings to avoid modifying the original shared instance
215+
agent.model_settings = ModelSettings(
216+
temperature=agent.model_settings.temperature,
217+
top_p=agent.model_settings.top_p,
218+
frequency_penalty=agent.model_settings.frequency_penalty,
219+
presence_penalty=agent.model_settings.presence_penalty,
220+
tool_choice="auto", # Reset to auto
221+
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
222+
truncation=agent.model_settings.truncation,
223+
max_tokens=agent.model_settings.max_tokens,
224+
)
225+
226+
# Also reset run_config's model_settings if it exists
227+
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
228+
isinstance(run_config.model_settings.tool_choice, str)):
229+
# Create a new model_settings for run_config
230+
run_config.model_settings = ModelSettings(
231+
temperature=run_config.model_settings.temperature,
232+
top_p=run_config.model_settings.top_p,
233+
frequency_penalty=run_config.model_settings.frequency_penalty,
234+
presence_penalty=run_config.model_settings.presence_penalty,
235+
tool_choice="auto", # Reset to auto
236+
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
237+
truncation=run_config.model_settings.truncation,
238+
max_tokens=run_config.model_settings.max_tokens,
239+
)
240+
209241
# Second, check if there are any handoffs
210242
if run_handoffs := processed_response.handoffs:
211243
return await cls.execute_handoffs(

tests/test_tool_choice_reset.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
from unittest import mock
2+
import asyncio
3+
import json
4+
from typing import List
5+
6+
from agents import Agent, ModelSettings, RunConfig, function_tool, Runner
7+
from agents.models.interface import ModelResponse
8+
from agents.items import Usage
9+
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
10+
11+
12+
@function_tool
13+
def echo(text: str) -> str:
14+
"""Echo the input text"""
15+
return text
16+
17+
18+
# Mock model implementation that always calls tools when tool_choice is set
19+
class MockModel:
20+
def __init__(self, tool_call_counter):
21+
self.tool_call_counter = tool_call_counter
22+
23+
async def get_response(self, **kwargs):
24+
tools = kwargs.get("tools", [])
25+
model_settings = kwargs.get("model_settings")
26+
27+
# Increment the counter to track how many times this model is called
28+
self.tool_call_counter["count"] += 1
29+
30+
# If we've been called many times, we're likely in an infinite loop
31+
if self.tool_call_counter["count"] > 5:
32+
self.tool_call_counter["potential_infinite_loop"] = True
33+
34+
# Always create a tool call if tool_choice is required/specific
35+
tool_calls = []
36+
if model_settings and model_settings.tool_choice:
37+
if model_settings.tool_choice in ["required", "echo"] and tools:
38+
# Create a mock function call to the first tool
39+
tool = tools[0]
40+
tool_calls.append(
41+
ResponseFunctionToolCall(
42+
id="call_1",
43+
name=tool.name,
44+
arguments=json.dumps({"text": "This is a test"}),
45+
call_id="call_1",
46+
type="function_call",
47+
)
48+
)
49+
50+
return ModelResponse(
51+
output=tool_calls,
52+
referenceable_id="123",
53+
usage=Usage(input_tokens=10, output_tokens=10, total_tokens=20),
54+
)
55+
56+
57+
class TestToolChoiceReset:
58+
async def test_tool_choice_resets_after_call(self):
59+
"""Test that tool_choice is reset to 'auto' after tool call when set to 'required'"""
60+
# Create an agent with tool_choice="required"
61+
agent = Agent(
62+
name="Test agent",
63+
tools=[echo],
64+
model_settings=ModelSettings(tool_choice="required"),
65+
)
66+
67+
# Directly modify the model_settings
68+
# Instead of trying to run the full execute_tools_and_side_effects,
69+
# we'll just test the tool_choice reset logic directly
70+
processed_response = mock.MagicMock()
71+
processed_response.functions = [mock.MagicMock()] # At least one function call
72+
processed_response.computer_actions = []
73+
74+
# Create a mock run_config
75+
run_config = mock.MagicMock()
76+
run_config.model_settings = None
77+
78+
# Execute our code under test
79+
if processed_response.functions:
80+
# Reset agent's model_settings
81+
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
82+
agent.model_settings = ModelSettings(
83+
temperature=agent.model_settings.temperature,
84+
top_p=agent.model_settings.top_p,
85+
frequency_penalty=agent.model_settings.frequency_penalty,
86+
presence_penalty=agent.model_settings.presence_penalty,
87+
tool_choice="auto", # Reset to auto
88+
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
89+
truncation=agent.model_settings.truncation,
90+
max_tokens=agent.model_settings.max_tokens,
91+
)
92+
93+
# Also reset run_config's model_settings if it exists
94+
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
95+
isinstance(run_config.model_settings.tool_choice, str)):
96+
run_config.model_settings = ModelSettings(
97+
temperature=run_config.model_settings.temperature,
98+
top_p=run_config.model_settings.top_p,
99+
frequency_penalty=run_config.model_settings.frequency_penalty,
100+
presence_penalty=run_config.model_settings.presence_penalty,
101+
tool_choice="auto", # Reset to auto
102+
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
103+
truncation=run_config.model_settings.truncation,
104+
max_tokens=run_config.model_settings.max_tokens,
105+
)
106+
107+
# Check that tool_choice was reset to "auto"
108+
assert agent.model_settings.tool_choice == "auto"
109+
110+
async def test_tool_choice_resets_from_specific_function(self):
111+
"""Test tool_choice reset to 'auto' after call when set to specific function name"""
112+
# Create an agent with tool_choice set to a specific function
113+
agent = Agent(
114+
name="Test agent",
115+
instructions="You are a test agent",
116+
tools=[echo],
117+
model="gpt-4-0125-preview",
118+
model_settings=ModelSettings(tool_choice="echo"),
119+
)
120+
121+
# Execute our code under test
122+
processed_response = mock.MagicMock()
123+
processed_response.functions = [mock.MagicMock()] # At least one function call
124+
processed_response.computer_actions = []
125+
126+
# Create a mock run_config
127+
run_config = mock.MagicMock()
128+
run_config.model_settings = None
129+
130+
# Execute our code under test
131+
if processed_response.functions:
132+
# Reset agent's model_settings
133+
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
134+
agent.model_settings = ModelSettings(
135+
temperature=agent.model_settings.temperature,
136+
top_p=agent.model_settings.top_p,
137+
frequency_penalty=agent.model_settings.frequency_penalty,
138+
presence_penalty=agent.model_settings.presence_penalty,
139+
tool_choice="auto", # Reset to auto
140+
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
141+
truncation=agent.model_settings.truncation,
142+
max_tokens=agent.model_settings.max_tokens,
143+
)
144+
145+
# Also reset run_config's model_settings if it exists
146+
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
147+
isinstance(run_config.model_settings.tool_choice, str)):
148+
run_config.model_settings = ModelSettings(
149+
temperature=run_config.model_settings.temperature,
150+
top_p=run_config.model_settings.top_p,
151+
frequency_penalty=run_config.model_settings.frequency_penalty,
152+
presence_penalty=run_config.model_settings.presence_penalty,
153+
tool_choice="auto", # Reset to auto
154+
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
155+
truncation=run_config.model_settings.truncation,
156+
max_tokens=run_config.model_settings.max_tokens,
157+
)
158+
159+
# Check that tool_choice was reset to "auto"
160+
assert agent.model_settings.tool_choice == "auto"
161+
162+
async def test_tool_choice_no_reset_when_auto(self):
163+
"""Test that tool_choice is not changed when it's already set to 'auto'"""
164+
# Create an agent with tool_choice="auto"
165+
agent = Agent(
166+
name="Test agent",
167+
tools=[echo],
168+
model_settings=ModelSettings(tool_choice="auto"),
169+
)
170+
171+
# Execute our code under test
172+
processed_response = mock.MagicMock()
173+
processed_response.functions = [mock.MagicMock()] # At least one function call
174+
processed_response.computer_actions = []
175+
176+
# Create a mock run_config
177+
run_config = mock.MagicMock()
178+
run_config.model_settings = None
179+
180+
# Execute our code under test
181+
if processed_response.functions:
182+
# Reset agent's model_settings
183+
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
184+
agent.model_settings = ModelSettings(
185+
temperature=agent.model_settings.temperature,
186+
top_p=agent.model_settings.top_p,
187+
frequency_penalty=agent.model_settings.frequency_penalty,
188+
presence_penalty=agent.model_settings.presence_penalty,
189+
tool_choice="auto", # Reset to auto
190+
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
191+
truncation=agent.model_settings.truncation,
192+
max_tokens=agent.model_settings.max_tokens,
193+
)
194+
195+
# Also reset run_config's model_settings if it exists
196+
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
197+
isinstance(run_config.model_settings.tool_choice, str)):
198+
run_config.model_settings = ModelSettings(
199+
temperature=run_config.model_settings.temperature,
200+
top_p=run_config.model_settings.top_p,
201+
frequency_penalty=run_config.model_settings.frequency_penalty,
202+
presence_penalty=run_config.model_settings.presence_penalty,
203+
tool_choice="auto", # Reset to auto
204+
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
205+
truncation=run_config.model_settings.truncation,
206+
max_tokens=run_config.model_settings.max_tokens,
207+
)
208+
209+
# Check that tool_choice remains "auto"
210+
assert agent.model_settings.tool_choice == "auto"
211+
212+
async def test_run_config_tool_choice_reset(self):
213+
"""Test that run_config.model_settings.tool_choice is reset to 'auto'"""
214+
# Create an agent with default model_settings
215+
agent = Agent(
216+
name="Test agent",
217+
tools=[echo],
218+
model_settings=ModelSettings(tool_choice=None),
219+
)
220+
221+
# Create a run_config with tool_choice="required"
222+
run_config = RunConfig()
223+
run_config.model_settings = ModelSettings(tool_choice="required")
224+
225+
# Execute our code under test
226+
processed_response = mock.MagicMock()
227+
processed_response.functions = [mock.MagicMock()] # At least one function call
228+
processed_response.computer_actions = []
229+
230+
# Execute our code under test
231+
if processed_response.functions:
232+
# Reset agent's model_settings
233+
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
234+
agent.model_settings = ModelSettings(
235+
temperature=agent.model_settings.temperature,
236+
top_p=agent.model_settings.top_p,
237+
frequency_penalty=agent.model_settings.frequency_penalty,
238+
presence_penalty=agent.model_settings.presence_penalty,
239+
tool_choice="auto", # Reset to auto
240+
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
241+
truncation=agent.model_settings.truncation,
242+
max_tokens=agent.model_settings.max_tokens,
243+
)
244+
245+
# Also reset run_config's model_settings if it exists
246+
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
247+
isinstance(run_config.model_settings.tool_choice, str)):
248+
run_config.model_settings = ModelSettings(
249+
temperature=run_config.model_settings.temperature,
250+
top_p=run_config.model_settings.top_p,
251+
frequency_penalty=run_config.model_settings.frequency_penalty,
252+
presence_penalty=run_config.model_settings.presence_penalty,
253+
tool_choice="auto", # Reset to auto
254+
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
255+
truncation=run_config.model_settings.truncation,
256+
max_tokens=run_config.model_settings.max_tokens,
257+
)
258+
259+
# Check that run_config's tool_choice was reset to "auto"
260+
assert run_config.model_settings.tool_choice == "auto"
261+
262+
@mock.patch("agents.run.Runner._get_model")
263+
async def test_integration_prevents_infinite_loop(self, mock_get_model):
264+
"""Integration test to verify that tool_choice reset prevents infinite loops"""
265+
# Create a counter to track model calls and detect potential infinite loops
266+
tool_call_counter = {"count": 0, "potential_infinite_loop": False}
267+
268+
# Set up our mock model that will always use tools when tool_choice is set
269+
mock_model_instance = MockModel(tool_call_counter)
270+
# Return our mock model directly
271+
mock_get_model.return_value = mock_model_instance
272+
273+
# Create an agent with tool_choice="required" to force tool usage
274+
agent = Agent(
275+
name="Test agent",
276+
instructions="You are a test agent",
277+
tools=[echo],
278+
model_settings=ModelSettings(tool_choice="required"),
279+
# Use "run_llm_again" to allow LLM to continue after tool calls
280+
# This would cause infinite loops without the tool_choice reset
281+
tool_use_behavior="run_llm_again",
282+
)
283+
284+
# Set a timeout to catch potential infinite loops that our fix doesn't address
285+
try:
286+
# Run the agent with a timeout
287+
async def run_with_timeout():
288+
return await Runner.run(agent, input="Test input")
289+
290+
result = await asyncio.wait_for(run_with_timeout(), timeout=2.0)
291+
292+
# Verify the agent ran successfully
293+
assert result is not None
294+
295+
# Verify the tool was called at least once but not too many times
296+
# (indicating no infinite loop)
297+
assert tool_call_counter["count"] >= 1
298+
assert tool_call_counter["count"] < 5
299+
assert not tool_call_counter["potential_infinite_loop"]
300+
301+
except asyncio.TimeoutError:
302+
# If we hit a timeout, the test failed - we likely have an infinite loop
303+
assert False, "Timeout occurred, potential infinite loop detected"

0 commit comments

Comments
 (0)