Skip to content

Commit 78550bd

Browse files
author
smhanan
committed
fix: tool result sanitization and context truncation
1 parent 5174b1a commit 78550bd

File tree

17 files changed

+3016
-103
lines changed

17 files changed

+3016
-103
lines changed

ccproxy/llms/formatters/openai_to_anthropic/requests.py

Lines changed: 222 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import json
6+
from collections import Counter
67
from typing import Any
78

89
from ccproxy.core.constants import DEFAULT_MAX_TOKENS
@@ -15,83 +16,242 @@
1516

1617

1718
def _sanitize_tool_results(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
18-
"""Remove orphaned tool_result blocks that don't have matching tool_use blocks.
19+
"""Remove orphaned tool blocks that don't have matching counterparts.
1920
20-
The Anthropic API requires that each tool_result block must have a corresponding
21-
tool_use block in the immediately preceding assistant message. This function removes
22-
tool_result blocks that don't meet this requirement, converting them to text to
21+
The Anthropic API requires:
22+
1. Each tool_result block must have a corresponding tool_use in the preceding assistant message
23+
2. Each tool_use block must have a corresponding tool_result in the next user message
24+
25+
This function removes orphaned blocks of both types, converting them to text to
2326
preserve information.
2427
2528
Args:
2629
messages: List of Anthropic format messages
2730
2831
Returns:
29-
Sanitized messages with orphaned tool_results removed or converted to text
32+
Sanitized messages with orphaned tool blocks removed or converted to text
3033
"""
3134
if not messages:
3235
return messages
3336

34-
sanitized = []
35-
for i, msg in enumerate(messages):
36-
if msg.get("role") == "user" and isinstance(msg.get("content"), list):
37-
# Find tool_use_ids from the immediately preceding assistant message
38-
valid_tool_use_ids: set[str] = set()
39-
if i > 0 and messages[i - 1].get("role") == "assistant":
40-
prev_content = messages[i - 1].get("content", [])
41-
if isinstance(prev_content, list):
42-
for block in prev_content:
43-
if isinstance(block, dict) and block.get("type") == "tool_use":
44-
tool_id = block.get("id")
45-
if tool_id:
46-
valid_tool_use_ids.add(tool_id)
47-
48-
# Filter content blocks
49-
new_content = []
50-
orphaned_results = []
51-
for block in msg["content"]:
52-
if isinstance(block, dict) and block.get("type") == "tool_result":
53-
tool_use_id = block.get("tool_use_id")
54-
if tool_use_id in valid_tool_use_ids:
37+
def _iter_content_blocks(content: Any) -> list[Any]:
38+
if isinstance(content, list):
39+
return content
40+
if isinstance(content, dict):
41+
return [content]
42+
return []
43+
44+
def _collect_tool_use_counts(content: Any) -> Counter[str]:
45+
counts: Counter[str] = Counter()
46+
for block in _iter_content_blocks(content):
47+
if isinstance(block, dict) and block.get("type") == "tool_use":
48+
tool_id = block.get("id")
49+
if tool_id:
50+
counts[str(tool_id)] += 1
51+
return counts
52+
53+
def _collect_tool_result_counts(content: Any) -> Counter[str]:
54+
counts: Counter[str] = Counter()
55+
for block in _iter_content_blocks(content):
56+
if isinstance(block, dict) and block.get("type") == "tool_result":
57+
tool_use_id = block.get("tool_use_id")
58+
if tool_use_id:
59+
counts[str(tool_use_id)] += 1
60+
return counts
61+
62+
def _sanitize_once(
63+
current_messages: list[dict[str, Any]],
64+
) -> tuple[list[dict[str, Any]], bool]:
65+
assistant_tool_use_counts: list[Counter[str]] = [
66+
Counter() for _ in current_messages
67+
]
68+
user_tool_result_counts: list[Counter[str]] = [
69+
Counter() for _ in current_messages
70+
]
71+
72+
for i, msg in enumerate(current_messages):
73+
role = msg.get("role")
74+
content = msg.get("content")
75+
if role == "assistant":
76+
assistant_tool_use_counts[i] = _collect_tool_use_counts(content)
77+
elif role == "user":
78+
user_tool_result_counts[i] = _collect_tool_result_counts(content)
79+
80+
paired_counts_for_assistant: list[Counter[str]] = [
81+
Counter() for _ in current_messages
82+
]
83+
paired_counts_for_user: list[Counter[str]] = [
84+
Counter() for _ in current_messages
85+
]
86+
87+
for i, msg in enumerate(current_messages):
88+
if msg.get("role") != "assistant":
89+
continue
90+
if (
91+
i + 1 < len(current_messages)
92+
and current_messages[i + 1].get("role") == "user"
93+
):
94+
paired: Counter[str] = Counter()
95+
assistant_counts = assistant_tool_use_counts[i]
96+
user_counts = user_tool_result_counts[i + 1]
97+
for tool_id, tool_use_count in assistant_counts.items():
98+
if tool_id in user_counts:
99+
paired[tool_id] = min(tool_use_count, user_counts[tool_id])
100+
paired_counts_for_assistant[i] = paired
101+
paired_counts_for_user[i + 1] = paired
102+
103+
sanitized: list[dict[str, Any]] = []
104+
changed = False
105+
for i, msg in enumerate(current_messages):
106+
role = msg.get("role")
107+
content = msg.get("content")
108+
content_blocks = _iter_content_blocks(content)
109+
content_was_dict = isinstance(content, dict)
110+
111+
# Handle assistant messages with tool_use blocks
112+
if role == "assistant" and content_blocks:
113+
valid_tool_use_counts = paired_counts_for_assistant[i]
114+
kept_tool_use_counts: Counter[str] = Counter()
115+
116+
new_content = []
117+
orphaned_tool_uses = []
118+
119+
for block in content_blocks:
120+
if isinstance(block, dict) and block.get("type") == "tool_use":
121+
raw_tool_id = block.get("id")
122+
tool_use_id = str(raw_tool_id) if raw_tool_id else None
123+
# Only keep tool_use when the *next* user message provides the result
124+
if tool_use_id and kept_tool_use_counts[
125+
tool_use_id
126+
] < valid_tool_use_counts.get(tool_use_id, 0):
127+
kept_tool_use_counts[tool_use_id] += 1
128+
new_content.append(block)
129+
else:
130+
orphaned_tool_uses.append(block)
131+
changed = True
132+
logger.warning(
133+
"orphaned_tool_use_removed",
134+
tool_use_id=tool_use_id,
135+
tool_name=block.get("name"),
136+
message_index=i,
137+
category="message_sanitization",
138+
)
139+
else:
55140
new_content.append(block)
141+
142+
# Convert orphaned tool_use blocks to text
143+
if orphaned_tool_uses:
144+
orphan_text = (
145+
"[Tool calls from compacted history - results not available]\n"
146+
)
147+
for orphan in orphaned_tool_uses:
148+
tool_name = orphan.get("name", "unknown")
149+
tool_input = orphan.get("input", {})
150+
input_str = str(tool_input)
151+
if len(input_str) > 200:
152+
input_str = input_str[:200] + "..."
153+
orphan_text += f"- Called {tool_name}: {input_str}\n"
154+
155+
# Add text block at the beginning (or prepend to existing text)
156+
if (
157+
new_content
158+
and isinstance(new_content[0], dict)
159+
and new_content[0].get("type") == "text"
160+
):
161+
new_content[0] = {
162+
**new_content[0],
163+
"text": orphan_text + "\n" + new_content[0]["text"],
164+
}
165+
else:
166+
new_content.insert(0, {"type": "text", "text": orphan_text})
167+
168+
if new_content:
169+
if content_was_dict:
170+
changed = True
171+
sanitized.append({**msg, "content": new_content})
172+
else:
173+
# If no content left, add minimal text to avoid empty assistant message
174+
changed = True
175+
sanitized.append(
176+
{**msg, "content": "[Previous response content compacted]"}
177+
)
178+
continue
179+
180+
# Handle user messages with tool_result blocks
181+
elif role == "user" and content_blocks:
182+
# Find tool_use_ids from the immediately preceding assistant message
183+
valid_tool_use_counts = paired_counts_for_user[i]
184+
kept_tool_result_counts: Counter[str] = Counter()
185+
186+
# Filter content blocks
187+
new_content = []
188+
orphaned_results = []
189+
for block in content_blocks:
190+
if isinstance(block, dict) and block.get("type") == "tool_result":
191+
tool_use_id = block.get("tool_use_id")
192+
tool_use_id = str(tool_use_id) if tool_use_id else None
193+
if tool_use_id and kept_tool_result_counts[
194+
tool_use_id
195+
] < valid_tool_use_counts.get(tool_use_id, 0):
196+
kept_tool_result_counts[tool_use_id] += 1
197+
new_content.append(block)
198+
else:
199+
# Track orphaned tool_result for conversion to text
200+
orphaned_results.append(block)
201+
changed = True
202+
logger.warning(
203+
"orphaned_tool_result_removed",
204+
tool_use_id=tool_use_id,
205+
valid_ids=list(valid_tool_use_counts.keys()),
206+
message_index=i,
207+
category="message_sanitization",
208+
)
56209
else:
57-
# Track orphaned tool_result for conversion to text
58-
orphaned_results.append(block)
59-
logger.warning(
60-
"orphaned_tool_result_removed",
61-
tool_use_id=tool_use_id,
62-
valid_ids=list(valid_tool_use_ids),
63-
message_index=i,
64-
category="message_sanitization",
210+
new_content.append(block)
211+
212+
# Convert orphaned results to text block to preserve information.
213+
# Avoid injecting text into a valid tool_result reply message.
214+
if orphaned_results and sum(kept_tool_result_counts.values()) == 0:
215+
orphan_text = "[Previous tool results from compacted history]\n"
216+
for orphan in orphaned_results:
217+
result_content = orphan.get("content", "")
218+
if isinstance(result_content, list):
219+
text_parts = []
220+
for c in result_content:
221+
if isinstance(c, dict) and c.get("type") == "text":
222+
text_parts.append(c.get("text", ""))
223+
result_content = "\n".join(text_parts)
224+
# Truncate long content
225+
content_str = str(result_content)
226+
if len(content_str) > 500:
227+
content_str = content_str[:500] + "..."
228+
orphan_text += (
229+
f"- Tool {orphan.get('tool_use_id', 'unknown')}: "
230+
f"{content_str}\n"
65231
)
232+
233+
# Add as text block at the beginning
234+
new_content.insert(0, {"type": "text", "text": orphan_text})
235+
236+
# Update message content (only if we have content left)
237+
if new_content:
238+
if content_was_dict:
239+
changed = True
240+
sanitized.append({**msg, "content": new_content})
66241
else:
67-
new_content.append(block)
68-
69-
# Convert orphaned results to text block to preserve information
70-
if orphaned_results:
71-
orphan_text = "[Previous tool results from compacted history]\n"
72-
for orphan in orphaned_results:
73-
content = orphan.get("content", "")
74-
if isinstance(content, list):
75-
text_parts = []
76-
for c in content:
77-
if isinstance(c, dict) and c.get("type") == "text":
78-
text_parts.append(c.get("text", ""))
79-
content = "\n".join(text_parts)
80-
# Truncate long content
81-
content_str = str(content)
82-
if len(content_str) > 500:
83-
content_str = content_str[:500] + "..."
84-
orphan_text += f"- Tool {orphan.get('tool_use_id', 'unknown')}: {content_str}\n"
85-
86-
# Add as text block at the beginning
87-
new_content.insert(0, {"type": "text", "text": orphan_text})
88-
89-
# Update message content (only if we have content left)
90-
if new_content:
91-
sanitized.append({**msg, "content": new_content})
92-
# If no content left, skip this message entirely
93-
else:
94-
sanitized.append(msg)
242+
# If no content left, skip this message entirely
243+
changed = True
244+
continue
245+
else:
246+
sanitized.append(msg)
247+
248+
return sanitized, changed
249+
250+
sanitized = messages
251+
for _ in range(2):
252+
sanitized, changed = _sanitize_once(sanitized)
253+
if not changed:
254+
break
95255

96256
return sanitized
97257

ccproxy/llms/utils/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""LLM utility modules for token estimation and context management."""
2+
3+
from .context_truncation import truncate_to_fit
4+
from .token_estimation import (
5+
estimate_messages_tokens,
6+
estimate_request_tokens,
7+
estimate_tokens,
8+
get_max_input_tokens,
9+
)
10+
11+
12+
__all__ = [
13+
"estimate_tokens",
14+
"estimate_messages_tokens",
15+
"estimate_request_tokens",
16+
"get_max_input_tokens",
17+
"truncate_to_fit",
18+
]

0 commit comments

Comments
 (0)