Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ChatCompletionChunk,
ChatCompletionMessageCustomToolCall,
ChatCompletionMessageFunctionToolCall,
ChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_message import (
Annotation,
Expand Down Expand Up @@ -267,6 +268,10 @@ async def _fetch_response(
input, preserve_thinking_blocks=preserve_thinking_blocks
)

# Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result # noqa: E501
if preserve_thinking_blocks:
converted_messages = self._fix_tool_message_ordering(converted_messages)

if system_instructions:
converted_messages.insert(
0,
Expand Down Expand Up @@ -379,6 +384,121 @@ async def _fetch_response(
)
return response, ret

def _fix_tool_message_ordering(
self, messages: list[ChatCompletionMessageParam]
) -> list[ChatCompletionMessageParam]:
"""
Fix the ordering of tool messages to ensure tool_use messages come before tool_result messages.

This addresses the interleaved thinking bug where conversation histories may contain
tool results before their corresponding tool calls, causing Anthropic API to reject the request.
""" # noqa: E501
if not messages:
return messages

# Collect all tool calls and tool results
tool_call_messages = {} # tool_id -> (index, message)
tool_result_messages = {} # tool_id -> (index, message)
other_messages = [] # (index, message) for non-tool messages

for i, message in enumerate(messages):
if not isinstance(message, dict):
other_messages.append((i, message))
continue

role = message.get("role")

if role == "assistant" and message.get("tool_calls"):
# Extract tool calls from this assistant message
tool_calls = message.get("tool_calls", [])
if isinstance(tool_calls, list):
for tool_call in tool_calls:
if isinstance(tool_call, dict):
tool_id = tool_call.get("id")
if tool_id:
# Create a separate assistant message for each tool call
single_tool_msg = cast(dict[str, Any], message.copy())
single_tool_msg["tool_calls"] = [tool_call]
tool_call_messages[tool_id] = (
i,
cast(ChatCompletionMessageParam, single_tool_msg),
)

elif role == "tool":
tool_call_id = message.get("tool_call_id")
if tool_call_id:
tool_result_messages[tool_call_id] = (i, message)
else:
other_messages.append((i, message))
else:
other_messages.append((i, message))

# First, identify which tool results will be paired to avoid duplicates
paired_tool_result_indices = set()
for tool_id in tool_call_messages:
if tool_id in tool_result_messages:
tool_result_idx, _ = tool_result_messages[tool_id]
paired_tool_result_indices.add(tool_result_idx)

# Create the fixed message sequence
fixed_messages: list[ChatCompletionMessageParam] = []
used_indices = set()

# Add messages in their original order, but ensure tool_use → tool_result pairing
for i, original_message in enumerate(messages):
if i in used_indices:
continue

if not isinstance(original_message, dict):
fixed_messages.append(original_message)
used_indices.add(i)
continue

role = original_message.get("role")

if role == "assistant" and original_message.get("tool_calls"):
# Process each tool call in this assistant message
tool_calls = original_message.get("tool_calls", [])
if isinstance(tool_calls, list):
for tool_call in tool_calls:
if isinstance(tool_call, dict):
tool_id = tool_call.get("id")
if (
tool_id
and tool_id in tool_call_messages
and tool_id in tool_result_messages
):
# Add tool_use → tool_result pair
_, tool_call_msg = tool_call_messages[tool_id]
tool_result_idx, tool_result_msg = tool_result_messages[tool_id]

fixed_messages.append(tool_call_msg)
fixed_messages.append(tool_result_msg)

# Mark both as used
used_indices.add(tool_call_messages[tool_id][0])
used_indices.add(tool_result_idx)
elif tool_id and tool_id in tool_call_messages:
# Tool call without result - add just the tool call
_, tool_call_msg = tool_call_messages[tool_id]
fixed_messages.append(tool_call_msg)
used_indices.add(tool_call_messages[tool_id][0])

used_indices.add(i) # Mark original multi-tool message as used

elif role == "tool":
# Only preserve unmatched tool results to avoid duplicates
if i not in paired_tool_result_indices:
fixed_messages.append(original_message)
used_indices.add(i)

else:
# Regular message - add it normally
fixed_messages.append(original_message)
used_indices.add(i)

return fixed_messages

def _remove_not_given(self, value: Any) -> Any:
if isinstance(value, NotGiven):
return None
Expand Down
8 changes: 6 additions & 2 deletions src/agents/models/chatcmpl_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:

if content_items and preserve_thinking_blocks:
# Reconstruct thinking blocks from content and signature
pending_thinking_blocks = []
reconstructed_thinking_blocks = []
for content_item in content_items:
if (
isinstance(content_item, dict)
Expand All @@ -546,7 +546,11 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
# Add signatures if available
if signatures:
thinking_block["signature"] = signatures.pop(0)
pending_thinking_blocks.append(thinking_block)
reconstructed_thinking_blocks.append(thinking_block)

# Store thinking blocks as pending for the next assistant message
# This preserves the original behavior
pending_thinking_blocks = reconstructed_thinking_blocks

# 8) If we haven't recognized it => fail or ignore
else:
Expand Down
Loading