Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 222 additions & 62 deletions ccproxy/llms/formatters/openai_to_anthropic/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from collections import Counter
from typing import Any

from ccproxy.core.constants import DEFAULT_MAX_TOKENS
Expand All @@ -15,83 +16,242 @@


def _sanitize_tool_results(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Remove orphaned tool_result blocks that don't have matching tool_use blocks.
"""Remove orphaned tool blocks that don't have matching counterparts.

The Anthropic API requires that each tool_result block must have a corresponding
tool_use block in the immediately preceding assistant message. This function removes
tool_result blocks that don't meet this requirement, converting them to text to
The Anthropic API requires:
1. Each tool_result block must have a corresponding tool_use in the preceding assistant message
2. Each tool_use block must have a corresponding tool_result in the next user message

This function removes orphaned blocks of both types, converting them to text to
preserve information.

Args:
messages: List of Anthropic format messages

Returns:
Sanitized messages with orphaned tool_results removed or converted to text
Sanitized messages with orphaned tool blocks removed or converted to text
"""
if not messages:
return messages

sanitized = []
for i, msg in enumerate(messages):
if msg.get("role") == "user" and isinstance(msg.get("content"), list):
# Find tool_use_ids from the immediately preceding assistant message
valid_tool_use_ids: set[str] = set()
if i > 0 and messages[i - 1].get("role") == "assistant":
prev_content = messages[i - 1].get("content", [])
if isinstance(prev_content, list):
for block in prev_content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tool_id = block.get("id")
if tool_id:
valid_tool_use_ids.add(tool_id)

# Filter content blocks
new_content = []
orphaned_results = []
for block in msg["content"]:
if isinstance(block, dict) and block.get("type") == "tool_result":
tool_use_id = block.get("tool_use_id")
if tool_use_id in valid_tool_use_ids:
def _iter_content_blocks(content: Any) -> list[Any]:
if isinstance(content, list):
return content
if isinstance(content, dict):
return [content]
return []

def _collect_tool_use_counts(content: Any) -> Counter[str]:
counts: Counter[str] = Counter()
for block in _iter_content_blocks(content):
if isinstance(block, dict) and block.get("type") == "tool_use":
tool_id = block.get("id")
if tool_id:
counts[str(tool_id)] += 1
return counts

def _collect_tool_result_counts(content: Any) -> Counter[str]:
counts: Counter[str] = Counter()
for block in _iter_content_blocks(content):
if isinstance(block, dict) and block.get("type") == "tool_result":
tool_use_id = block.get("tool_use_id")
if tool_use_id:
counts[str(tool_use_id)] += 1
return counts

def _sanitize_once(
current_messages: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], bool]:
assistant_tool_use_counts: list[Counter[str]] = [
Counter() for _ in current_messages
]
user_tool_result_counts: list[Counter[str]] = [
Counter() for _ in current_messages
]

for i, msg in enumerate(current_messages):
role = msg.get("role")
content = msg.get("content")
if role == "assistant":
assistant_tool_use_counts[i] = _collect_tool_use_counts(content)
elif role == "user":
user_tool_result_counts[i] = _collect_tool_result_counts(content)

paired_counts_for_assistant: list[Counter[str]] = [
Counter() for _ in current_messages
]
paired_counts_for_user: list[Counter[str]] = [
Counter() for _ in current_messages
]

for i, msg in enumerate(current_messages):
if msg.get("role") != "assistant":
continue
if (
i + 1 < len(current_messages)
and current_messages[i + 1].get("role") == "user"
):
paired: Counter[str] = Counter()
assistant_counts = assistant_tool_use_counts[i]
user_counts = user_tool_result_counts[i + 1]
for tool_id, tool_use_count in assistant_counts.items():
if tool_id in user_counts:
paired[tool_id] = min(tool_use_count, user_counts[tool_id])
paired_counts_for_assistant[i] = paired
paired_counts_for_user[i + 1] = paired

sanitized: list[dict[str, Any]] = []
changed = False
for i, msg in enumerate(current_messages):
role = msg.get("role")
content = msg.get("content")
content_blocks = _iter_content_blocks(content)
content_was_dict = isinstance(content, dict)

# Handle assistant messages with tool_use blocks
if role == "assistant" and content_blocks:
valid_tool_use_counts = paired_counts_for_assistant[i]
kept_tool_use_counts: Counter[str] = Counter()

new_content = []
orphaned_tool_uses = []

for block in content_blocks:
if isinstance(block, dict) and block.get("type") == "tool_use":
raw_tool_id = block.get("id")
tool_use_id = str(raw_tool_id) if raw_tool_id else None
# Only keep tool_use when the *next* user message provides the result
if tool_use_id and kept_tool_use_counts[
tool_use_id
] < valid_tool_use_counts.get(tool_use_id, 0):
kept_tool_use_counts[tool_use_id] += 1
new_content.append(block)
else:
orphaned_tool_uses.append(block)
changed = True
logger.warning(
"orphaned_tool_use_removed",
tool_use_id=tool_use_id,
tool_name=block.get("name"),
message_index=i,
category="message_sanitization",
)
else:
new_content.append(block)

