diff --git a/ccproxy/llms/formatters/openai_to_anthropic/requests.py b/ccproxy/llms/formatters/openai_to_anthropic/requests.py index 5540289c..0129e126 100644 --- a/ccproxy/llms/formatters/openai_to_anthropic/requests.py +++ b/ccproxy/llms/formatters/openai_to_anthropic/requests.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from collections import Counter from typing import Any from ccproxy.core.constants import DEFAULT_MAX_TOKENS @@ -15,83 +16,242 @@ def _sanitize_tool_results(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Remove orphaned tool_result blocks that don't have matching tool_use blocks. + """Remove orphaned tool blocks that don't have matching counterparts. - The Anthropic API requires that each tool_result block must have a corresponding - tool_use block in the immediately preceding assistant message. This function removes - tool_result blocks that don't meet this requirement, converting them to text to + The Anthropic API requires: + 1. Each tool_result block must have a corresponding tool_use in the preceding assistant message + 2. Each tool_use block must have a corresponding tool_result in the next user message + + This function removes orphaned blocks of both types, converting them to text to preserve information. Args: messages: List of Anthropic format messages Returns: - Sanitized messages with orphaned tool_results removed or converted to text + Sanitized messages with orphaned tool blocks removed or converted to text """ if not messages: return messages - sanitized = [] - for i, msg in enumerate(messages): - if msg.get("role") == "user" and isinstance(msg.get("content"), list): - # Find tool_use_ids from the immediately preceding assistant message - valid_tool_use_ids: set[str] = set() - if i > 0 and messages[i - 1].get("role") == "assistant": - prev_content = messages[i - 1].get("content", []) - if isinstance(prev_content, list): - for block in prev_content: - if isinstance(block, dict) and block.get("type") == "tool_use": - tool_id = block.get("id") - if tool_id: - valid_tool_use_ids.add(tool_id) - - # Filter content blocks - new_content = [] - orphaned_results = [] - for block in msg["content"]: - if isinstance(block, dict) and block.get("type") == "tool_result": - tool_use_id = block.get("tool_use_id") - if tool_use_id in valid_tool_use_ids: + def _iter_content_blocks(content: Any) -> list[Any]: + if isinstance(content, list): + return content + if isinstance(content, dict): + return [content] + return [] + + def _collect_tool_use_counts(content: Any) -> Counter[str]: + counts: Counter[str] = Counter() + for block in _iter_content_blocks(content): + if isinstance(block, dict) and block.get("type") == "tool_use": + tool_id = block.get("id") + if tool_id: + counts[str(tool_id)] += 1 + return counts + + def _collect_tool_result_counts(content: Any) -> Counter[str]: + counts: Counter[str] = Counter() + for block in _iter_content_blocks(content): + if isinstance(block, dict) and block.get("type") == "tool_result": + tool_use_id = block.get("tool_use_id") + if tool_use_id: + counts[str(tool_use_id)] += 1 + return counts + + def _sanitize_once( + current_messages: list[dict[str, Any]], + ) -> tuple[list[dict[str, Any]], bool]: + assistant_tool_use_counts: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + user_tool_result_counts: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + + for i, msg in enumerate(current_messages): + role = msg.get("role") + content = msg.get("content") + if role == "assistant": + assistant_tool_use_counts[i] = _collect_tool_use_counts(content) + elif role == "user": + user_tool_result_counts[i] = _collect_tool_result_counts(content) + + paired_counts_for_assistant: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + paired_counts_for_user: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + + for i, msg in enumerate(current_messages): + if msg.get("role") != "assistant": + continue + if ( + i + 1 < len(current_messages) + and current_messages[i + 1].get("role") == "user" + ): + paired: Counter[str] = Counter() + assistant_counts = assistant_tool_use_counts[i] + user_counts = user_tool_result_counts[i + 1] + for tool_id, tool_use_count in assistant_counts.items(): + if tool_id in user_counts: + paired[tool_id] = min(tool_use_count, user_counts[tool_id]) + paired_counts_for_assistant[i] = paired + paired_counts_for_user[i + 1] = paired + + sanitized: list[dict[str, Any]] = [] + changed = False + for i, msg in enumerate(current_messages): + role = msg.get("role") + content = msg.get("content") + content_blocks = _iter_content_blocks(content) + content_was_dict = isinstance(content, dict) + + # Handle assistant messages with tool_use blocks + if role == "assistant" and content_blocks: + valid_tool_use_counts = paired_counts_for_assistant[i] + kept_tool_use_counts: Counter[str] = Counter() + + new_content = [] + orphaned_tool_uses = [] + + for block in content_blocks: + if isinstance(block, dict) and block.get("type") == "tool_use": + raw_tool_id = block.get("id") + tool_use_id = str(raw_tool_id) if raw_tool_id else None + # Only keep tool_use when the *next* user message provides the result + if tool_use_id and kept_tool_use_counts[ + tool_use_id + ] < valid_tool_use_counts.get(tool_use_id, 0): + kept_tool_use_counts[tool_use_id] += 1 + new_content.append(block) + else: + orphaned_tool_uses.append(block) + changed = True + logger.warning( + "orphaned_tool_use_removed", + tool_use_id=tool_use_id, + tool_name=block.get("name"), + message_index=i, + category="message_sanitization", + ) + else: new_content.append(block) + + # Convert orphaned tool_use blocks to text + if orphaned_tool_uses: + orphan_text = ( + "[Tool calls from compacted history - results not available]\n" + ) + for orphan in orphaned_tool_uses: + tool_name = orphan.get("name", "unknown") + tool_input = orphan.get("input", {}) + input_str = str(tool_input) + if len(input_str) > 200: + input_str = input_str[:200] + "..." + orphan_text += f"- Called {tool_name}: {input_str}\n" + + # Add text block at the beginning (or prepend to existing text) + if ( + new_content + and isinstance(new_content[0], dict) + and new_content[0].get("type") == "text" + ): + new_content[0] = { + **new_content[0], + "text": orphan_text + "\n" + new_content[0]["text"], + } + else: + new_content.insert(0, {"type": "text", "text": orphan_text}) + + if new_content: + if content_was_dict: + changed = True + sanitized.append({**msg, "content": new_content}) + else: + # If no content left, add minimal text to avoid empty assistant message + changed = True + sanitized.append( + {**msg, "content": "[Previous response content compacted]"} + ) + continue + + # Handle user messages with tool_result blocks + elif role == "user" and content_blocks: + # Find tool_use_ids from the immediately preceding assistant message + valid_tool_use_counts = paired_counts_for_user[i] + kept_tool_result_counts: Counter[str] = Counter() + + # Filter content blocks + new_content = [] + orphaned_results = [] + for block in content_blocks: + if isinstance(block, dict) and block.get("type") == "tool_result": + tool_use_id = block.get("tool_use_id") + tool_use_id = str(tool_use_id) if tool_use_id else None + if tool_use_id and kept_tool_result_counts[ + tool_use_id + ] < valid_tool_use_counts.get(tool_use_id, 0): + kept_tool_result_counts[tool_use_id] += 1 + new_content.append(block) + else: + # Track orphaned tool_result for conversion to text + orphaned_results.append(block) + changed = True + logger.warning( + "orphaned_tool_result_removed", + tool_use_id=tool_use_id, + valid_ids=list(valid_tool_use_counts.keys()), + message_index=i, + category="message_sanitization", + ) else: - # Track orphaned tool_result for conversion to text - orphaned_results.append(block) - logger.warning( - "orphaned_tool_result_removed", - tool_use_id=tool_use_id, - valid_ids=list(valid_tool_use_ids), - message_index=i, - category="message_sanitization", + new_content.append(block) + + # Convert orphaned results to text block to preserve information. + # Avoid injecting text into a valid tool_result reply message. + if orphaned_results and sum(kept_tool_result_counts.values()) == 0: + orphan_text = "[Previous tool results from compacted history]\n" + for orphan in orphaned_results: + result_content = orphan.get("content", "") + if isinstance(result_content, list): + text_parts = [] + for c in result_content: + if isinstance(c, dict) and c.get("type") == "text": + text_parts.append(c.get("text", "")) + result_content = "\n".join(text_parts) + # Truncate long content + content_str = str(result_content) + if len(content_str) > 500: + content_str = content_str[:500] + "..." + orphan_text += ( + f"- Tool {orphan.get('tool_use_id', 'unknown')}: " + f"{content_str}\n" ) + + # Add as text block at the beginning + new_content.insert(0, {"type": "text", "text": orphan_text}) + + # Update message content (only if we have content left) + if new_content: + if content_was_dict: + changed = True + sanitized.append({**msg, "content": new_content}) else: - new_content.append(block) - - # Convert orphaned results to text block to preserve information - if orphaned_results: - orphan_text = "[Previous tool results from compacted history]\n" - for orphan in orphaned_results: - content = orphan.get("content", "") - if isinstance(content, list): - text_parts = [] - for c in content: - if isinstance(c, dict) and c.get("type") == "text": - text_parts.append(c.get("text", "")) - content = "\n".join(text_parts) - # Truncate long content - content_str = str(content) - if len(content_str) > 500: - content_str = content_str[:500] + "..." - orphan_text += f"- Tool {orphan.get('tool_use_id', 'unknown')}: {content_str}\n" - - # Add as text block at the beginning - new_content.insert(0, {"type": "text", "text": orphan_text}) - - # Update message content (only if we have content left) - if new_content: - sanitized.append({**msg, "content": new_content}) - # If no content left, skip this message entirely - else: - sanitized.append(msg) + # If no content left, skip this message entirely + changed = True + continue + else: + sanitized.append(msg) + + return sanitized, changed + + sanitized = messages + for _ in range(2): + sanitized, changed = _sanitize_once(sanitized) + if not changed: + break return sanitized diff --git a/ccproxy/llms/utils/__init__.py b/ccproxy/llms/utils/__init__.py new file mode 100644 index 00000000..4c2debbb --- /dev/null +++ b/ccproxy/llms/utils/__init__.py @@ -0,0 +1,18 @@ +"""LLM utility modules for token estimation and context management.""" + +from .context_truncation import truncate_to_fit +from .token_estimation import ( + estimate_messages_tokens, + estimate_request_tokens, + estimate_tokens, + get_max_input_tokens, +) + + +__all__ = [ + "estimate_tokens", + "estimate_messages_tokens", + "estimate_request_tokens", + "get_max_input_tokens", + "truncate_to_fit", +] diff --git a/ccproxy/llms/utils/context_truncation.py b/ccproxy/llms/utils/context_truncation.py new file mode 100644 index 00000000..43e0afec --- /dev/null +++ b/ccproxy/llms/utils/context_truncation.py @@ -0,0 +1,210 @@ +"""Context window truncation utilities.""" + +import copy +from typing import Any + +from ccproxy.core.logging import get_logger + +from .token_estimation import estimate_request_tokens + + +logger = get_logger(__name__) + + +# Maximum characters to keep for truncated content blocks +MAX_TRUNCATED_CONTENT_CHARS = 10000 + + +def _truncate_large_content_blocks( + messages: list[dict[str, Any]], + max_chars: int = MAX_TRUNCATED_CONTENT_CHARS, +) -> tuple[list[dict[str, Any]], int]: + """Truncate large content blocks within messages. + + This is a fallback when message-level truncation isn't enough. + Targets large tool_result blocks and text content. + + Args: + messages: List of messages to process + max_chars: Maximum characters to keep per content block + + Returns: + Tuple of (modified_messages, blocks_truncated_count) + """ + truncated_count = 0 + modified_messages = [] + + for msg in messages: + msg_copy = copy.deepcopy(msg) + content = msg_copy.get("content") + + if isinstance(content, str) and len(content) > max_chars: + # Truncate large string content + msg_copy["content"] = ( + content[:max_chars] + + f"\n\n[Content truncated - {len(content) - max_chars} characters removed]" + ) + truncated_count += 1 + + elif isinstance(content, list): + # Process content blocks + new_content = [] + for block in content: + if isinstance(block, dict): + block_copy = copy.deepcopy(block) + + # Handle tool_result blocks with large content + if block_copy.get("type") == "tool_result": + tool_content = block_copy.get("content", "") + if ( + isinstance(tool_content, str) + and len(tool_content) > max_chars + ): + block_copy["content"] = ( + tool_content[:max_chars] + + f"\n\n[Tool result truncated - {len(tool_content) - max_chars} characters removed]" + ) + truncated_count += 1 + + # Handle text blocks with large content + elif block_copy.get("type") == "text": + text = block_copy.get("text", "") + if len(text) > max_chars: + block_copy["text"] = ( + text[:max_chars] + + f"\n\n[Text truncated - {len(text) - max_chars} characters removed]" + ) + truncated_count += 1 + + new_content.append(block_copy) + else: + new_content.append(block) + + msg_copy["content"] = new_content + + modified_messages.append(msg_copy) + + return modified_messages, truncated_count + + +def truncate_to_fit( + request_data: dict[str, Any], + max_input_tokens: int, + preserve_recent: int = 10, + safety_margin: float = 0.9, +) -> tuple[dict[str, Any], bool]: + """Truncate request to fit within token limit. + + Strategy: + 1. Always preserve system prompt and tools + 2. Try to preserve the last N messages (preserve_recent) + 3. Remove oldest messages first + 4. If too few messages to truncate, reduce preserve_recent dynamically + 5. As a last resort, truncate large content blocks within messages + 6. Add a truncation notice when content is removed + + Args: + request_data: The request payload + max_input_tokens: Model's max input token limit + preserve_recent: Number of recent messages to always keep + safety_margin: Target this fraction of max to allow for estimation error + + Returns: + Tuple of (modified_request_data, was_truncated) + """ + target_tokens = int(max_input_tokens * safety_margin) + + current_tokens = estimate_request_tokens(request_data) + if current_tokens <= target_tokens: + return request_data, False + + # Work on a copy + modified = copy.deepcopy(request_data) + messages = modified.get("messages", []) + + # If we have fewer messages than preserve_recent, reduce preserve_recent + # We need at least 1 message to be truncatable for this strategy to work + effective_preserve = min(preserve_recent, len(messages) - 1) + + # If we have 0 or 1 messages, we can't do message-level truncation + # Skip to content-level truncation + if effective_preserve < 0: + effective_preserve = 0 + + # Split into truncatable and preserved messages + if effective_preserve > 0: + truncatable = messages[:-effective_preserve] + preserved = messages[-effective_preserve:] + else: + truncatable = list(messages) + preserved = [] + + # Remove oldest messages until we're under the limit + removed_count = 0 + while truncatable and estimate_request_tokens(modified) > target_tokens: + truncatable.pop(0) + removed_count += 1 + modified["messages"] = truncatable + preserved + + # Check if we're still over the limit after removing all truncatable messages + if estimate_request_tokens(modified) > target_tokens: + logger.info( + "context_truncation_message_level_insufficient", + reason="still_over_limit_after_message_truncation", + message_count=len(modified.get("messages", [])), + current_tokens=estimate_request_tokens(modified), + target_tokens=target_tokens, + category="context_management", + ) + + # Fallback: truncate large content blocks within remaining messages + truncated_messages, blocks_truncated = _truncate_large_content_blocks( + modified.get("messages", []) + ) + modified["messages"] = truncated_messages + + if blocks_truncated > 0: + logger.info( + "context_truncation_content_level", + blocks_truncated=blocks_truncated, + current_tokens=estimate_request_tokens(modified), + target_tokens=target_tokens, + category="context_management", + ) + + # If still over limit after content truncation, log error + final_tokens = estimate_request_tokens(modified) + if final_tokens > target_tokens: + logger.error( + "context_truncation_failed", + reason="still_over_limit_after_all_truncation", + final_tokens=final_tokens, + target_tokens=target_tokens, + messages_removed=removed_count, + blocks_truncated=blocks_truncated, + category="context_management", + ) + # Still return the truncated version - it's better than nothing + # The API will return an error, but at least we tried + + # Add truncation notice as first user message if we removed content + if removed_count > 0: + notice = { + "role": "user", + "content": f"[Context truncated - {removed_count} earlier messages removed to fit context window]", + } + modified["messages"] = [notice] + modified["messages"] + + final_tokens = estimate_request_tokens(modified) + + logger.info( + "context_truncated", + original_tokens=current_tokens, + final_tokens=final_tokens, + messages_removed=removed_count, + target_tokens=target_tokens, + effective_preserve_recent=effective_preserve, + category="context_management", + ) + + return modified, True diff --git a/ccproxy/llms/utils/token_estimation.py b/ccproxy/llms/utils/token_estimation.py new file mode 100644 index 00000000..1c3abfdb --- /dev/null +++ b/ccproxy/llms/utils/token_estimation.py @@ -0,0 +1,195 @@ +"""Token estimation utilities for context window management.""" + +import json +from pathlib import Path +from typing import Any + + +# Cache for loaded token limits +_token_limits_cache: dict[str, int] | None = None + + +def estimate_tokens(content: Any) -> int: + """Estimate token count for content. + + Uses ~3 characters per token heuristic for English text. + This is a conservative estimate - actual may be lower. + + Args: + content: Message content (string, list of blocks, or dict) + + Returns: + Estimated token count + """ + if content is None: + return 0 + + if isinstance(content, str): + # ~3 chars per token for English, be conservative + return max(1, len(content) // 3) + + if isinstance(content, list): + total = 0 + for block in content: + if isinstance(block, dict): + block_type = block.get("type", "") + if block_type == "text": + total += estimate_tokens(block.get("text", "")) + elif block_type == "tool_use": + # Tool name + input + total += estimate_tokens(block.get("name", "")) + total += estimate_tokens(json.dumps(block.get("input", {}))) + elif block_type == "tool_result": + total += estimate_tokens(block.get("content", "")) + elif block_type == "image": + # Images are ~1600 tokens for typical size + total += 1600 + else: + # Generic block - serialize and estimate + total += estimate_tokens(json.dumps(block)) + else: + total += estimate_tokens(block) + return total + + if isinstance(content, dict): + return estimate_tokens(json.dumps(content)) + + return estimate_tokens(str(content)) + + +def estimate_messages_tokens(messages: list[dict[str, Any]]) -> int: + """Estimate total tokens for a list of messages. + + Args: + messages: List of message dicts with role and content + + Returns: + Estimated total token count + """ + total = 0 + for msg in messages: + # Role contributes ~2 tokens + total += 2 + total += estimate_tokens(msg.get("content")) + return total + + +def estimate_request_tokens(request_data: dict[str, Any]) -> int: + """Estimate total input tokens for a request. + + Includes messages, system prompt, and tool definitions. + + Args: + request_data: The request payload dictionary + + Returns: + Estimated total input token count + """ + total = 0 + + # Messages + messages = request_data.get("messages", []) + total += estimate_messages_tokens(messages) + + # System prompt + system = request_data.get("system") + if system: + total += estimate_tokens(system) + + # Tools + tools = request_data.get("tools", []) + if tools: + total += estimate_tokens(json.dumps(tools)) + + return total + + +def _load_token_limits() -> dict[str, int]: + """Load token limits from available sources. + + Loads from: + 1. Local token_limits.json in max_tokens plugin + 2. Pricing cache at ~/.cache/ccproxy/model_pricing.json + + Returns: + Dict mapping model names to max_input_tokens + """ + global _token_limits_cache + if _token_limits_cache is not None: + return _token_limits_cache + + _token_limits_cache = {} + + # Try local token_limits.json first + local_file = ( + Path(__file__).parent.parent.parent + / "plugins" + / "max_tokens" + / "token_limits.json" + ) + if local_file.exists(): + try: + with local_file.open("r", encoding="utf-8") as f: + data = json.load(f) + for model_name, model_data in data.items(): + if model_name.startswith("_"): + continue + if isinstance(model_data, dict): + max_input = model_data.get("max_input_tokens") + if isinstance(max_input, int): + _token_limits_cache[model_name] = max_input + except Exception: + pass # Fall through to pricing cache + + # Also try pricing cache for additional models + pricing_cache = Path.home() / ".cache" / "ccproxy" / "model_pricing.json" + if pricing_cache.exists(): + try: + with pricing_cache.open("r", encoding="utf-8") as f: + data = json.load(f) + for model_name, model_data in data.items(): + if model_name in _token_limits_cache: + continue # Local file takes precedence + if isinstance(model_data, dict): + max_input = model_data.get("max_input_tokens") + if isinstance(max_input, int): + _token_limits_cache[model_name] = max_input + except Exception: + pass + + return _token_limits_cache + + +def get_max_input_tokens(model: str) -> int | None: + """Get max input tokens for a model. + + Supports pattern matching for model variants: + - Exact match: "claude-opus-4-5-20251101" + - Prefix match: "claude-opus-4-5-*" matches "claude-opus-4-5-20251101" + + Args: + model: Model name or identifier + + Returns: + Max input tokens if known, None otherwise + """ + limits = _load_token_limits() + + # Try exact match first + if model in limits: + return limits[model] + + # Try prefix matching (for patterns like claude-opus-4-5-*) + for pattern, max_tokens in limits.items(): + if pattern.endswith("*"): + prefix = pattern[:-1] + if model.startswith(prefix): + return max_tokens + + # Try matching known model families + model_lower = model.lower() + for known_model, max_tokens in limits.items(): + if known_model.lower() in model_lower or model_lower in known_model.lower(): + return max_tokens + + return None diff --git a/ccproxy/plugins/claude_api/adapter.py b/ccproxy/plugins/claude_api/adapter.py index b73d88e2..80c71f95 100644 --- a/ccproxy/plugins/claude_api/adapter.py +++ b/ccproxy/plugins/claude_api/adapter.py @@ -12,6 +12,8 @@ DetectionServiceProtocol, TokenManagerProtocol, ) +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results +from ccproxy.llms.utils import get_max_input_tokens, truncate_to_fit from ccproxy.services.adapters.http_adapter import BaseHTTPAdapter from ccproxy.utils.headers import ( extract_response_headers, @@ -57,6 +59,35 @@ async def prepare_provider_request( if body_data.get("temperature") is None: body_data.pop("temperature", None) + # Get model from request for context truncation + model = body_data.get("model", "") + + # Auto-truncate context if request exceeds model limits + # IMPORTANT: Truncation must happen BEFORE sanitization because truncation + # can create orphaned tool blocks by removing messages with tool_use while + # keeping messages with tool_result + max_input = get_max_input_tokens(model) + if max_input: + body_data, was_truncated = truncate_to_fit( + body_data, + max_input_tokens=max_input, + preserve_recent=getattr(self.config, "preserve_recent_messages", 10), + safety_margin=getattr(self.config, "context_safety_margin", 0.9), + ) + if was_truncated: + logger.info( + "request_truncated_for_context_limit", + model=model, + max_input_tokens=max_input, + category="context_management", + ) + + # Sanitize tool_result blocks to remove orphaned references + # This fixes "unexpected tool_use_id" errors from conversation compaction + # Must run AFTER truncation to catch orphans created by truncation + if "messages" in body_data: + body_data["messages"] = _sanitize_tool_results(body_data["messages"]) + # Anthropic API constraint: cannot accept both temperature and top_p # Prioritize temperature over top_p when both are present if "temperature" in body_data and "top_p" in body_data: diff --git a/ccproxy/plugins/claude_sdk/adapter.py b/ccproxy/plugins/claude_sdk/adapter.py index 75f7988f..74bed887 100644 --- a/ccproxy/plugins/claude_sdk/adapter.py +++ b/ccproxy/plugins/claude_sdk/adapter.py @@ -14,7 +14,9 @@ from ccproxy.config.utils import OPENAI_CHAT_COMPLETIONS_PATH from ccproxy.core.logging import get_plugin_logger from ccproxy.core.request_context import RequestContext +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results from ccproxy.llms.streaming import OpenAIStreamProcessor +from ccproxy.llms.utils import get_max_input_tokens, truncate_to_fit from ccproxy.services.adapters.chain_composer import compose_from_chain from ccproxy.services.adapters.format_adapter import FormatAdapterProtocol from ccproxy.services.adapters.http_adapter import BaseHTTPAdapter @@ -234,6 +236,34 @@ async def handle_request( # Extract parameters for SDK handler messages = request_data.get("messages", []) model = request_data.get("model", "claude-3-opus-20240229") + + # Auto-truncate context if request exceeds model limits + # IMPORTANT: Truncation must happen BEFORE sanitization because truncation + # can create orphaned tool blocks by removing messages with tool_use while + # keeping messages with tool_result + max_input = get_max_input_tokens(model) + if max_input: + request_data, was_truncated = truncate_to_fit( + request_data, + max_input_tokens=max_input, + preserve_recent=getattr(self.config, "preserve_recent_messages", 10), + safety_margin=getattr(self.config, "context_safety_margin", 0.9), + ) + if was_truncated: + logger.info( + "request_truncated_for_context_limit", + model=model, + max_input_tokens=max_input, + category="context_management", + ) + # Update messages reference after truncation + messages = request_data.get("messages", []) + + # Sanitize tool_result blocks to remove orphaned references + # This fixes "unexpected tool_use_id" errors from conversation compaction + # Must run AFTER truncation to catch orphans created by truncation + messages = _sanitize_tool_results(messages) + request_data["messages"] = messages temperature = request_data.get("temperature") max_tokens = request_data.get("max_tokens") stream = request_data.get("stream", False) diff --git a/ccproxy/streaming/deferred.py b/ccproxy/streaming/deferred.py index b721a105..a9bbb498 100644 --- a/ccproxy/streaming/deferred.py +++ b/ccproxy/streaming/deferred.py @@ -13,6 +13,7 @@ import structlog from starlette.responses import JSONResponse, Response, StreamingResponse +from ccproxy.core.constants import FORMAT_ANTHROPIC_MESSAGES from ccproxy.core.plugins.hooks import HookEvent, HookManager from ccproxy.core.plugins.hooks.base import HookContext from ccproxy.llms.streaming.accumulators import StreamAccumulator @@ -233,6 +234,7 @@ async def body_generator() -> AsyncGenerator[bytes, None]: async def _emit_error_sse( error_obj: dict[str, Any], ) -> AsyncGenerator[bytes, None]: + error_obj = self._format_stream_error(error_obj) adapted: dict[str, Any] | None = None try: if self.handler_config and self.handler_config.response_adapter: @@ -840,6 +842,23 @@ async def _serialize_json_to_sse_stream( ): yield chunk + def _format_stream_error(self, error_obj: dict[str, Any]) -> dict[str, Any]: + """Normalize streaming error payloads for client-specific SSE schemas.""" + if isinstance(error_obj, dict) and error_obj.get("type"): + return error_obj + + format_chain = ( + self.request_context.format_chain + if self.request_context and self.request_context.format_chain + else [] + ) + client_format = format_chain[0] if format_chain else None + + if client_format == FORMAT_ANTHROPIC_MESSAGES: + return {"type": "error", "error": error_obj.get("error", error_obj)} + + return error_obj + def _record_tool_event(self, event_name: str, payload: Any) -> None: if not self._stream_accumulator or not isinstance(payload, dict): return diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..89fad0c5 --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ +"""Shared pytest fixture package.""" diff --git a/tests/fixtures/external_apis/anthropic_api.py b/tests/fixtures/external_apis/anthropic_api.py index 1e96060b..221042c9 100644 --- a/tests/fixtures/external_apis/anthropic_api.py +++ b/tests/fixtures/external_apis/anthropic_api.py @@ -11,6 +11,7 @@ import httpx import pytest from pytest_httpx import HTTPXMock + from tests.helpers.sample_loader import ( load_sample, response_content_from_sample, diff --git a/tests/fixtures/external_apis/codex_api.py b/tests/fixtures/external_apis/codex_api.py index dbbe7d23..7b623541 100644 --- a/tests/fixtures/external_apis/codex_api.py +++ b/tests/fixtures/external_apis/codex_api.py @@ -8,6 +8,7 @@ import httpx import pytest from pytest_httpx import HTTPXMock + from tests.helpers.sample_loader import ( load_sample, response_content_from_sample, diff --git a/tests/fixtures/external_apis/copilot_api.py b/tests/fixtures/external_apis/copilot_api.py index 1fb8177c..985b86e3 100644 --- a/tests/fixtures/external_apis/copilot_api.py +++ b/tests/fixtures/external_apis/copilot_api.py @@ -8,6 +8,7 @@ import httpx import pytest from pytest_httpx import HTTPXMock + from tests.helpers.sample_loader import ( load_sample, response_content_from_sample, diff --git a/tests/plugins/claude_api/unit/test_native_anthropic_sanitization.py b/tests/plugins/claude_api/unit/test_native_anthropic_sanitization.py new file mode 100644 index 00000000..6db9ca9f --- /dev/null +++ b/tests/plugins/claude_api/unit/test_native_anthropic_sanitization.py @@ -0,0 +1,591 @@ +"""Test native Anthropic request sanitization for orphaned tool blocks. + +This test module verifies that native Anthropic format requests (sent to /v1/messages) +properly sanitize orphaned tool blocks that don't have matching counterparts. + +Two types of orphaned blocks are handled: + +1. Orphaned tool_result blocks (tool_result without matching tool_use): + - Occurs when conversation is compacted, removing old tool_use blocks + - tool_result blocks remain without their corresponding tool_use blocks + - API rejects with: "unexpected tool_use_id found in tool_result blocks" + +2. Orphaned tool_use blocks (tool_use without matching tool_result): + - Occurs when conversation is compacted, removing tool_result blocks + - tool_use blocks remain without their corresponding tool_result blocks + - API rejects with: "tool_use ids were found without tool_result blocks immediately after" + +The fix applies _sanitize_tool_results() in the claude_api and claude_sdk adapters' +prepare_provider_request() method before forwarding to the Anthropic API. +""" + +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results + + +Message = dict[str, Any] + + +class TestNativeAnthropicSanitization: + """Test sanitization of native Anthropic requests with orphaned tool_result blocks.""" + + def test_orphaned_tool_result_removed_from_native_request(self): + """Native Anthropic request with orphaned tool_result should be sanitized. + + This reproduces the exact error reported: + "unexpected tool_use_id found in tool_result blocks: toolu_019M2sPZmfSNC57WBuV9NaRb" + """ + # Simulate a compacted conversation where the tool_use was summarized + # but the tool_result remains with its original ID + messages: list[Message] = [ + {"role": "user", "content": "Search for files matching *.py"}, + { + "role": "assistant", + "content": "I'll search for Python files.", # tool_use was compacted out + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_019M2sPZmfSNC57WBuV9NaRb", # orphaned! + "content": "Found 15 Python files", + }, + {"type": "text", "text": "Thanks, now please analyze them"}, + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # The orphaned tool_result should be converted to text + assert len(result) == 3 + user_msg = result[2] + assert user_msg["role"] == "user" + + # Should have text blocks but no tool_result + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 0 + + # Original text should be preserved + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + assert len(text_blocks) == 2 # original text + converted orphan + + # Check orphan was converted to informative text + orphan_text_block = text_blocks[0] # inserted at beginning + assert ( + "Previous tool results from compacted history" in orphan_text_block["text"] + ) + assert "toolu_019M2sPZmfSNC57WBuV9NaRb" in orphan_text_block["text"] + + def test_valid_tool_result_preserved_in_native_request(self): + """Native Anthropic request with valid tool_result should pass through unchanged.""" + messages: list[Message] = [ + {"role": "user", "content": "Search for files"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll search for files."}, + { + "type": "tool_use", + "id": "toolu_valid123", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid123", # matches tool_use above + "content": "Found 15 files", + } + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # Messages should be unchanged + assert len(result) == 3 + user_msg = result[2] + + # tool_result should be preserved + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 1 + assert tool_results[0]["tool_use_id"] == "toolu_valid123" + + def test_mixed_valid_and_orphaned_tool_results(self): + """Request with both valid and orphaned tool_results should keep valid ones.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_valid", + "name": "read", + "input": {"path": "file.py"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", # valid + "content": "file contents", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan_from_compaction", # orphaned + "content": "old result", + }, + ], + }, + ] + + result = _sanitize_tool_results(messages) + + user_msg = result[1] + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + + # Only valid tool_result should remain + assert len(tool_results) == 1 + assert tool_results[0]["tool_use_id"] == "toolu_valid" + assert len(text_blocks) == 0 + + def test_superdesign_conversation_compaction_scenario(self): + """Reproduce the SuperDesign VS Code extension compaction scenario. + + SuperDesign uses @ai-sdk/anthropic which sends native Anthropic format. + When the conversation gets long, it compacts history, removing old messages + but sometimes leaving orphaned tool_result blocks. + """ + # Simulated compacted conversation from SuperDesign + messages: list[Message] = [ + # Earlier context was compacted into a summary + {"role": "user", "content": "Help me build a React component"}, + { + "role": "assistant", + "content": "[Summary: Previously searched for files and read component code]", + }, + # User message still has tool_results from before compaction + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_old_glob_call", + "content": "src/components/Button.tsx\nsrc/components/Modal.tsx", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_old_read_call", + "content": "export const Button = () => ", + }, + {"type": "text", "text": "Now create a new Header component"}, + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # Orphaned tool_results should be converted to text + user_msg = result[2] + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 0 + + # Content should be preserved as text + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + assert len(text_blocks) == 2 # original text + orphan summary + + # User's actual request should be there + assert any("Header component" in b["text"] for b in text_blocks) + + def test_empty_messages_handled(self): + """Empty messages list should return empty list.""" + assert _sanitize_tool_results([]) == [] + + def test_messages_without_tool_content_unchanged(self): + """Messages without any tool-related content pass through unchanged.""" + messages: list[Message] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + result = _sanitize_tool_results(messages) + assert result == messages + + def test_long_orphan_content_truncated(self): + """Long orphaned tool_result content should be truncated to 500 chars.""" + long_content = "x" * 1000 + messages: list[Message] = [ + {"role": "assistant", "content": "No tool_use here"}, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_orphan", + "content": long_content, + } + ], + }, + ] + + result = _sanitize_tool_results(messages) + + user_msg = result[1] + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + orphan_text = text_blocks[0]["text"] + + # Should be truncated with "..." + assert "..." in orphan_text + # Should not contain full 1000 chars + assert len(orphan_text) < 1000 + + +class TestClaudeAPIAdapterSanitization: + """Test that the claude_api adapter properly applies sanitization.""" + + @pytest.mark.asyncio + async def test_adapter_sanitizes_native_anthropic_request(self): + """The adapter's prepare_provider_request should sanitize messages. + + This test directly verifies the sanitization is applied by checking + the code path rather than instantiating the full adapter. + """ + # Test the sanitization function directly with the message format + # that would come from a native Anthropic request + messages: list[Message] = [ + {"role": "assistant", "content": "Summary of previous work"}, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_orphan", + "content": "old result", + }, + {"type": "text", "text": "Continue please"}, + ], + }, + ] + + # Simulate what the adapter does + sanitized = _sanitize_tool_results(messages) + + # Check that orphaned tool_result was sanitized + user_msg = sanitized[1] + tool_results = [ + b + for b in user_msg["content"] + if isinstance(b, dict) and b.get("type") == "tool_result" + ] + assert len(tool_results) == 0 + + # Verify the fix is properly imported in the adapter module + from ccproxy.plugins.claude_api import adapter as api_adapter + + assert hasattr(api_adapter, "_sanitize_tool_results") + + +class TestClaudeSDKAdapterSanitization: + """Test that the claude_sdk adapter properly applies sanitization.""" + + def test_sdk_adapter_has_sanitization_import(self): + """Verify the SDK adapter imports the sanitization function.""" + from ccproxy.plugins.claude_sdk import adapter + + assert hasattr(adapter, "_sanitize_tool_results") + + +class TestOrphanedToolUseSanitization: + """Test sanitization of orphaned tool_use blocks (tool_use without matching tool_result). + + This addresses the error: + "tool_use ids were found without tool_result blocks immediately after: . + Each tool_use block must have a corresponding tool_result block in the next message." + """ + + def test_orphaned_tool_use_removed_from_assistant_message(self): + """Assistant message with tool_use but no matching tool_result should be sanitized.""" + messages: list[Message] = [ + {"role": "user", "content": "Search for files"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll search for files."}, + { + "type": "tool_use", + "id": "toolu_orphan123", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + ], + }, + { + "role": "user", + "content": "please continue", # No tool_result for the tool_use above + }, + ] + + result = _sanitize_tool_results(messages) + + # The tool_use should be converted to text + assistant_msg = result[1] + assert isinstance(assistant_msg["content"], list) + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 0 + + # Should have text describing the orphaned tool call + text_blocks = [b for b in assistant_msg["content"] if b.get("type") == "text"] + assert len(text_blocks) >= 1 + combined_text = " ".join(b["text"] for b in text_blocks) + assert ( + "glob" in combined_text + or "Tool calls from compacted history" in combined_text + ) + + def test_valid_tool_use_with_result_preserved(self): + """tool_use with matching tool_result should be preserved.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll search."}, + { + "type": "tool_use", + "id": "toolu_valid", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", + "content": "Found 10 files", + } + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # Both should be preserved + assistant_msg = result[0] + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 1 + assert tool_uses[0]["id"] == "toolu_valid" + + user_msg = result[1] + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 1 + + def test_mixed_valid_and_orphaned_tool_uses(self): + """Message with both valid and orphaned tool_uses should keep only valid ones.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_valid", + "name": "read", + "input": {"path": "file.py"}, + }, + { + "type": "tool_use", + "id": "toolu_orphan", + "name": "write", + "input": {"path": "out.py"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", + "content": "file contents", + } + # No tool_result for toolu_orphan + ], + }, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 1 + assert tool_uses[0]["id"] == "toolu_valid" + + def test_superdesign_compaction_scenario_with_orphaned_tool_use(self): + """Reproduce SuperDesign compaction where tool_use remains but tool_result is lost. + + This is the exact error reported: + "tool_use ids were found without tool_result blocks immediately after: toolu_01YJquBpATfUqskN381pdJdP" + """ + messages: list[Message] = [ + {"role": "user", "content": "Help me build a component"}, + { + "role": "assistant", + "content": "[Summary: Previously searched for files]", # Compacted + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me search for more files."}, + { + "type": "tool_use", + "id": "toolu_01YJquBpATfUqskN381pdJdP", # Exact ID from error + "name": "glob", + "input": {"pattern": "src/**/*.tsx"}, + }, + ], + }, + { + "role": "user", + "content": "please continue", # User message without tool_result + }, + ] + + result = _sanitize_tool_results(messages) + + # The orphaned tool_use should be removed/converted + for msg in result: + if msg.get("role") == "assistant" and isinstance(msg.get("content"), list): + tool_uses = [ + b + for b in msg["content"] + if isinstance(b, dict) and b.get("type") == "tool_use" + ] + assert len(tool_uses) == 0, f"Found orphaned tool_use: {tool_uses}" + + def test_tool_use_only_message_converted_to_text(self): + """Assistant message with only orphaned tool_use should become text.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_only_orphan", + "name": "write", + "input": {"path": "test.py", "content": "print('hello')"}, + } + ], + }, + {"role": "user", "content": "continue"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + # Content should be text (either string or list with text block) + content = assistant_msg["content"] + if isinstance(content, list): + tool_uses = [b for b in content if b.get("type") == "tool_use"] + assert len(tool_uses) == 0 + text_blocks = [b for b in content if b.get("type") == "text"] + assert len(text_blocks) >= 1 + assert "write" in text_blocks[0]["text"] + else: + # String content is also acceptable + assert isinstance(content, str) + + def test_long_tool_input_truncated(self): + """Long orphaned tool_use input should be truncated to 200 chars.""" + long_input = {"content": "x" * 500} + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_long", + "name": "write", + "input": long_input, + } + ], + }, + {"role": "user", "content": "continue"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + text_blocks = [b for b in assistant_msg["content"] if b.get("type") == "text"] + orphan_text = text_blocks[0]["text"] + + # Should be truncated with "..." + assert "..." in orphan_text + + def test_multiple_consecutive_orphaned_tool_uses(self): + """Multiple orphaned tool_use blocks should all be converted.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll do multiple operations."}, + { + "type": "tool_use", + "id": "toolu_1", + "name": "read", + "input": {"path": "a.py"}, + }, + { + "type": "tool_use", + "id": "toolu_2", + "name": "read", + "input": {"path": "b.py"}, + }, + { + "type": "tool_use", + "id": "toolu_3", + "name": "write", + "input": {"path": "c.py"}, + }, + ], + }, + {"role": "user", "content": "okay, what's next?"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 0 + + # All three should be mentioned in the text + text_blocks = [b for b in assistant_msg["content"] if b.get("type") == "text"] + combined_text = " ".join(b["text"] for b in text_blocks) + assert "read" in combined_text + assert "write" in combined_text diff --git a/tests/plugins/claude_sdk/integration/test_sdk_compaction_behavior.py b/tests/plugins/claude_sdk/integration/test_sdk_compaction_behavior.py new file mode 100644 index 00000000..a0e68fc3 --- /dev/null +++ b/tests/plugins/claude_sdk/integration/test_sdk_compaction_behavior.py @@ -0,0 +1,332 @@ +"""Integration tests for SDK compaction behavior with tool blocks. + +These tests verify whether the Claude Agent SDK's automatic context compaction +creates orphaned tool_use/tool_result blocks that cause API errors. + +Tests require: +- Claude CLI installed and authenticated +- Real API calls (uses actual tokens) +- Sufficient context to trigger auto-compaction + +Run with: RUN_SDK_INTEGRATION=1 pytest tests/plugins/claude_sdk/integration/test_sdk_compaction_behavior.py -v -s +""" + +import json +import os +from pathlib import Path +from typing import Any + +import pytest + + +# Mark all tests in this module as requiring SDK integration +pytestmark = [ + pytest.mark.integration, + pytest.mark.real_api, # Uses real_api marker for tests requiring actual SDK/API calls + pytest.mark.skipif( + not os.environ.get("RUN_SDK_INTEGRATION"), + reason="SDK integration tests require RUN_SDK_INTEGRATION=1", + ), +] + + +class TestSessionFileOrphanAnalysis: + """Analyze existing session JSONL files for orphaned tool blocks. + + This test doesn't make API calls - it inspects existing session files + to look for evidence of orphaned tool_use/tool_result pairs. + """ + + def _extract_tool_blocks( + self, + msg: dict[str, Any], + tool_uses: dict[str, dict[str, Any]], + tool_results: dict[str, dict[str, Any]], + ) -> None: + """Extract tool_use and tool_result blocks from a message.""" + content = msg.get("content", []) + if isinstance(content, str): + return + + if not isinstance(content, list): + return + + for block in content: + if isinstance(block, dict): + if block.get("type") == "tool_use": + tool_id = block.get("id") + if tool_id: + tool_uses[tool_id] = block + elif block.get("type") == "tool_result": + tool_use_id = block.get("tool_use_id") + if tool_use_id: + tool_results[tool_use_id] = block + + def _analyze_session_file(self, session_file: Path) -> dict[str, Any]: + """Analyze a single session file for tool block orphans.""" + tool_uses: dict[str, dict[str, Any]] = {} + tool_results: dict[str, dict[str, Any]] = {} + message_count = 0 + compaction_events = [] + + with session_file.open() as f: + for line_num, line in enumerate(f, 1): + if not line.strip(): + continue + try: + msg = json.loads(line) + message_count += 1 + + # Check for compaction boundary + if ( + msg.get("type") == "system" + and msg.get("subtype") == "compact_boundary" + ): + compaction_events.append( + { + "line": line_num, + "metadata": msg.get("compact_metadata", {}), + } + ) + + # Extract tool blocks + self._extract_tool_blocks(msg, tool_uses, tool_results) + + except json.JSONDecodeError: + continue + + # Find orphans + orphaned_results = set(tool_results.keys()) - set(tool_uses.keys()) + orphaned_uses = set(tool_uses.keys()) - set(tool_results.keys()) + + return { + "file": str(session_file), + "message_count": message_count, + "tool_use_count": len(tool_uses), + "tool_result_count": len(tool_results), + "compaction_events": compaction_events, + "orphaned_results": list(orphaned_results), + "orphaned_uses": list(orphaned_uses), + "has_orphans": bool(orphaned_results or orphaned_uses), + } + + @pytest.mark.unit + def test_inspect_session_history_for_orphans(self) -> None: + """Inspect session JSONL files for orphaned tool blocks. + + This test directly reads the session files to check for orphans + without making API calls. + """ + claude_projects = Path.home() / ".claude" / "projects" + + if not claude_projects.exists(): + pytest.skip("No Claude projects directory found at ~/.claude/projects") + + # Find all session files recursively + session_files = list(claude_projects.rglob("*.jsonl")) + + if not session_files: + pytest.skip("No session files found in ~/.claude/projects") + + # Sort by modification time, analyze most recent + session_files.sort(key=lambda p: p.stat().st_mtime, reverse=True) + + # Analyze up to 10 most recent sessions + results = [] + sessions_with_orphans = [] + sessions_with_compaction = [] + + for session_file in session_files[:10]: + analysis = self._analyze_session_file(session_file) + results.append(analysis) + + if analysis["has_orphans"]: + sessions_with_orphans.append(analysis) + if analysis["compaction_events"]: + sessions_with_compaction.append(analysis) + + # Report findings + print(f"\n{'=' * 60}") + print("SESSION FILE ANALYSIS") + print(f"{'=' * 60}") + print(f"Total sessions analyzed: {len(results)}") + print(f"Sessions with compaction events: {len(sessions_with_compaction)}") + print(f"Sessions with orphaned tool blocks: {len(sessions_with_orphans)}") + print() + + for r in results: + status = "HAS ORPHANS" if r["has_orphans"] else "OK" + compaction = ( + f"({len(r['compaction_events'])} compactions)" + if r["compaction_events"] + else "" + ) + print( + f" [{status}] {Path(r['file']).name}: " + f"{r['tool_use_count']} uses, {r['tool_result_count']} results {compaction}" + ) + + if r["has_orphans"]: + if r["orphaned_results"]: + print( + f" - Orphaned tool_results (no matching tool_use): {r['orphaned_results'][:3]}..." + ) + if r["orphaned_uses"]: + print( + f" - Orphaned tool_uses (no matching tool_result): {r['orphaned_uses'][:3]}..." + ) + + print() + + # Correlation analysis + if sessions_with_compaction: + compaction_with_orphans = [ + s for s in sessions_with_compaction if s["has_orphans"] + ] + print( + f"Compaction correlation: {len(compaction_with_orphans)}/{len(sessions_with_compaction)} " + f"sessions with compaction also have orphans" + ) + + # The test passes regardless - this is investigative + # But we flag if we found concerning patterns + if sessions_with_orphans: + print( + f"\nWARNING: Found {len(sessions_with_orphans)} session(s) with orphaned tool blocks!" + ) + for s in sessions_with_orphans: + if s["compaction_events"]: + print( + f" - {Path(s['file']).name}: Compaction may have created orphans" + ) + + +class TestSDKCompactionTrigger: + """Tests that require actual SDK interactions to trigger compaction.""" + + @pytest.mark.asyncio + async def test_compaction_creates_orphans_hypothesis(self) -> None: + """Test hypothesis: SDK compaction creates orphaned tool blocks. + + This test: + 1. Creates a session with tool_use/tool_result exchanges + 2. Fills context until auto-compaction triggers + 3. Verifies whether orphaned tool blocks appear + + Note: This test is expensive (uses many tokens) and slow. + """ + # Import SDK components + try: + from claude_agent_sdk import ClaudeAgentOptions, query + except ImportError: + pytest.skip("claude_agent_sdk not installed") + + import uuid + + # Claude CLI requires proper UUID format for --resume + session_id = str(uuid.uuid4()) + num_turns = 20 # Start with fewer turns, increase if needed + compaction_detected = False + error_messages: list[str] = [] + + print(f"\nStarting compaction test with session: {session_id}") + print(f"Will attempt {num_turns} turns to trigger compaction") + + for turn in range(num_turns): + # Create a prompt that generates substantial content + # Include mention of tools to encourage tool-like patterns in response + prompt = f"""Turn {turn + 1}/{num_turns}: + +Please provide a detailed technical analysis (at least 500 words) about: +- System architecture patterns for turn {turn} +- Include code examples with detailed comments +- Discuss trade-offs and implementation considerations + +This is for testing context window behavior.""" + + options = ClaudeAgentOptions( + resume=session_id if turn > 0 else None, + max_turns=1, + ) + + try: + messages_this_turn = [] + async for message in query(prompt=prompt, options=options): + messages_this_turn.append(message) + + # Check for compaction boundary + if hasattr(message, "type") and message.type == "system": + if ( + hasattr(message, "subtype") + and message.subtype == "compact_boundary" + ): + compaction_detected = True + print(f"\n*** COMPACTION DETECTED at turn {turn + 1} ***") + if hasattr(message, "compact_metadata"): + print(f" Metadata: {message.compact_metadata}") + + print( + f" Turn {turn + 1} completed: {len(messages_this_turn)} messages" + ) + + except Exception as e: + error_str = str(e) + error_messages.append(error_str) + + if "tool_use_id" in error_str.lower(): + print(f"\n*** ORPHAN ERROR DETECTED at turn {turn + 1} ***") + print(f" Error: {error_str}") + + # This confirms the hypothesis + pytest.fail( + f"SDK compaction created orphaned tool blocks at turn {turn + 1}: {error_str}" + ) + else: + print(f" Turn {turn + 1} error (non-orphan): {error_str[:100]}") + + # Summary + print(f"\n{'=' * 60}") + print("TEST SUMMARY") + print(f"{'=' * 60}") + print(f"Turns completed: {num_turns}") + print(f"Compaction detected: {compaction_detected}") + print(f"Errors encountered: {len(error_messages)}") + + if not compaction_detected: + print("\nNote: Compaction did not trigger. To trigger compaction:") + print(" - Increase num_turns (currently {num_turns})") + print(" - Use prompts that generate more output") + print(" - Or manually check ~/.claude/projects for session files") + + @pytest.mark.asyncio + async def test_manual_orphan_injection(self) -> None: + """Test how SDK handles a session file with injected orphans. + + This test: + 1. Creates a session file with orphaned tool blocks + 2. Attempts to resume the session + 3. Observes whether SDK/API errors occur + """ + try: + from claude_agent_sdk import ClaudeAgentOptions, query + except ImportError: + pytest.skip("claude_agent_sdk not installed") + + # This test would require manipulating session files directly + # which is risky and may not be reproducible + pytest.skip("Manual orphan injection test not yet implemented") + + +class TestCompareSDKvsAPI: + """Compare tool handling between SDK and direct API paths.""" + + @pytest.mark.asyncio + async def test_identical_request_different_path(self) -> None: + """Send identical requests through SDK and API, compare behavior. + + This tests whether the SDK and API handle tool blocks identically + when given the same input. + """ + pytest.skip( + "Comparative test not yet implemented - requires both adapters running" + ) diff --git a/tests/test_tool_result_sanitization.py b/tests/test_tool_result_sanitization.py index d902a4ea..6b9a713f 100644 --- a/tests/test_tool_result_sanitization.py +++ b/tests/test_tool_result_sanitization.py @@ -1,4 +1,4 @@ -"""Test _sanitize_tool_results method for removing orphaned tool_result blocks. +"""Test _sanitize_tool_results method for orphaned tool blocks. This module tests the bug fix for orphaned tool_result blocks that occur when conversation history is compacted. When tool_use blocks are removed during @@ -9,8 +9,9 @@ The _sanitize_tool_results method fixes this by: 1. Removing orphaned tool_result blocks that don't have matching tool_use blocks in the immediately preceding assistant message -2. Converting orphaned results to text blocks to preserve information -3. Keeping valid tool_result blocks that have matching tool_use blocks +2. Removing tool_use blocks that don't have a tool_result in the *next* user message +3. Converting orphaned blocks to text to preserve information +4. Keeping valid tool_use/tool_result pairs Real-world scenario: - A long conversation with multiple tool calls gets compacted to stay within token limits @@ -109,7 +110,7 @@ def test_valid_tool_result_preserved(self, mock_logger: Mock) -> None: - User message with tool_result(tool_use_id="tool_123") Result: tool_result should be kept unchanged """ - messages = [ + messages: list[dict[str, Any]] = [ create_assistant_with_tool_use( "I'll help you with that.", [{"id": "tool_123", "name": "calculator", "input": {"x": 5}}], @@ -137,7 +138,7 @@ def test_orphaned_tool_result_removed(self, mock_logger: Mock) -> None: - NO preceding assistant with matching tool_use Result: tool_result should be removed and converted to text """ - messages = [ + messages: list[dict[str, Any]] = [ create_user_text_message("Hello"), create_user_with_tool_result( [{"tool_use_id": "orphan_456", "content": "orphaned result"}] @@ -163,9 +164,9 @@ def test_mixed_valid_and_orphaned(self, mock_logger: Mock) -> None: Scenario: Partial compaction - Assistant with tool_use(id="valid_1") - User with tool_result(tool_use_id="valid_1") AND tool_result(tool_use_id="orphan_2") - Result: valid_1 kept, orphan_2 converted to text + Result: valid_1 kept, orphan_2 dropped (no text injected) """ - messages = [ + messages: list[dict[str, Any]] = [ create_assistant_with_tool_use( "Let me check that.", [{"id": "valid_1", "name": "search", "input": {"query": "test"}}], @@ -183,17 +184,10 @@ def test_mixed_valid_and_orphaned(self, mock_logger: Mock) -> None: assert len(result) == 2 user_content = result[1]["content"] - # Should have text block (from orphaned) + valid tool_result - assert len(user_content) == 2 - - # First should be text block with orphaned info - assert user_content[0]["type"] == "text" - assert "Previous tool results" in user_content[0]["text"] - assert "orphan_2" in user_content[0]["text"] - - # Second should be the valid tool_result - assert user_content[1]["type"] == "tool_result" - assert user_content[1]["tool_use_id"] == "valid_1" + # Only the valid tool_result should remain + assert len(user_content) == 1 + assert user_content[0]["type"] == "tool_result" + assert user_content[0]["tool_use_id"] == "valid_1" # Should log warning about orphaned result mock_logger.warning.assert_called_once() @@ -206,7 +200,7 @@ def test_multiple_tool_uses_preserved(self, mock_logger: Mock) -> None: - User with tool_result for all three Result: all three should be preserved """ - messages = [ + messages: list[dict[str, Any]] = [ create_assistant_with_tool_use( "I'll use three tools.", [ @@ -238,6 +232,150 @@ def test_multiple_tool_uses_preserved(self, mock_logger: Mock) -> None: # No warnings should be logged mock_logger.warning.assert_not_called() + def test_orphaned_tool_use_removed_when_no_next_message( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when no following user message exists.""" + messages: list[dict[str, Any]] = [ + create_assistant_with_tool_use( + "I'll need to run a tool.", + [{"id": "tool_1", "name": "helper", "input": {"q": "test"}}], + ) + ] + + result = _sanitize_tool_results(messages) + + assert len(result) == 1 + assert result[0]["role"] == "assistant" + content = result[0]["content"] + assert isinstance(content, list) + assert all(block.get("type") != "tool_use" for block in content) + assert "Tool calls from compacted history" in content[0]["text"] + mock_logger.warning.assert_called_once() + + def test_orphaned_tool_use_removed_when_next_user_missing_result( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when the next user message lacks tool_result.""" + messages: list[dict[str, Any]] = [ + create_assistant_with_tool_use( + "Let me check.", + [{"id": "tool_1", "name": "search", "input": {"q": "test"}}], + ), + create_user_text_message("Thanks!"), + ] + + result = _sanitize_tool_results(messages) + + assert len(result) == 2 + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert "Tool calls from compacted history" in assistant_content[0]["text"] + assert result[1] == messages[1] + mock_logger.warning.assert_called_once() + + def test_tool_use_removed_when_next_user_only_has_orphaned_results( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when next user only includes unrelated tool_results.""" + messages: list[dict[str, Any]] = [ + create_assistant_with_tool_use( + "Let me check.", + [{"id": "tool_keep", "name": "search", "input": {"q": "test"}}], + ), + create_user_with_tool_result( + [{"tool_use_id": "orphan_only", "content": "old result"}] + ), + ] + + result = _sanitize_tool_results(messages) + + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert "Tool calls from compacted history" in assistant_content[0]["text"] + + user_content = result[1]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Previous tool results" in user_content[0]["text"] + + assert mock_logger.warning.call_count == 2 + + def test_duplicate_tool_use_pruned_to_match_result_count( + self, mock_logger: Mock + ) -> None: + """Remove extra tool_use blocks when tool_result count is lower. + + Scenario: Assistant has duplicate tool_use IDs; user has a single result. + Result: Keep only one tool_use and one tool_result. + """ + messages: list[dict[str, Any]] = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Calling tool twice."}, + { + "type": "tool_use", + "id": "dup_tool", + "name": "search", + "input": {}, + }, + { + "type": "tool_use", + "id": "dup_tool", + "name": "search", + "input": {}, + }, + ], + }, + create_user_with_tool_result( + [{"tool_use_id": "dup_tool", "content": "result once"}] + ), + ] + + result = _sanitize_tool_results(messages) + + assistant_content = result[0]["content"] + tool_use_blocks = [b for b in assistant_content if b.get("type") == "tool_use"] + assert len(tool_use_blocks) == 1 + + user_content = result[1]["content"] + tool_result_blocks = [b for b in user_content if b.get("type") == "tool_result"] + assert len(tool_result_blocks) == 1 + + mock_logger.warning.assert_called_once() + + def test_tool_use_removed_when_result_not_immediately_after( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when tool_result is not in the next message.""" + messages = [ + create_assistant_with_tool_use( + "Calling a tool.", + [{"id": "tool_1", "name": "lookup", "input": {"q": "test"}}], + ), + create_assistant_text_message("Continuing without result."), + create_user_with_tool_result( + [{"tool_use_id": "tool_1", "content": "late result"}] + ), + ] + + result = _sanitize_tool_results(messages) + + assert len(result) == 3 + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert "Tool calls from compacted history" in assistant_content[0]["text"] + + user_content = result[2]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Previous tool results" in user_content[0]["text"] + mock_logger.warning.assert_called() + def test_conversation_compaction_scenario(self, mock_logger: Mock) -> None: """Test the real bug scenario: conversation compaction leaves orphaned results. @@ -275,16 +413,10 @@ def test_conversation_compaction_scenario(self, mock_logger: Mock) -> None: assert len(result) == 3 user_content = result[2]["content"] - # Should have text block (orphaned) + valid tool_result - assert len(user_content) == 2 - - # First is text with orphaned info - assert user_content[0]["type"] == "text" - assert "original_tool" in user_content[0]["text"] - - # Second is valid tool_result - assert user_content[1]["type"] == "tool_result" - assert user_content[1]["tool_use_id"] == "new_tool" + # Only the valid tool_result should remain + assert len(user_content) == 1 + assert user_content[0]["type"] == "tool_result" + assert user_content[0]["tool_use_id"] == "new_tool" # Should log warning mock_logger.warning.assert_called_once() @@ -521,7 +653,7 @@ def test_partial_match_orphaned(self, mock_logger: Mock) -> None: Scenario: Multiple results, only some have matching tool_use - Assistant with tool_use(id="valid_1") and tool_use(id="valid_2") - User with results for "valid_1", "orphan_3", "valid_2" - Result: valid_1 and valid_2 kept, orphan_3 converted to text + Result: valid_1 and valid_2 kept, orphan_3 dropped """ messages = [ create_assistant_with_tool_use( @@ -543,18 +675,12 @@ def test_partial_match_orphaned(self, mock_logger: Mock) -> None: result = _sanitize_tool_results(messages) user_content = result[1]["content"] - # Should have text block + 2 valid results - assert len(user_content) == 3 - - # First is text with orphaned info - assert user_content[0]["type"] == "text" - assert "orphan_3" in user_content[0]["text"] - - # Other two are valid results + # Only the valid tool_results should remain + assert len(user_content) == 2 + assert user_content[0]["type"] == "tool_result" + assert user_content[0]["tool_use_id"] == "valid_1" assert user_content[1]["type"] == "tool_result" - assert user_content[1]["tool_use_id"] == "valid_1" - assert user_content[2]["type"] == "tool_result" - assert user_content[2]["tool_use_id"] == "valid_2" + assert user_content[1]["tool_use_id"] == "valid_2" def test_assistant_with_string_content_no_tool_use(self, mock_logger: Mock) -> None: """Test assistant message with string content (no tool_use blocks). @@ -611,3 +737,63 @@ def test_multiple_orphaned_conversions(self, mock_logger: Mock) -> None: assert "result one" in text_block assert "result two" in text_block assert "result three" in text_block + + def test_assistant_dict_content_tool_use_removed(self, mock_logger: Mock) -> None: + """Handle assistant content supplied as a dict (single tool_use block). + + Scenario: Non-list content with tool_use and no following tool_result + Result: tool_use should be converted to text and content normalized to list + """ + messages: list[dict[str, Any]] = [ + { + "role": "assistant", + "content": { + "type": "tool_use", + "id": "tool_dict", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + }, + {"role": "user", "content": "continue"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert any( + block.get("type") == "text" + for block in assistant_content + if isinstance(block, dict) + ) + + mock_logger.warning.assert_called_once() + + def test_user_dict_content_tool_result_converted(self, mock_logger: Mock) -> None: + """Handle user content supplied as a dict (single tool_result block). + + Scenario: Orphaned tool_result provided as a dict + Result: tool_result should be converted to text and content normalized to list + """ + messages: list[dict[str, Any]] = [ + {"role": "assistant", "content": "No tools here"}, + { + "role": "user", + "content": { + "type": "tool_result", + "tool_use_id": "orphan_dict", + "content": "old result", + }, + }, + ] + + result = _sanitize_tool_results(messages) + + user_content = result[1]["content"] + assert isinstance(user_content, list) + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "orphan_dict" in user_content[0]["text"] + + mock_logger.warning.assert_called_once() diff --git a/tests/unit/llms/test_context_window_management.py b/tests/unit/llms/test_context_window_management.py new file mode 100644 index 00000000..0b366ac0 --- /dev/null +++ b/tests/unit/llms/test_context_window_management.py @@ -0,0 +1,290 @@ +"""Tests for context window management utilities. + +This module tests token estimation and context truncation logic +for managing requests that exceed model context limits. +""" + +import pytest + +from ccproxy.llms.utils import ( + estimate_messages_tokens, + estimate_request_tokens, + estimate_tokens, + get_max_input_tokens, + truncate_to_fit, +) + + +class TestTokenEstimation: + """Tests for token estimation functions.""" + + def test_estimate_tokens_string(self) -> None: + """Test token estimation for plain strings.""" + # ~3 chars per token + text = "Hello, world!" # 13 chars -> ~4 tokens + tokens = estimate_tokens(text) + assert tokens >= 1 + assert tokens <= 10 # Reasonable upper bound + + def test_estimate_tokens_empty_string(self) -> None: + """Test token estimation for empty string.""" + tokens = estimate_tokens("") + assert tokens == 1 # min(1, ...) + + def test_estimate_tokens_none(self) -> None: + """Test token estimation for None.""" + tokens = estimate_tokens(None) + assert tokens == 0 + + def test_estimate_tokens_text_block(self) -> None: + """Test token estimation for text content block.""" + content = [ + {"type": "text", "text": "This is a test message with some content."} + ] + tokens = estimate_tokens(content) + assert tokens > 0 + + def test_estimate_tokens_tool_use_block(self) -> None: + """Test token estimation for tool_use content block.""" + content = [ + { + "type": "tool_use", + "id": "tool_123", + "name": "read_file", + "input": {"path": "/some/file/path.txt"}, + } + ] + tokens = estimate_tokens(content) + assert tokens > 0 + + def test_estimate_tokens_tool_result_block(self) -> None: + """Test token estimation for tool_result content block.""" + content = [ + { + "type": "tool_result", + "tool_use_id": "tool_123", + "content": "File contents here with some data.", + } + ] + tokens = estimate_tokens(content) + assert tokens > 0 + + def test_estimate_tokens_image_block(self) -> None: + """Test token estimation for image content block.""" + content = [{"type": "image", "source": {"type": "base64", "data": "..."}}] + tokens = estimate_tokens(content) + assert tokens == 1600 # Fixed estimate for images + + def test_estimate_tokens_mixed_content(self) -> None: + """Test token estimation for mixed content blocks.""" + content = [ + {"type": "text", "text": "Check this image:"}, + {"type": "image", "source": {"type": "url", "url": "https://..."}}, + {"type": "text", "text": "What do you see?"}, + ] + tokens = estimate_tokens(content) + assert tokens > 1600 # At least the image tokens + + def test_estimate_messages_tokens_single(self) -> None: + """Test token estimation for a single message.""" + messages = [{"role": "user", "content": "Hello, how are you?"}] + tokens = estimate_messages_tokens(messages) + assert tokens > 2 # At least role tokens + + def test_estimate_messages_tokens_conversation(self) -> None: + """Test token estimation for a conversation.""" + messages = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "And 3+3?"}, + {"role": "assistant", "content": "3+3 equals 6."}, + ] + tokens = estimate_messages_tokens(messages) + assert tokens > 8 # At least role tokens (2 per message) + + def test_estimate_request_tokens_with_system(self) -> None: + """Test token estimation for request with system prompt.""" + request = { + "model": "claude-3-opus-20240229", + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "Hello"}], + } + tokens = estimate_request_tokens(request) + assert tokens > 0 + + def test_estimate_request_tokens_with_tools(self) -> None: + """Test token estimation for request with tools.""" + request = { + "model": "claude-3-opus-20240229", + "messages": [{"role": "user", "content": "Read a file"}], + "tools": [ + { + "name": "read_file", + "description": "Read contents of a file", + "input_schema": { + "type": "object", + "properties": {"path": {"type": "string"}}, + }, + } + ], + } + tokens = estimate_request_tokens(request) + assert tokens > 0 + + +class TestGetMaxInputTokens: + """Tests for max input tokens lookup.""" + + def test_get_max_input_tokens_known_model(self) -> None: + """Test getting max input tokens for a known model.""" + # This test may need adjustment based on what models are in the limits file + max_tokens = get_max_input_tokens("claude-3-opus-20240229") + # Should return a value if model is in limits + if max_tokens is not None: + assert max_tokens > 0 + + def test_get_max_input_tokens_unknown_model(self) -> None: + """Test getting max input tokens for an unknown model.""" + max_tokens = get_max_input_tokens("totally-unknown-model-xyz") + assert max_tokens is None + + +class TestTruncateToFit: + """Tests for context truncation.""" + + def test_truncate_no_truncation_needed(self) -> None: + """Test that small requests are not truncated.""" + request = { + "model": "claude-3-opus-20240229", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + } + result, was_truncated = truncate_to_fit( + request, max_input_tokens=200000, preserve_recent=10 + ) + assert was_truncated is False + assert result == request + + def test_truncate_removes_old_messages(self) -> None: + """Test that truncation removes oldest messages first.""" + # Create a request with many messages + messages = [] + for i in range(20): + messages.append({"role": "user", "content": f"Message {i} " * 100}) + messages.append({"role": "assistant", "content": f"Response {i} " * 100}) + + request = {"model": "claude-3-opus-20240229", "messages": messages} + + # Force truncation with a very low limit + result, was_truncated = truncate_to_fit( + request, max_input_tokens=1000, preserve_recent=4 + ) + + assert was_truncated is True + # Should have fewer messages + assert len(result["messages"]) < len(messages) + # Should preserve recent messages + assert len(result["messages"]) >= 4 + + def test_truncate_preserves_recent_messages(self) -> None: + """Test that recent messages are preserved during truncation.""" + messages = [ + {"role": "user", "content": "Old message " * 500}, + {"role": "assistant", "content": "Old response " * 500}, + {"role": "user", "content": "Recent message 1"}, + {"role": "assistant", "content": "Recent response 1"}, + {"role": "user", "content": "Recent message 2"}, + {"role": "assistant", "content": "Recent response 2"}, + ] + request = {"model": "claude-3-opus-20240229", "messages": messages} + + result, was_truncated = truncate_to_fit( + request, max_input_tokens=500, preserve_recent=4 + ) + + if was_truncated: + # Check that the last 4 messages are preserved + result_messages = result["messages"] + # Account for truncation notice being added + recent_messages = ( + result_messages[-4:] if len(result_messages) >= 4 else result_messages + ) + + # Verify recent content is in the preserved messages + recent_content = [m.get("content", "") for m in recent_messages] + assert any("Recent" in str(c) for c in recent_content) + + def test_truncate_adds_notice(self) -> None: + """Test that truncation adds a notice message.""" + messages = [] + for i in range(10): + messages.append({"role": "user", "content": f"Message {i} " * 200}) + messages.append({"role": "assistant", "content": f"Response {i} " * 200}) + + request = {"model": "claude-3-opus-20240229", "messages": messages} + + result, was_truncated = truncate_to_fit( + request, max_input_tokens=500, preserve_recent=2 + ) + + if was_truncated: + # First message should be the truncation notice + first_msg = result["messages"][0] + assert first_msg["role"] == "user" + assert "truncated" in first_msg["content"].lower() + + def test_truncate_not_enough_messages(self) -> None: + """Test truncation behavior when only one message exceeds the limit.""" + messages = [ + {"role": "user", "content": "Single message " * 1000}, + ] + request = {"model": "claude-3-opus-20240229", "messages": messages} + + # Try to truncate with preserve_recent > message count + result, was_truncated = truncate_to_fit( + request, max_input_tokens=100, preserve_recent=10 + ) + + # Should truncate and insert a notice if content can't fit + assert was_truncated is True + assert result["messages"][0]["role"] == "user" + assert "truncated" in result["messages"][0]["content"].lower() + + def test_truncate_preserves_system_and_tools(self) -> None: + """Test that system prompt and tools are preserved.""" + request = { + "model": "claude-3-opus-20240229", + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Old message " * 500}, + {"role": "assistant", "content": "Old response " * 500}, + {"role": "user", "content": "Recent message"}, + ], + "tools": [{"name": "test_tool", "description": "A test tool"}], + } + + result, was_truncated = truncate_to_fit( + request, max_input_tokens=500, preserve_recent=1 + ) + + # System and tools should be preserved regardless of truncation + assert result.get("system") == "You are a helpful assistant." + assert result.get("tools") == request["tools"] + + def test_truncate_safety_margin(self) -> None: + """Test that safety margin is applied correctly.""" + messages = [] + for i in range(5): + messages.append({"role": "user", "content": f"Message {i}"}) + + request = {"model": "claude-3-opus-20240229", "messages": messages} + + # With safety_margin=0.5, effective limit is 50000 + result, was_truncated = truncate_to_fit( + request, max_input_tokens=100000, preserve_recent=2, safety_margin=0.5 + ) + + # Should not truncate since content is small + assert was_truncated is False diff --git a/tests/unit/llms/test_truncation_creates_orphans.py b/tests/unit/llms/test_truncation_creates_orphans.py new file mode 100644 index 00000000..c7af4fd0 --- /dev/null +++ b/tests/unit/llms/test_truncation_creates_orphans.py @@ -0,0 +1,781 @@ +"""Tests reproducing the bug where truncation creates orphaned tool blocks. + +This module demonstrates the issue where: +1. Sanitization runs FIRST (removes existing orphaned tool blocks) +2. Truncation runs SECOND (can CREATE new orphaned blocks by removing messages) +3. The orphaned blocks hit the Anthropic API and cause errors like: + - "unexpected tool_use_id found in tool_result blocks" + - "tool_use ids were found without tool_result blocks" + +The fix should ensure sanitization runs AFTER truncation, or both before AND after. +""" + +from typing import Any + +import pytest + +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results +from ccproxy.llms.utils import truncate_to_fit + + +Message = dict[str, Any] + + +class TestTruncationCreatesOrphanedBlocks: + """Tests demonstrating the bug where truncation creates orphaned tool blocks.""" + + def _create_tool_use_assistant_message( + self, + tool_id: str, + tool_name: str = "read_file", + input_data: dict[str, Any] | None = None, + ) -> Message: + """Create an assistant message with a tool_use block.""" + return { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": tool_id, + "name": tool_name, + "input": input_data or {"path": "/some/file.txt"}, + } + ], + } + + def _create_tool_result_user_message( + self, tool_use_id: str, result_content: str = "File contents here" + ) -> Message: + """Create a user message with a tool_result block.""" + return { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": result_content, + } + ], + } + + def _create_conversation_with_tool_calls(self, num_pairs: int) -> list[Message]: + """Create a conversation with multiple tool_use/tool_result pairs. + + Each pair consists of: + 1. Assistant message with tool_use block + 2. User message with tool_result block + """ + messages: list[Message] = [] + # Start with initial user message + messages.append( + {"role": "user", "content": "Please help me analyze some files."} + ) + + for i in range(num_pairs): + tool_id = f"tool_{i:03d}" + # Assistant calls a tool + messages.append( + self._create_tool_use_assistant_message( + tool_id=tool_id, + tool_name="read_file", + input_data={"path": f"/file_{i}.txt"}, + ) + ) + # User provides tool result with substantial content to use up tokens + messages.append( + self._create_tool_result_user_message( + tool_use_id=tool_id, + result_content=f"This is the content of file {i}. " + * 50, # Make it substantial + ) + ) + + # End with a final user message + messages.append({"role": "user", "content": "Now summarize all the files."}) + return messages + + def _count_orphaned_tool_results(self, messages: list[Message]) -> int: + """Count tool_result blocks that don't have a matching tool_use in preceding message.""" + orphan_count = 0 + + for i, msg in enumerate(messages): + if msg.get("role") != "user": + continue + content = msg.get("content") + if not isinstance(content, list): + continue + + # Get valid tool_use IDs from preceding assistant message + valid_ids = set() + if i > 0 and messages[i - 1].get("role") == "assistant": + prev_content = messages[i - 1].get("content", []) + if isinstance(prev_content, list): + for block in prev_content: + if isinstance(block, dict) and block.get("type") == "tool_use": + valid_ids.add(block.get("id")) + + # Check for orphaned tool_results + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + if block.get("tool_use_id") not in valid_ids: + orphan_count += 1 + + return orphan_count + + def _count_orphaned_tool_uses(self, messages: list[Message]) -> int: + """Count tool_use blocks that don't have a matching tool_result in any subsequent message.""" + # Collect all tool_result IDs + result_ids = set() + for msg in messages: + if msg.get("role") == "user" and isinstance(msg.get("content"), list): + for block in msg["content"]: + if isinstance(block, dict) and block.get("type") == "tool_result": + result_ids.add(block.get("tool_use_id")) + + # Count tool_use blocks without results + orphan_count = 0 + for msg in messages: + if msg.get("role") == "assistant" and isinstance(msg.get("content"), list): + for block in msg["content"]: + if isinstance(block, dict) and block.get("type") == "tool_use": + if block.get("id") not in result_ids: + orphan_count += 1 + + return orphan_count + + def test_current_order_creates_orphans(self) -> None: + """REPRODUCTION TEST: Current order (sanitize → truncate) creates orphaned blocks. + + This test demonstrates the bug: + 1. Create a conversation with many tool_use/tool_result pairs + 2. Apply sanitization first (no orphans to remove initially) + 3. Apply truncation with low token limit (removes old messages) + 4. VERIFY: Orphaned tool_result blocks now exist (BUG!) + + The Anthropic API would reject this with: + "unexpected tool_use_id found in tool_result blocks" + """ + # Create conversation with 10 tool call pairs + messages = self._create_conversation_with_tool_calls(num_pairs=10) + + # Verify no orphans initially + initial_orphan_results = self._count_orphaned_tool_results(messages) + initial_orphan_uses = self._count_orphaned_tool_uses(messages) + assert initial_orphan_results == 0, "Should have no orphaned results initially" + assert initial_orphan_uses == 0, "Should have no orphaned uses initially" + + # STEP 1: Sanitization (current order - runs first) + sanitized_messages = _sanitize_tool_results(messages) + + # No orphans should be removed since there are none + assert self._count_orphaned_tool_results(sanitized_messages) == 0 + + # STEP 2: Truncation (current order - runs second) + request = { + "model": "claude-3-opus-20240229", + "messages": sanitized_messages, + } + + # Use a very low token limit to force significant truncation + truncated_request, was_truncated = truncate_to_fit( + request, + max_input_tokens=2000, # Very low to force truncation + preserve_recent=4, # Keep last 4 messages + safety_margin=0.9, + ) + + assert was_truncated, "Should have been truncated" + truncated_messages = truncated_request["messages"] + + # The bug: truncation removed assistant messages with tool_use blocks, + # but the corresponding tool_result blocks in user messages remain + orphan_count = self._count_orphaned_tool_results(truncated_messages) + + # This assertion documents the bug - we EXPECT orphans due to the wrong order + # When the fix is applied, this test should be updated + print( + f"\n[BUG REPRODUCTION] Found {orphan_count} orphaned tool_result blocks after truncation" + ) + print(f"Messages before truncation: {len(sanitized_messages)}") + print(f"Messages after truncation: {len(truncated_messages)}") + + # This is the problematic state - orphans exist after our current processing + # The Anthropic API would reject this request + if orphan_count > 0: + print( + "[CONFIRMED] Bug reproduced: truncation created orphaned tool_result blocks" + ) + # Mark as expected failure state + pytest.xfail( + f"BUG: Current order (sanitize → truncate) creates {orphan_count} orphaned " + "tool_result blocks. The Anthropic API would reject this with " + "'unexpected tool_use_id found in tool_result blocks'" + ) + + def test_correct_order_no_orphans(self) -> None: + """EXPECTED BEHAVIOR: Correct order (truncate → sanitize) leaves no orphans. + + This test shows what should happen: + 1. Create a conversation with many tool_use/tool_result pairs + 2. Apply truncation first (removes old messages, creates orphans) + 3. Apply sanitization second (removes the orphans) + 4. VERIFY: No orphaned blocks remain (CORRECT!) + """ + # Create conversation with 10 tool call pairs + messages = self._create_conversation_with_tool_calls(num_pairs=10) + + # Verify no orphans initially + assert self._count_orphaned_tool_results(messages) == 0 + assert self._count_orphaned_tool_uses(messages) == 0 + + # STEP 1: Truncation FIRST (correct order) + request = { + "model": "claude-3-opus-20240229", + "messages": messages, + } + + truncated_request, was_truncated = truncate_to_fit( + request, + max_input_tokens=2000, + preserve_recent=4, + safety_margin=0.9, + ) + + assert was_truncated, "Should have been truncated" + truncated_messages = truncated_request["messages"] + + # After truncation, we might have orphans + orphans_after_truncation = self._count_orphaned_tool_results(truncated_messages) + print(f"\nOrphaned tool_results after truncation: {orphans_after_truncation}") + + # STEP 2: Sanitization SECOND (correct order) + final_messages = _sanitize_tool_results(truncated_messages) + + # After sanitization, no orphans should remain + final_orphan_count = self._count_orphaned_tool_results(final_messages) + + print(f"Orphaned tool_results after sanitization: {final_orphan_count}") + + assert final_orphan_count == 0, ( + f"After correct order (truncate → sanitize), should have 0 orphans " + f"but found {final_orphan_count}" + ) + + def test_superdesign_scenario_simulation(self) -> None: + """Simulate the SuperDesign scenario with massive tool calls. + + SuperDesign (VS Code extension using @ai-sdk/anthropic) sends native + Anthropic format requests. When a conversation has many tool calls, + the context exceeds Claude's limit. The client (or proxy) compacts + the conversation, which can leave orphaned tool blocks. + + This test simulates that scenario. + """ + # SuperDesign typically has MANY tool calls in a single session + # Simulate a session with 30 tool call pairs (like reading/writing many files) + messages: list[Message] = [] + messages.append( + { + "role": "user", + "content": "Implement a new feature across multiple files.", + } + ) + + # Simulate 30 tool calls (read/write operations) + for i in range(30): + tool_id = f"toolu_{i:05d}" + # Large tool input to simulate real file operations + messages.append( + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": tool_id, + "name": "Write" if i % 2 else "Read", + "input": { + "file_path": f"/project/src/component_{i}.tsx", + "content": "Large file content... " * 100 + if i % 2 + else None, + }, + } + ], + } + ) + # Tool result with file contents + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_id, + "content": f"{'File written successfully' if i % 2 else 'import React...' * 50}", + } + ], + } + ) + + # Final user message + messages.append({"role": "user", "content": "Great, now run the tests."}) + + print("\n[SuperDesign Simulation]") + print(f"Total messages: {len(messages)}") + print("Tool call pairs: 30") + + # Current (buggy) flow: sanitize → truncate + request = {"model": "claude-sonnet-4-20250514", "messages": messages} + + # Step 1: Sanitize first (no orphans yet) + sanitized = _sanitize_tool_results(messages) + + # Step 2: Truncate (creates orphans!) + truncated_request, was_truncated = truncate_to_fit( + {"model": request["model"], "messages": sanitized}, + max_input_tokens=50000, # Realistic limit + preserve_recent=10, + safety_margin=0.9, + ) + + if was_truncated: + truncated_messages = truncated_request["messages"] + orphan_results = self._count_orphaned_tool_results(truncated_messages) + orphan_uses = self._count_orphaned_tool_uses(truncated_messages) + + print(f"Messages after truncation: {len(truncated_messages)}") + print(f"Orphaned tool_results: {orphan_results}") + print(f"Orphaned tool_uses: {orphan_uses}") + + if orphan_results > 0 or orphan_uses > 0: + print( + "[CONFIRMED] SuperDesign scenario reproduced - would fail at Anthropic API" + ) + pytest.xfail( + f"BUG: SuperDesign scenario creates {orphan_results} orphaned tool_results " + f"and {orphan_uses} orphaned tool_uses after truncation. " + "Anthropic API would reject this request." + ) + else: + print("Truncation not needed for this token limit") + + def test_full_pipeline_current_vs_fixed(self) -> None: + """Compare current (buggy) pipeline vs fixed pipeline. + + This test explicitly shows both pipelines side by side. + """ + messages = self._create_conversation_with_tool_calls(num_pairs=15) + + # === CURRENT (BUGGY) PIPELINE === + # This is what the adapters currently do + current_messages = list(messages) # Copy + + # Step 1: Sanitize + current_messages = _sanitize_tool_results(current_messages) + + # Step 2: Truncate + current_request, _ = truncate_to_fit( + {"model": "claude-3-opus-20240229", "messages": current_messages}, + max_input_tokens=3000, + preserve_recent=4, + safety_margin=0.9, + ) + current_final = current_request["messages"] + current_orphans = self._count_orphaned_tool_results(current_final) + + # === FIXED PIPELINE === + # This is what should happen + fixed_messages = list(messages) # Copy + + # Step 1: Truncate FIRST + fixed_request, _ = truncate_to_fit( + {"model": "claude-3-opus-20240229", "messages": fixed_messages}, + max_input_tokens=3000, + preserve_recent=4, + safety_margin=0.9, + ) + fixed_messages = fixed_request["messages"] + + # Step 2: Sanitize SECOND + fixed_final = _sanitize_tool_results(fixed_messages) + fixed_orphans = self._count_orphaned_tool_results(fixed_final) + + print("\n[Pipeline Comparison]") + print(f"Current pipeline (sanitize→truncate): {current_orphans} orphans") + print(f"Fixed pipeline (truncate→sanitize): {fixed_orphans} orphans") + + # Current pipeline has orphans (bug) + # Fixed pipeline has no orphans (correct) + assert fixed_orphans == 0, "Fixed pipeline should have no orphans" + + if current_orphans > 0: + print( + f"[BUG CONFIRMED] Current pipeline leaves {current_orphans} orphaned blocks" + ) + pytest.xfail( + f"Current pipeline creates {current_orphans} orphans while " + f"fixed pipeline creates {fixed_orphans}" + ) + + def test_preserve_recent_splits_tool_pair(self) -> None: + """CRITICAL: Test when preserve_recent boundary splits a tool_use/tool_result pair. + + The bug occurs when: + 1. preserve_recent is set to a value that splits a tool_use (assistant) from its tool_result (user) + 2. The assistant message with tool_use goes into truncatable and gets removed + 3. The user message with tool_result stays in preserved + 4. Result: orphaned tool_result with no matching tool_use + + This is the exact scenario causing "unexpected tool_use_id" errors. + """ + # DESIGN: Make tool_use the ONLY message in truncatable + # So ANY truncation will remove it, leaving the tool_result orphaned + # + # Structure with preserve_recent=2: + # [0] assistant: tool_use (LARGE to force truncation) <- In truncatable, REMOVED + # [1] user: tool_result <- In preserved (ORPHAN!) + # [2] user: final message <- In preserved + + messages: list[Message] = [ + # Tool use that will be the ONLY item in truncatable - WILL BE REMOVED + # Make it large enough to force truncation + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "orphan_tool_123", + "name": "analyze_data", + "input": { + "dataset": "x" * 2000 + }, # Large input to exceed token limit + } + ], + }, + # Tool result in preserved - will become orphan when tool_use is removed + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "orphan_tool_123", + "content": "Analysis complete: found 42 items", + } + ], + }, + # Final message in preserved + {"role": "user", "content": "Thanks, now summarize the results."}, + ] + + print("\n[Targeted Split Test - Verify Fix]") + print(f"Total messages: {len(messages)}") + print("Using preserve_recent=2 to keep only last 2 messages") + print("truncatable will contain ONLY the tool_use message") + + # FIXED ORDER: Truncate FIRST, then Sanitize + # This is the correct order that prevents orphaned tool_result blocks + + # Step 1: Truncate with preserve_recent=2 + # - preserved = [msg[1], msg[2]] = [tool_result, final message] + # - truncatable = [msg[0]] = [tool_use ONLY] + # Truncation will remove the tool_use, potentially creating an orphan + request = {"model": "claude-3-opus-20240229", "messages": messages} + truncated_request, was_truncated = truncate_to_fit( + request, + max_input_tokens=200, # Low enough to force removal of the large tool_use + preserve_recent=2, # Keep last 2: tool_result + final message + safety_margin=0.9, + ) + + # Step 2: Sanitize AFTER truncation (the fix) + # This will clean up any orphaned tool_result blocks created by truncation + if was_truncated: + truncated_request["messages"] = _sanitize_tool_results( + truncated_request["messages"] + ) + + if was_truncated: + final_messages = truncated_request["messages"] + print(f"Messages after truncation: {len(final_messages)}") + + # Check what we have + for i, msg in enumerate(final_messages): + role = msg.get("role") + content = msg.get("content") + if isinstance(content, list): + types = [b.get("type") for b in content if isinstance(b, dict)] + ids = [ + b.get("id") or b.get("tool_use_id") + for b in content + if isinstance(b, dict) + ] + print(f" [{i}] {role}: {types} {ids}") + else: + content_preview = str(content)[:60] if content else "(empty)" + print(f" [{i}] {role}: {content_preview}...") + + orphan_count = self._count_orphaned_tool_results(final_messages) + print(f"Orphaned tool_results: {orphan_count}") + + # Check if we have tool_result without tool_use + has_tool_result = any( + isinstance(msg.get("content"), list) + and any( + b.get("type") == "tool_result" + for b in msg["content"] + if isinstance(b, dict) + ) + for msg in final_messages + ) + has_tool_use = any( + isinstance(msg.get("content"), list) + and any( + b.get("type") == "tool_use" + for b in msg["content"] + if isinstance(b, dict) + ) + for msg in final_messages + ) + + print(f"Has tool_result: {has_tool_result}, Has tool_use: {has_tool_use}") + + # After fix: truncation followed by sanitization should leave no orphans + # This test now verifies the CORRECT behavior + if orphan_count > 0: + # If orphans exist, the fix is not working properly + pytest.fail( + f"REGRESSION: truncation created {orphan_count} orphaned tool_result(s). " + "Sanitization should run AFTER truncation to clean up orphans." + ) + + # Verify that we have tool_result but no tool_use (expected after truncation) + # This is fine as long as there are no ORPHANED tool_results + if has_tool_result and not has_tool_use: + print( + "[CORRECT] Tool_use was truncated, tool_result was kept but sanitized" + ) + print( + "This is the expected behavior when sanitization runs AFTER truncation" + ) + + print( + "[FIX VERIFIED] No orphaned tool_result blocks after truncation + sanitization" + ) + else: + print("No truncation occurred - need lower token limit") + + def test_token_removal_splits_tool_pair(self) -> None: + """Test when token-based removal stops in the middle of a tool pair. + + Even if preserve_recent doesn't split a pair, the token-based removal might: + 1. Start removing oldest messages + 2. Remove the assistant message with tool_use + 3. Stop (under token limit) before removing the user message with tool_result + 4. Result: orphaned tool_result + """ + # Create messages where removal is likely to stop mid-pair + messages: list[Message] = [ + {"role": "user", "content": "Start"}, # Small - will be removed + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_X", + "name": "analyze", + "input": {"data": "x" * 500}, + } + ], + }, # Medium - might be removed + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tool_X", + "content": "Result X", + } + ], + }, # Small - might NOT be removed (leaving orphan) + {"role": "user", "content": "Middle message"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_Y", + "name": "process", + "input": {"data": "y"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tool_Y", + "content": "Result Y", + } + ], + }, + {"role": "user", "content": "Final message with important context"}, + ] + + print("\n[Token Removal Split Test]") + + # Current (buggy) pipeline + sanitized = _sanitize_tool_results(messages) + + # Use preserve_recent that includes all tool_Y messages but not tool_X + request = {"model": "claude-3-opus-20240229", "messages": sanitized} + truncated_request, was_truncated = truncate_to_fit( + request, + max_input_tokens=800, + preserve_recent=4, # Keep last 4 messages (includes tool_Y pair + final) + safety_margin=0.9, + ) + + if was_truncated: + final_messages = truncated_request["messages"] + orphan_count = self._count_orphaned_tool_results(final_messages) + + print(f"Messages after truncation: {len(final_messages)}") + print(f"Orphaned tool_results: {orphan_count}") + + if orphan_count > 0: + print("[BUG CONFIRMED] Token removal split a tool pair") + pytest.xfail(f"Token removal created {orphan_count} orphan(s)") + + def test_few_messages_massive_content(self) -> None: + """Test truncation when there are fewer messages than preserve_recent but massive content. + + This reproduces the real-world issue where: + - Only 9 messages exist + - preserve_recent=10 (more than message count) + - Total tokens ~500k (way over 200k limit) + - OLD BEHAVIOR: truncate_to_fit gave up and returned unchanged + - NEW BEHAVIOR: reduce preserve_recent dynamically & truncate content blocks + + This was causing "prompt is too long: 200661 tokens > 200000 maximum" errors. + """ + # Create a few messages with MASSIVE content (simulating large tool results) + messages: list[Message] = [ + {"role": "user", "content": "Read the README file"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_readme", + "name": "Read", + "input": {"file_path": "/project/README.md"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tool_readme", + # Massive content - simulate 100k+ chars (like a big file) + "content": "README content line\n" * 20000, # ~400k chars + } + ], + }, + {"role": "user", "content": "Now read the package.json"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_package", + "name": "Read", + "input": {"file_path": "/project/package.json"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tool_package", + # Another massive tool result + "content": "package.json content line\n" * 15000, # ~375k chars + } + ], + }, + {"role": "user", "content": "Summarize both files for me."}, + ] + + print("\n[Few Messages, Massive Content Test]") + print(f"Total messages: {len(messages)}") + print("preserve_recent: 10 (more than message count)") + + # OLD behavior would give up here since len(messages)=7 < preserve_recent=10 + request = {"model": "claude-3-opus-20240229", "messages": messages} + truncated_request, was_truncated = truncate_to_fit( + request, + max_input_tokens=50000, # Target ~45k tokens + preserve_recent=10, # More than messages count (7) + safety_margin=0.9, + ) + + # NEW behavior should: + # 1. Reduce preserve_recent dynamically + # 2. Truncate large content blocks as fallback + + print(f"Was truncated: {was_truncated}") + + if was_truncated: + final_messages = truncated_request["messages"] + print(f"Messages after truncation: {len(final_messages)}") + + # Check total content size after truncation + total_chars = 0 + for msg in final_messages: + content = msg.get("content") + if isinstance(content, str): + total_chars += len(content) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict): + if block.get("type") == "tool_result": + total_chars += len(str(block.get("content", ""))) + elif block.get("type") == "text": + total_chars += len(str(block.get("text", ""))) + + print(f"Total content chars after truncation: {total_chars}") + + # Verify content was actually reduced + # Original was ~775k chars, should be significantly less now + assert total_chars < 100000, ( + f"Content should be significantly reduced (was {total_chars} chars)" + ) + + # Also verify no orphans were created + orphan_count = self._count_orphaned_tool_results(final_messages) + print(f"Orphaned tool_results: {orphan_count}") + + # Run sanitization as would happen in real pipeline + from ccproxy.llms.formatters.openai_to_anthropic.requests import ( + _sanitize_tool_results, + ) + + sanitized = _sanitize_tool_results(final_messages) + final_orphan_count = self._count_orphaned_tool_results(sanitized) + + assert final_orphan_count == 0, ( + f"After sanitization should have no orphans but found {final_orphan_count}" + ) + + print("[SUCCESS] Truncation handled few messages with massive content") + else: + # If not truncated, verify we're under the limit + from ccproxy.llms.utils.token_estimation import estimate_request_tokens + + tokens = estimate_request_tokens(request) + print(f"Request tokens: {tokens}") + print("[INFO] No truncation needed - content under limit") diff --git a/tests/unit/streaming/test_deferred_stream_errors.py b/tests/unit/streaming/test_deferred_stream_errors.py new file mode 100644 index 00000000..3025c02f --- /dev/null +++ b/tests/unit/streaming/test_deferred_stream_errors.py @@ -0,0 +1,66 @@ +"""Tests for streaming error payload normalization.""" + +from __future__ import annotations + +import httpx +import pytest + +from ccproxy.core.constants import FORMAT_ANTHROPIC_MESSAGES, FORMAT_OPENAI_CHAT +from ccproxy.core.logging import get_logger +from ccproxy.core.request_context import RequestContext +from ccproxy.streaming.deferred import DeferredStreaming + + +@pytest.mark.anyio +async def test_anthropic_error_wrapped_with_type() -> None: + ctx = RequestContext( + request_id="req_1", + start_time=0.0, + logger=get_logger(__name__), + ) + ctx.format_chain = [FORMAT_ANTHROPIC_MESSAGES] + + client = httpx.AsyncClient() + stream = DeferredStreaming( + method="POST", + url="http://example.test/v1/messages", + headers={}, + body=b"{}", + client=client, + request_context=ctx, + ) + + error_obj = {"error": {"type": "timeout_error", "message": "Request timeout"}} + formatted = stream._format_stream_error(error_obj) + + assert formatted["type"] == "error" + assert formatted["error"]["type"] == "timeout_error" + + await client.aclose() + + +@pytest.mark.anyio +async def test_non_anthropic_error_left_unchanged() -> None: + ctx = RequestContext( + request_id="req_2", + start_time=0.0, + logger=get_logger(__name__), + ) + ctx.format_chain = [FORMAT_OPENAI_CHAT] + + client = httpx.AsyncClient() + stream = DeferredStreaming( + method="POST", + url="http://example.test/v1/chat/completions", + headers={}, + body=b"{}", + client=client, + request_context=ctx, + ) + + error_obj = {"error": {"type": "timeout_error", "message": "Request timeout"}} + formatted = stream._format_stream_error(error_obj) + + assert formatted == error_obj + + await client.aclose()