Skip to content

Commit bb7a3a4

Browse files
committed
add test for extended thinking message reordering
1 parent 6ed8cef commit bb7a3a4

File tree

1 file changed

+293
-0
lines changed

1 file changed

+293
-0
lines changed
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
"""Tests for the extended thinking message order bug fix in LitellmModel."""
2+
3+
from __future__ import annotations
4+
5+
from openai.types.chat import ChatCompletionMessageParam
6+
7+
from agents.extensions.models.litellm_model import LitellmModel
8+
9+
10+
class TestExtendedThinkingMessageOrder:
11+
"""Test the _fix_tool_message_ordering method."""
12+
13+
def test_basic_reordering_tool_result_before_call(self):
14+
"""Test that a tool result appearing before its tool call gets reordered correctly."""
15+
messages: list[ChatCompletionMessageParam] = [
16+
{"role": "user", "content": "Hello"},
17+
{"role": "tool", "tool_call_id": "call_123", "content": "Result for call_123"},
18+
{
19+
"role": "assistant",
20+
"tool_calls": [
21+
{
22+
"id": "call_123",
23+
"type": "function",
24+
"function": {"name": "test", "arguments": "{}"},
25+
}
26+
],
27+
},
28+
{"role": "user", "content": "Thanks"},
29+
]
30+
31+
model = LitellmModel("test-model")
32+
result = model._fix_tool_message_ordering(messages)
33+
34+
# Should reorder to: user, assistant+tool_call, tool_result, user
35+
assert len(result) == 4
36+
assert result[0]["role"] == "user"
37+
assert result[1]["role"] == "assistant"
38+
assert result[1]["tool_calls"][0]["id"] == "call_123" # type: ignore
39+
assert result[2]["role"] == "tool"
40+
assert result[2]["tool_call_id"] == "call_123"
41+
assert result[3]["role"] == "user"
42+
43+
def test_consecutive_tool_calls_get_separated(self):
44+
"""Test that consecutive assistant messages with tool calls get properly paired with results.""" # noqa: E501
45+
messages: list[ChatCompletionMessageParam] = [
46+
{"role": "user", "content": "Hello"},
47+
{
48+
"role": "assistant",
49+
"tool_calls": [
50+
{
51+
"id": "call_1",
52+
"type": "function",
53+
"function": {"name": "test1", "arguments": "{}"},
54+
}
55+
],
56+
},
57+
{
58+
"role": "assistant",
59+
"tool_calls": [
60+
{
61+
"id": "call_2",
62+
"type": "function",
63+
"function": {"name": "test2", "arguments": "{}"},
64+
}
65+
],
66+
},
67+
{"role": "tool", "tool_call_id": "call_1", "content": "Result 1"},
68+
{"role": "tool", "tool_call_id": "call_2", "content": "Result 2"},
69+
]
70+
71+
model = LitellmModel("test-model")
72+
result = model._fix_tool_message_ordering(messages)
73+
74+
# Should pair each tool call with its result immediately
75+
assert len(result) == 5
76+
assert result[0]["role"] == "user"
77+
assert result[1]["role"] == "assistant"
78+
assert result[1]["tool_calls"][0]["id"] == "call_1" # type: ignore
79+
assert result[2]["role"] == "tool"
80+
assert result[2]["tool_call_id"] == "call_1"
81+
assert result[3]["role"] == "assistant"
82+
assert result[3]["tool_calls"][0]["id"] == "call_2" # type: ignore
83+
assert result[4]["role"] == "tool"
84+
assert result[4]["tool_call_id"] == "call_2"
85+
86+
def test_unmatched_tool_results_preserved(self):
87+
"""Test that tool results without matching tool calls are preserved."""
88+
messages: list[ChatCompletionMessageParam] = [
89+
{"role": "user", "content": "Hello"},
90+
{
91+
"role": "assistant",
92+
"tool_calls": [
93+
{
94+
"id": "call_1",
95+
"type": "function",
96+
"function": {"name": "test", "arguments": "{}"},
97+
}
98+
],
99+
},
100+
{"role": "tool", "tool_call_id": "call_1", "content": "Matched result"},
101+
{"role": "tool", "tool_call_id": "call_orphan", "content": "Orphaned result"},
102+
{"role": "user", "content": "End"},
103+
]
104+
105+
model = LitellmModel("test-model")
106+
result = model._fix_tool_message_ordering(messages)
107+
108+
# Should preserve the orphaned tool result
109+
assert len(result) == 5
110+
assert result[0]["role"] == "user"
111+
assert result[1]["role"] == "assistant"
112+
assert result[2]["role"] == "tool"
113+
assert result[2]["tool_call_id"] == "call_1"
114+
assert result[3]["role"] == "tool" # Orphaned result preserved
115+
assert result[3]["tool_call_id"] == "call_orphan"
116+
assert result[4]["role"] == "user"
117+
118+
def test_tool_calls_without_results_preserved(self):
119+
"""Test that tool calls without results are still included."""
120+
messages: list[ChatCompletionMessageParam] = [
121+
{"role": "user", "content": "Hello"},
122+
{
123+
"role": "assistant",
124+
"tool_calls": [
125+
{
126+
"id": "call_1",
127+
"type": "function",
128+
"function": {"name": "test", "arguments": "{}"},
129+
}
130+
],
131+
},
132+
{"role": "user", "content": "End"},
133+
]
134+
135+
model = LitellmModel("test-model")
136+
result = model._fix_tool_message_ordering(messages)
137+
138+
# Should preserve the tool call even without a result
139+
assert len(result) == 3
140+
assert result[0]["role"] == "user"
141+
assert result[1]["role"] == "assistant"
142+
assert result[1]["tool_calls"][0]["id"] == "call_1" # type: ignore
143+
assert result[2]["role"] == "user"
144+
145+
def test_correctly_ordered_messages_unchanged(self):
146+
"""Test that correctly ordered messages remain in the same order."""
147+
messages: list[ChatCompletionMessageParam] = [
148+
{"role": "user", "content": "Hello"},
149+
{
150+
"role": "assistant",
151+
"tool_calls": [
152+
{
153+
"id": "call_1",
154+
"type": "function",
155+
"function": {"name": "test", "arguments": "{}"},
156+
}
157+
],
158+
},
159+
{"role": "tool", "tool_call_id": "call_1", "content": "Result"},
160+
{"role": "assistant", "content": "Done"},
161+
]
162+
163+
model = LitellmModel("test-model")
164+
result = model._fix_tool_message_ordering(messages)
165+
166+
# Should remain exactly the same
167+
assert len(result) == 4
168+
assert result[0]["role"] == "user"
169+
assert result[1]["role"] == "assistant"
170+
assert result[1]["tool_calls"][0]["id"] == "call_1" # type: ignore
171+
assert result[2]["role"] == "tool"
172+
assert result[2]["tool_call_id"] == "call_1"
173+
assert result[3]["role"] == "assistant"
174+
175+
def test_multiple_tool_calls_single_message(self):
176+
"""Test assistant message with multiple tool calls gets split properly."""
177+
messages: list[ChatCompletionMessageParam] = [
178+
{"role": "user", "content": "Hello"},
179+
{
180+
"role": "assistant",
181+
"tool_calls": [
182+
{
183+
"id": "call_1",
184+
"type": "function",
185+
"function": {"name": "test1", "arguments": "{}"},
186+
},
187+
{
188+
"id": "call_2",
189+
"type": "function",
190+
"function": {"name": "test2", "arguments": "{}"},
191+
},
192+
],
193+
},
194+
{"role": "tool", "tool_call_id": "call_1", "content": "Result 1"},
195+
{"role": "tool", "tool_call_id": "call_2", "content": "Result 2"},
196+
]
197+
198+
model = LitellmModel("test-model")
199+
result = model._fix_tool_message_ordering(messages)
200+
201+
# Should split the multi-tool message and pair each properly
202+
assert len(result) == 5
203+
assert result[0]["role"] == "user"
204+
assert result[1]["role"] == "assistant"
205+
assert len(result[1]["tool_calls"]) == 1 # type: ignore
206+
assert result[1]["tool_calls"][0]["id"] == "call_1" # type: ignore
207+
assert result[2]["role"] == "tool"
208+
assert result[2]["tool_call_id"] == "call_1"
209+
assert result[3]["role"] == "assistant"
210+
assert len(result[3]["tool_calls"]) == 1 # type: ignore
211+
assert result[3]["tool_calls"][0]["id"] == "call_2" # type: ignore
212+
assert result[4]["role"] == "tool"
213+
assert result[4]["tool_call_id"] == "call_2"
214+
215+
def test_empty_messages_list(self):
216+
"""Test that empty message list is handled correctly."""
217+
messages: list[ChatCompletionMessageParam] = []
218+
219+
model = LitellmModel("test-model")
220+
result = model._fix_tool_message_ordering(messages)
221+
222+
assert result == []
223+
224+
def test_no_tool_messages(self):
225+
"""Test that messages without tool calls are left unchanged."""
226+
messages: list[ChatCompletionMessageParam] = [
227+
{"role": "user", "content": "Hello"},
228+
{"role": "assistant", "content": "Hi there"},
229+
{"role": "user", "content": "How are you?"},
230+
]
231+
232+
model = LitellmModel("test-model")
233+
result = model._fix_tool_message_ordering(messages)
234+
235+
assert result == messages
236+
237+
def test_complex_mixed_scenario(self):
238+
"""Test a complex scenario with various message types and orderings."""
239+
messages: list[ChatCompletionMessageParam] = [
240+
{"role": "user", "content": "Start"},
241+
{
242+
"role": "tool",
243+
"tool_call_id": "call_out_of_order",
244+
"content": "Out of order result",
245+
}, # This comes before its call
246+
{"role": "assistant", "content": "Regular response"},
247+
{
248+
"role": "assistant",
249+
"tool_calls": [
250+
{
251+
"id": "call_out_of_order",
252+
"type": "function",
253+
"function": {"name": "test", "arguments": "{}"},
254+
}
255+
],
256+
},
257+
{
258+
"role": "assistant",
259+
"tool_calls": [
260+
{
261+
"id": "call_normal",
262+
"type": "function",
263+
"function": {"name": "test2", "arguments": "{}"},
264+
}
265+
],
266+
},
267+
{"role": "tool", "tool_call_id": "call_normal", "content": "Normal result"},
268+
{
269+
"role": "tool",
270+
"tool_call_id": "call_orphan",
271+
"content": "Orphaned result",
272+
}, # No matching call
273+
{"role": "user", "content": "End"},
274+
]
275+
276+
model = LitellmModel("test-model")
277+
result = model._fix_tool_message_ordering(messages)
278+
279+
# Should reorder properly while preserving all messages
280+
assert len(result) == 8
281+
assert result[0]["role"] == "user" # Start
282+
assert result[1]["role"] == "assistant" # Regular response
283+
assert result[2]["role"] == "assistant" # call_out_of_order
284+
assert result[2]["tool_calls"][0]["id"] == "call_out_of_order" # type: ignore
285+
assert result[3]["role"] == "tool" # Out of order result (now properly paired)
286+
assert result[3]["tool_call_id"] == "call_out_of_order"
287+
assert result[4]["role"] == "assistant" # call_normal
288+
assert result[4]["tool_calls"][0]["id"] == "call_normal" # type: ignore
289+
assert result[5]["role"] == "tool" # Normal result
290+
assert result[5]["tool_call_id"] == "call_normal"
291+
assert result[6]["role"] == "tool" # Orphaned result (preserved)
292+
assert result[6]["tool_call_id"] == "call_orphan"
293+
assert result[7]["role"] == "user" # End

0 commit comments

Comments
 (0)