|  | 
| 23 | 23 |     ChatCompletionChunk, | 
| 24 | 24 |     ChatCompletionMessageCustomToolCall, | 
| 25 | 25 |     ChatCompletionMessageFunctionToolCall, | 
|  | 26 | +    ChatCompletionMessageParam, | 
| 26 | 27 | ) | 
| 27 | 28 | from openai.types.chat.chat_completion_message import ( | 
| 28 | 29 |     Annotation, | 
| @@ -267,6 +268,10 @@ async def _fetch_response( | 
| 267 | 268 |             input, preserve_thinking_blocks=preserve_thinking_blocks | 
| 268 | 269 |         ) | 
| 269 | 270 | 
 | 
|  | 271 | +        # Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result  # noqa: E501 | 
|  | 272 | +        if preserve_thinking_blocks: | 
|  | 273 | +            converted_messages = self._fix_tool_message_ordering(converted_messages) | 
|  | 274 | + | 
| 270 | 275 |         if system_instructions: | 
| 271 | 276 |             converted_messages.insert( | 
| 272 | 277 |                 0, | 
| @@ -379,6 +384,121 @@ async def _fetch_response( | 
| 379 | 384 |         ) | 
| 380 | 385 |         return response, ret | 
| 381 | 386 | 
 | 
|  | 387 | +    def _fix_tool_message_ordering( | 
|  | 388 | +        self, messages: list[ChatCompletionMessageParam] | 
|  | 389 | +    ) -> list[ChatCompletionMessageParam]: | 
|  | 390 | +        """ | 
|  | 391 | +        Fix the ordering of tool messages to ensure tool_use messages come before tool_result messages. | 
|  | 392 | +
 | 
|  | 393 | +        This addresses the interleaved thinking bug where conversation histories may contain | 
|  | 394 | +        tool results before their corresponding tool calls, causing Anthropic API to reject the request. | 
|  | 395 | +        """  # noqa: E501 | 
|  | 396 | +        if not messages: | 
|  | 397 | +            return messages | 
|  | 398 | + | 
|  | 399 | +        # Collect all tool calls and tool results | 
|  | 400 | +        tool_call_messages = {}  # tool_id -> (index, message) | 
|  | 401 | +        tool_result_messages = {}  # tool_id -> (index, message) | 
|  | 402 | +        other_messages = []  # (index, message) for non-tool messages | 
|  | 403 | + | 
|  | 404 | +        for i, message in enumerate(messages): | 
|  | 405 | +            if not isinstance(message, dict): | 
|  | 406 | +                other_messages.append((i, message)) | 
|  | 407 | +                continue | 
|  | 408 | + | 
|  | 409 | +            role = message.get("role") | 
|  | 410 | + | 
|  | 411 | +            if role == "assistant" and message.get("tool_calls"): | 
|  | 412 | +                # Extract tool calls from this assistant message | 
|  | 413 | +                tool_calls = message.get("tool_calls", []) | 
|  | 414 | +                if isinstance(tool_calls, list): | 
|  | 415 | +                    for tool_call in tool_calls: | 
|  | 416 | +                        if isinstance(tool_call, dict): | 
|  | 417 | +                            tool_id = tool_call.get("id") | 
|  | 418 | +                            if tool_id: | 
|  | 419 | +                                # Create a separate assistant message for each tool call | 
|  | 420 | +                                single_tool_msg = cast(dict[str, Any], message.copy()) | 
|  | 421 | +                                single_tool_msg["tool_calls"] = [tool_call] | 
|  | 422 | +                                tool_call_messages[tool_id] = ( | 
|  | 423 | +                                    i, | 
|  | 424 | +                                    cast(ChatCompletionMessageParam, single_tool_msg), | 
|  | 425 | +                                ) | 
|  | 426 | + | 
|  | 427 | +            elif role == "tool": | 
|  | 428 | +                tool_call_id = message.get("tool_call_id") | 
|  | 429 | +                if tool_call_id: | 
|  | 430 | +                    tool_result_messages[tool_call_id] = (i, message) | 
|  | 431 | +                else: | 
|  | 432 | +                    other_messages.append((i, message)) | 
|  | 433 | +            else: | 
|  | 434 | +                other_messages.append((i, message)) | 
|  | 435 | + | 
|  | 436 | +        # First, identify which tool results will be paired to avoid duplicates | 
|  | 437 | +        paired_tool_result_indices = set() | 
|  | 438 | +        for tool_id in tool_call_messages: | 
|  | 439 | +            if tool_id in tool_result_messages: | 
|  | 440 | +                tool_result_idx, _ = tool_result_messages[tool_id] | 
|  | 441 | +                paired_tool_result_indices.add(tool_result_idx) | 
|  | 442 | + | 
|  | 443 | +        # Create the fixed message sequence | 
|  | 444 | +        fixed_messages: list[ChatCompletionMessageParam] = [] | 
|  | 445 | +        used_indices = set() | 
|  | 446 | + | 
|  | 447 | +        # Add messages in their original order, but ensure tool_use → tool_result pairing | 
|  | 448 | +        for i, original_message in enumerate(messages): | 
|  | 449 | +            if i in used_indices: | 
|  | 450 | +                continue | 
|  | 451 | + | 
|  | 452 | +            if not isinstance(original_message, dict): | 
|  | 453 | +                fixed_messages.append(original_message) | 
|  | 454 | +                used_indices.add(i) | 
|  | 455 | +                continue | 
|  | 456 | + | 
|  | 457 | +            role = original_message.get("role") | 
|  | 458 | + | 
|  | 459 | +            if role == "assistant" and original_message.get("tool_calls"): | 
|  | 460 | +                # Process each tool call in this assistant message | 
|  | 461 | +                tool_calls = original_message.get("tool_calls", []) | 
|  | 462 | +                if isinstance(tool_calls, list): | 
|  | 463 | +                    for tool_call in tool_calls: | 
|  | 464 | +                        if isinstance(tool_call, dict): | 
|  | 465 | +                            tool_id = tool_call.get("id") | 
|  | 466 | +                            if ( | 
|  | 467 | +                                tool_id | 
|  | 468 | +                                and tool_id in tool_call_messages | 
|  | 469 | +                                and tool_id in tool_result_messages | 
|  | 470 | +                            ): | 
|  | 471 | +                                # Add tool_use → tool_result pair | 
|  | 472 | +                                _, tool_call_msg = tool_call_messages[tool_id] | 
|  | 473 | +                                tool_result_idx, tool_result_msg = tool_result_messages[tool_id] | 
|  | 474 | + | 
|  | 475 | +                                fixed_messages.append(tool_call_msg) | 
|  | 476 | +                                fixed_messages.append(tool_result_msg) | 
|  | 477 | + | 
|  | 478 | +                                # Mark both as used | 
|  | 479 | +                                used_indices.add(tool_call_messages[tool_id][0]) | 
|  | 480 | +                                used_indices.add(tool_result_idx) | 
|  | 481 | +                            elif tool_id and tool_id in tool_call_messages: | 
|  | 482 | +                                # Tool call without result - add just the tool call | 
|  | 483 | +                                _, tool_call_msg = tool_call_messages[tool_id] | 
|  | 484 | +                                fixed_messages.append(tool_call_msg) | 
|  | 485 | +                                used_indices.add(tool_call_messages[tool_id][0]) | 
|  | 486 | + | 
|  | 487 | +                used_indices.add(i)  # Mark original multi-tool message as used | 
|  | 488 | + | 
|  | 489 | +            elif role == "tool": | 
|  | 490 | +                # Only preserve unmatched tool results to avoid duplicates | 
|  | 491 | +                if i not in paired_tool_result_indices: | 
|  | 492 | +                    fixed_messages.append(original_message) | 
|  | 493 | +                used_indices.add(i) | 
|  | 494 | + | 
|  | 495 | +            else: | 
|  | 496 | +                # Regular message - add it normally | 
|  | 497 | +                fixed_messages.append(original_message) | 
|  | 498 | +                used_indices.add(i) | 
|  | 499 | + | 
|  | 500 | +        return fixed_messages | 
|  | 501 | + | 
| 382 | 502 |     def _remove_not_given(self, value: Any) -> Any: | 
| 383 | 503 |         if isinstance(value, NotGiven): | 
| 384 | 504 |             return None | 
|  | 
0 commit comments