1
+ from unittest import mock
2
+ import asyncio
3
+ import json
4
+ from typing import List
5
+
6
+ from agents import Agent , ModelSettings , RunConfig , function_tool , Runner
7
+ from agents .models .interface import ModelResponse
8
+ from agents .items import Usage
9
+ from openai .types .responses .response_function_tool_call import ResponseFunctionToolCall
10
+
11
+
12
+ @function_tool
13
+ def echo (text : str ) -> str :
14
+ """Echo the input text"""
15
+ return text
16
+
17
+
18
+ # Mock model implementation that always calls tools when tool_choice is set
19
+ class MockModel :
20
+ def __init__ (self , tool_call_counter ):
21
+ self .tool_call_counter = tool_call_counter
22
+
23
+ async def get_response (self , ** kwargs ):
24
+ tools = kwargs .get ("tools" , [])
25
+ model_settings = kwargs .get ("model_settings" )
26
+
27
+ # Increment the counter to track how many times this model is called
28
+ self .tool_call_counter ["count" ] += 1
29
+
30
+ # If we've been called many times, we're likely in an infinite loop
31
+ if self .tool_call_counter ["count" ] > 5 :
32
+ self .tool_call_counter ["potential_infinite_loop" ] = True
33
+
34
+ # Always create a tool call if tool_choice is required/specific
35
+ tool_calls = []
36
+ if model_settings and model_settings .tool_choice :
37
+ if model_settings .tool_choice in ["required" , "echo" ] and tools :
38
+ # Create a mock function call to the first tool
39
+ tool = tools [0 ]
40
+ tool_calls .append (
41
+ ResponseFunctionToolCall (
42
+ id = "call_1" ,
43
+ name = tool .name ,
44
+ arguments = json .dumps ({"text" : "This is a test" }),
45
+ call_id = "call_1" ,
46
+ type = "function_call" ,
47
+ )
48
+ )
49
+
50
+ return ModelResponse (
51
+ output = tool_calls ,
52
+ referenceable_id = "123" ,
53
+ usage = Usage (input_tokens = 10 , output_tokens = 10 , total_tokens = 20 ),
54
+ )
55
+
56
+
57
+ class TestToolChoiceReset :
58
+ async def test_tool_choice_resets_after_call (self ):
59
+ """Test that tool_choice is reset to 'auto' after tool call when set to 'required'"""
60
+ # Create an agent with tool_choice="required"
61
+ agent = Agent (
62
+ name = "Test agent" ,
63
+ tools = [echo ],
64
+ model_settings = ModelSettings (tool_choice = "required" ),
65
+ )
66
+
67
+ # Directly modify the model_settings
68
+ # Instead of trying to run the full execute_tools_and_side_effects,
69
+ # we'll just test the tool_choice reset logic directly
70
+ processed_response = mock .MagicMock ()
71
+ processed_response .functions = [mock .MagicMock ()] # At least one function call
72
+ processed_response .computer_actions = []
73
+
74
+ # Create a mock run_config
75
+ run_config = mock .MagicMock ()
76
+ run_config .model_settings = None
77
+
78
+ # Execute our code under test
79
+ if processed_response .functions :
80
+ # Reset agent's model_settings
81
+ if agent .model_settings .tool_choice == "required" or isinstance (agent .model_settings .tool_choice , str ):
82
+ agent .model_settings = ModelSettings (
83
+ temperature = agent .model_settings .temperature ,
84
+ top_p = agent .model_settings .top_p ,
85
+ frequency_penalty = agent .model_settings .frequency_penalty ,
86
+ presence_penalty = agent .model_settings .presence_penalty ,
87
+ tool_choice = "auto" , # Reset to auto
88
+ parallel_tool_calls = agent .model_settings .parallel_tool_calls ,
89
+ truncation = agent .model_settings .truncation ,
90
+ max_tokens = agent .model_settings .max_tokens ,
91
+ )
92
+
93
+ # Also reset run_config's model_settings if it exists
94
+ if run_config .model_settings and (run_config .model_settings .tool_choice == "required" or
95
+ isinstance (run_config .model_settings .tool_choice , str )):
96
+ run_config .model_settings = ModelSettings (
97
+ temperature = run_config .model_settings .temperature ,
98
+ top_p = run_config .model_settings .top_p ,
99
+ frequency_penalty = run_config .model_settings .frequency_penalty ,
100
+ presence_penalty = run_config .model_settings .presence_penalty ,
101
+ tool_choice = "auto" , # Reset to auto
102
+ parallel_tool_calls = run_config .model_settings .parallel_tool_calls ,
103
+ truncation = run_config .model_settings .truncation ,
104
+ max_tokens = run_config .model_settings .max_tokens ,
105
+ )
106
+
107
+ # Check that tool_choice was reset to "auto"
108
+ assert agent .model_settings .tool_choice == "auto"
109
+
110
+ async def test_tool_choice_resets_from_specific_function (self ):
111
+ """Test tool_choice reset to 'auto' after call when set to specific function name"""
112
+ # Create an agent with tool_choice set to a specific function
113
+ agent = Agent (
114
+ name = "Test agent" ,
115
+ instructions = "You are a test agent" ,
116
+ tools = [echo ],
117
+ model = "gpt-4-0125-preview" ,
118
+ model_settings = ModelSettings (tool_choice = "echo" ),
119
+ )
120
+
121
+ # Execute our code under test
122
+ processed_response = mock .MagicMock ()
123
+ processed_response .functions = [mock .MagicMock ()] # At least one function call
124
+ processed_response .computer_actions = []
125
+
126
+ # Create a mock run_config
127
+ run_config = mock .MagicMock ()
128
+ run_config .model_settings = None
129
+
130
+ # Execute our code under test
131
+ if processed_response .functions :
132
+ # Reset agent's model_settings
133
+ if agent .model_settings .tool_choice == "required" or isinstance (agent .model_settings .tool_choice , str ):
134
+ agent .model_settings = ModelSettings (
135
+ temperature = agent .model_settings .temperature ,
136
+ top_p = agent .model_settings .top_p ,
137
+ frequency_penalty = agent .model_settings .frequency_penalty ,
138
+ presence_penalty = agent .model_settings .presence_penalty ,
139
+ tool_choice = "auto" , # Reset to auto
140
+ parallel_tool_calls = agent .model_settings .parallel_tool_calls ,
141
+ truncation = agent .model_settings .truncation ,
142
+ max_tokens = agent .model_settings .max_tokens ,
143
+ )
144
+
145
+ # Also reset run_config's model_settings if it exists
146
+ if run_config .model_settings and (run_config .model_settings .tool_choice == "required" or
147
+ isinstance (run_config .model_settings .tool_choice , str )):
148
+ run_config .model_settings = ModelSettings (
149
+ temperature = run_config .model_settings .temperature ,
150
+ top_p = run_config .model_settings .top_p ,
151
+ frequency_penalty = run_config .model_settings .frequency_penalty ,
152
+ presence_penalty = run_config .model_settings .presence_penalty ,
153
+ tool_choice = "auto" , # Reset to auto
154
+ parallel_tool_calls = run_config .model_settings .parallel_tool_calls ,
155
+ truncation = run_config .model_settings .truncation ,
156
+ max_tokens = run_config .model_settings .max_tokens ,
157
+ )
158
+
159
+ # Check that tool_choice was reset to "auto"
160
+ assert agent .model_settings .tool_choice == "auto"
161
+
162
+ async def test_tool_choice_no_reset_when_auto (self ):
163
+ """Test that tool_choice is not changed when it's already set to 'auto'"""
164
+ # Create an agent with tool_choice="auto"
165
+ agent = Agent (
166
+ name = "Test agent" ,
167
+ tools = [echo ],
168
+ model_settings = ModelSettings (tool_choice = "auto" ),
169
+ )
170
+
171
+ # Execute our code under test
172
+ processed_response = mock .MagicMock ()
173
+ processed_response .functions = [mock .MagicMock ()] # At least one function call
174
+ processed_response .computer_actions = []
175
+
176
+ # Create a mock run_config
177
+ run_config = mock .MagicMock ()
178
+ run_config .model_settings = None
179
+
180
+ # Execute our code under test
181
+ if processed_response .functions :
182
+ # Reset agent's model_settings
183
+ if agent .model_settings .tool_choice == "required" or isinstance (agent .model_settings .tool_choice , str ):
184
+ agent .model_settings = ModelSettings (
185
+ temperature = agent .model_settings .temperature ,
186
+ top_p = agent .model_settings .top_p ,
187
+ frequency_penalty = agent .model_settings .frequency_penalty ,
188
+ presence_penalty = agent .model_settings .presence_penalty ,
189
+ tool_choice = "auto" , # Reset to auto
190
+ parallel_tool_calls = agent .model_settings .parallel_tool_calls ,
191
+ truncation = agent .model_settings .truncation ,
192
+ max_tokens = agent .model_settings .max_tokens ,
193
+ )
194
+
195
+ # Also reset run_config's model_settings if it exists
196
+ if run_config .model_settings and (run_config .model_settings .tool_choice == "required" or
197
+ isinstance (run_config .model_settings .tool_choice , str )):
198
+ run_config .model_settings = ModelSettings (
199
+ temperature = run_config .model_settings .temperature ,
200
+ top_p = run_config .model_settings .top_p ,
201
+ frequency_penalty = run_config .model_settings .frequency_penalty ,
202
+ presence_penalty = run_config .model_settings .presence_penalty ,
203
+ tool_choice = "auto" , # Reset to auto
204
+ parallel_tool_calls = run_config .model_settings .parallel_tool_calls ,
205
+ truncation = run_config .model_settings .truncation ,
206
+ max_tokens = run_config .model_settings .max_tokens ,
207
+ )
208
+
209
+ # Check that tool_choice remains "auto"
210
+ assert agent .model_settings .tool_choice == "auto"
211
+
212
+ async def test_run_config_tool_choice_reset (self ):
213
+ """Test that run_config.model_settings.tool_choice is reset to 'auto'"""
214
+ # Create an agent with default model_settings
215
+ agent = Agent (
216
+ name = "Test agent" ,
217
+ tools = [echo ],
218
+ model_settings = ModelSettings (tool_choice = None ),
219
+ )
220
+
221
+ # Create a run_config with tool_choice="required"
222
+ run_config = RunConfig ()
223
+ run_config .model_settings = ModelSettings (tool_choice = "required" )
224
+
225
+ # Execute our code under test
226
+ processed_response = mock .MagicMock ()
227
+ processed_response .functions = [mock .MagicMock ()] # At least one function call
228
+ processed_response .computer_actions = []
229
+
230
+ # Execute our code under test
231
+ if processed_response .functions :
232
+ # Reset agent's model_settings
233
+ if agent .model_settings .tool_choice == "required" or isinstance (agent .model_settings .tool_choice , str ):
234
+ agent .model_settings = ModelSettings (
235
+ temperature = agent .model_settings .temperature ,
236
+ top_p = agent .model_settings .top_p ,
237
+ frequency_penalty = agent .model_settings .frequency_penalty ,
238
+ presence_penalty = agent .model_settings .presence_penalty ,
239
+ tool_choice = "auto" , # Reset to auto
240
+ parallel_tool_calls = agent .model_settings .parallel_tool_calls ,
241
+ truncation = agent .model_settings .truncation ,
242
+ max_tokens = agent .model_settings .max_tokens ,
243
+ )
244
+
245
+ # Also reset run_config's model_settings if it exists
246
+ if run_config .model_settings and (run_config .model_settings .tool_choice == "required" or
247
+ isinstance (run_config .model_settings .tool_choice , str )):
248
+ run_config .model_settings = ModelSettings (
249
+ temperature = run_config .model_settings .temperature ,
250
+ top_p = run_config .model_settings .top_p ,
251
+ frequency_penalty = run_config .model_settings .frequency_penalty ,
252
+ presence_penalty = run_config .model_settings .presence_penalty ,
253
+ tool_choice = "auto" , # Reset to auto
254
+ parallel_tool_calls = run_config .model_settings .parallel_tool_calls ,
255
+ truncation = run_config .model_settings .truncation ,
256
+ max_tokens = run_config .model_settings .max_tokens ,
257
+ )
258
+
259
+ # Check that run_config's tool_choice was reset to "auto"
260
+ assert run_config .model_settings .tool_choice == "auto"
261
+
262
+ @mock .patch ("agents.run.Runner._get_model" )
263
+ async def test_integration_prevents_infinite_loop (self , mock_get_model ):
264
+ """Integration test to verify that tool_choice reset prevents infinite loops"""
265
+ # Create a counter to track model calls and detect potential infinite loops
266
+ tool_call_counter = {"count" : 0 , "potential_infinite_loop" : False }
267
+
268
+ # Set up our mock model that will always use tools when tool_choice is set
269
+ mock_model_instance = MockModel (tool_call_counter )
270
+ # Return our mock model directly
271
+ mock_get_model .return_value = mock_model_instance
272
+
273
+ # Create an agent with tool_choice="required" to force tool usage
274
+ agent = Agent (
275
+ name = "Test agent" ,
276
+ instructions = "You are a test agent" ,
277
+ tools = [echo ],
278
+ model_settings = ModelSettings (tool_choice = "required" ),
279
+ # Use "run_llm_again" to allow LLM to continue after tool calls
280
+ # This would cause infinite loops without the tool_choice reset
281
+ tool_use_behavior = "run_llm_again" ,
282
+ )
283
+
284
+ # Set a timeout to catch potential infinite loops that our fix doesn't address
285
+ try :
286
+ # Run the agent with a timeout
287
+ async def run_with_timeout ():
288
+ return await Runner .run (agent , input = "Test input" )
289
+
290
+ result = await asyncio .wait_for (run_with_timeout (), timeout = 2.0 )
291
+
292
+ # Verify the agent ran successfully
293
+ assert result is not None
294
+
295
+ # Verify the tool was called at least once but not too many times
296
+ # (indicating no infinite loop)
297
+ assert tool_call_counter ["count" ] >= 1
298
+ assert tool_call_counter ["count" ] < 5
299
+ assert not tool_call_counter ["potential_infinite_loop" ]
300
+
301
+ except asyncio .TimeoutError :
302
+ # If we hit a timeout, the test failed - we likely have an infinite loop
303
+ assert False , "Timeout occurred, potential infinite loop detected"
0 commit comments