# Convert orphaned tool_use blocks to text
if orphaned_tool_uses:
orphan_text = (
"[Tool calls from compacted history - results not available]\n"
)
for orphan in orphaned_tool_uses:
tool_name = orphan.get("name", "unknown")
tool_input = orphan.get("input", {})
input_str = str(tool_input)
if len(input_str) > 200:
input_str = input_str[:200] + "..."
orphan_text += f"- Called {tool_name}: {input_str}\n"

# Add text block at the beginning (or prepend to existing text)
if (
new_content
and isinstance(new_content[0], dict)
and new_content[0].get("type") == "text"
):
new_content[0] = {
**new_content[0],
"text": orphan_text + "\n" + new_content[0]["text"],
}
else:
new_content.insert(0, {"type": "text", "text": orphan_text})

if new_content:
if content_was_dict:
changed = True
sanitized.append({**msg, "content": new_content})
else:
# If no content left, add minimal text to avoid empty assistant message
changed = True
sanitized.append(
{**msg, "content": "[Previous response content compacted]"}
)
continue

# Handle user messages with tool_result blocks
elif role == "user" and content_blocks:
# Find tool_use_ids from the immediately preceding assistant message
valid_tool_use_counts = paired_counts_for_user[i]
kept_tool_result_counts: Counter[str] = Counter()

# Filter content blocks
new_content = []
orphaned_results = []
for block in content_blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
tool_use_id = block.get("tool_use_id")
tool_use_id = str(tool_use_id) if tool_use_id else None
if tool_use_id and kept_tool_result_counts[
tool_use_id
] < valid_tool_use_counts.get(tool_use_id, 0):
kept_tool_result_counts[tool_use_id] += 1
new_content.append(block)
else:
# Track orphaned tool_result for conversion to text
orphaned_results.append(block)
changed = True
logger.warning(
"orphaned_tool_result_removed",
tool_use_id=tool_use_id,
valid_ids=list(valid_tool_use_counts.keys()),
message_index=i,
category="message_sanitization",
)
else:
# Track orphaned tool_result for conversion to text
orphaned_results.append(block)
logger.warning(
"orphaned_tool_result_removed",
tool_use_id=tool_use_id,
valid_ids=list(valid_tool_use_ids),
message_index=i,
category="message_sanitization",
new_content.append(block)

# Convert orphaned results to text block to preserve information.
# Avoid injecting text into a valid tool_result reply message.
if orphaned_results and sum(kept_tool_result_counts.values()) == 0:
orphan_text = "[Previous tool results from compacted history]\n"
for orphan in orphaned_results:
result_content = orphan.get("content", "")
if isinstance(result_content, list):
text_parts = []
for c in result_content:
if isinstance(c, dict) and c.get("type") == "text":
text_parts.append(c.get("text", ""))
result_content = "\n".join(text_parts)
# Truncate long content
content_str = str(result_content)
if len(content_str) > 500:
content_str = content_str[:500] + "..."
orphan_text += (
f"- Tool {orphan.get('tool_use_id', 'unknown')}: "
f"{content_str}\n"
)

# Add as text block at the beginning
new_content.insert(0, {"type": "text", "text": orphan_text})

# Update message content (only if we have content left)
if new_content:
if content_was_dict:
changed = True
sanitized.append({**msg, "content": new_content})
else:
new_content.append(block)

# Convert orphaned results to text block to preserve information
if orphaned_results:
orphan_text = "[Previous tool results from compacted history]\n"
for orphan in orphaned_results:
content = orphan.get("content", "")
if isinstance(content, list):
text_parts = []
for c in content:
if isinstance(c, dict) and c.get("type") == "text":
text_parts.append(c.get("text", ""))
content = "\n".join(text_parts)
# Truncate long content
content_str = str(content)
if len(content_str) > 500:
content_str = content_str[:500] + "..."
orphan_text += f"- Tool {orphan.get('tool_use_id', 'unknown')}: {content_str}\n"

# Add as text block at the beginning
new_content.insert(0, {"type": "text", "text": orphan_text})

# Update message content (only if we have content left)
if new_content:
sanitized.append({**msg, "content": new_content})
# If no content left, skip this message entirely
else:
sanitized.append(msg)
# If no content left, skip this message entirely
changed = True
continue
else:
sanitized.append(msg)

return sanitized, changed

sanitized = messages
for _ in range(2):
sanitized, changed = _sanitize_once(sanitized)
if not changed:
break

return sanitized

Expand Down
18 changes: 18 additions & 0 deletions ccproxy/llms/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""LLM utility modules for token estimation and context management."""

from .context_truncation import truncate_to_fit
from .token_estimation import (
estimate_messages_tokens,
estimate_request_tokens,
estimate_tokens,
get_max_input_tokens,
)


__all__ = [
"estimate_tokens",
"estimate_messages_tokens",
"estimate_request_tokens",
"get_max_input_tokens",
"truncate_to_fit",
]
Loading
Loading