|
1 | 1 | import { Anthropic } from "@anthropic-ai/sdk" |
| 2 | +import { ModelInfo } from "../../shared/api" |
2 | 3 |
|
3 | | -/* |
4 | | -We can't implement a dynamically updating sliding window as it would break prompt cache |
5 | | -every time. To maintain the benefits of caching, we need to keep conversation history |
6 | | -static. This operation should be performed as infrequently as possible. If a user reaches |
7 | | -a 200k context, we can assume that the first half is likely irrelevant to their current task. |
8 | | -Therefore, this function should only be called when absolutely necessary to fit within |
9 | | -context limits, not as a continuous process. |
10 | | -*/ |
11 | | -export function truncateHalfConversation( |
| 4 | +/** |
| 5 | + * Truncates a conversation by removing a fraction of the messages. |
| 6 | + * |
| 7 | + * The first message is always retained, and a specified fraction (rounded to an even number) |
| 8 | + * of messages from the beginning (excluding the first) is removed. |
| 9 | + * |
| 10 | + * @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages. |
| 11 | + * @param {number} fracToRemove - The fraction (between 0 and 1) of messages (excluding the first) to remove. |
| 12 | + * @returns {Anthropic.Messages.MessageParam[]} The truncated conversation messages. |
| 13 | + */ |
| 14 | +export function truncateConversation( |
12 | 15 | messages: Anthropic.Messages.MessageParam[], |
| 16 | + fracToRemove: number, |
13 | 17 | ): Anthropic.Messages.MessageParam[] { |
14 | | - // API expects messages to be in user-assistant order, and tool use messages must be followed by tool results. We need to maintain this structure while truncating. |
15 | | - |
16 | | - // Always keep the first Task message (this includes the project's file structure in environment_details) |
17 | 18 | const truncatedMessages = [messages[0]] |
18 | | - |
19 | | - // Remove half of user-assistant pairs |
20 | | - const messagesToRemove = Math.floor(messages.length / 4) * 2 // has to be even number |
21 | | - |
22 | | - const remainingMessages = messages.slice(messagesToRemove + 1) // has to start with assistant message since tool result cannot follow assistant message with no tool use |
| 19 | + const rawMessagesToRemove = Math.floor((messages.length - 1) * fracToRemove) |
| 20 | + const messagesToRemove = rawMessagesToRemove - (rawMessagesToRemove % 2) |
| 21 | + const remainingMessages = messages.slice(messagesToRemove + 1) |
23 | 22 | truncatedMessages.push(...remainingMessages) |
24 | 23 |
|
25 | 24 | return truncatedMessages |
26 | 25 | } |
| 26 | + |
| 27 | +/** |
| 28 | + * Conditionally truncates the conversation messages if the total token count exceeds the model's limit. |
| 29 | + * |
| 30 | + * Depending on whether the model supports prompt caching, different maximum token thresholds |
| 31 | + * and truncation fractions are used. If the current total tokens exceed the threshold, |
| 32 | + * the conversation is truncated using the appropriate fraction. |
| 33 | + * |
| 34 | + * @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages. |
| 35 | + * @param {number} totalTokens - The total number of tokens in the conversation. |
| 36 | + * @param {ModelInfo} modelInfo - Model metadata including context window size and prompt cache support. |
| 37 | + * @returns {Anthropic.Messages.MessageParam[]} The original or truncated conversation messages. |
| 38 | + */ |
| 39 | +export function truncateConversationIfNeeded( |
| 40 | + messages: Anthropic.Messages.MessageParam[], |
| 41 | + totalTokens: number, |
| 42 | + modelInfo: ModelInfo, |
| 43 | +): Anthropic.Messages.MessageParam[] { |
| 44 | + if (modelInfo.supportsPromptCache) { |
| 45 | + return totalTokens < getMaxTokensForPromptCachingModels(modelInfo) |
| 46 | + ? messages |
| 47 | + : truncateConversation(messages, getTruncFractionForPromptCachingModels(modelInfo)) |
| 48 | + } else { |
| 49 | + return totalTokens < getMaxTokensForNonPromptCachingModels(modelInfo) |
| 50 | + ? messages |
| 51 | + : truncateConversation(messages, getTruncFractionForNonPromptCachingModels(modelInfo)) |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +/** |
| 56 | + * Calculates the maximum allowed tokens for models that support prompt caching. |
| 57 | + * |
| 58 | + * The maximum is computed as the greater of (contextWindow - 40000) and 80% of the contextWindow. |
| 59 | + * |
| 60 | + * @param {ModelInfo} modelInfo - The model information containing the context window size. |
| 61 | + * @returns {number} The maximum number of tokens allowed for prompt caching models. |
| 62 | + */ |
| 63 | +function getMaxTokensForPromptCachingModels(modelInfo: ModelInfo): number { |
| 64 | + return Math.max(modelInfo.contextWindow - 40_000, modelInfo.contextWindow * 0.8) |
| 65 | +} |
| 66 | + |
| 67 | +/** |
| 68 | + * Provides the fraction of messages to remove for models that support prompt caching. |
| 69 | + * |
| 70 | + * @param {ModelInfo} modelInfo - The model information (unused in current implementation). |
| 71 | + * @returns {number} The truncation fraction for prompt caching models (fixed at 0.5). |
| 72 | + */ |
| 73 | +function getTruncFractionForPromptCachingModels(modelInfo: ModelInfo): number { |
| 74 | + return 0.5 |
| 75 | +} |
| 76 | + |
| 77 | +/** |
| 78 | + * Calculates the maximum allowed tokens for models that do not support prompt caching. |
| 79 | + * |
| 80 | + * The maximum is computed as the greater of (contextWindow - 40000) and 80% of the contextWindow. |
| 81 | + * |
| 82 | + * @param {ModelInfo} modelInfo - The model information containing the context window size. |
| 83 | + * @returns {number} The maximum number of tokens allowed for non-prompt caching models. |
| 84 | + */ |
| 85 | +function getMaxTokensForNonPromptCachingModels(modelInfo: ModelInfo): number { |
| 86 | + return Math.max(modelInfo.contextWindow - 40_000, modelInfo.contextWindow * 0.8) |
| 87 | +} |
| 88 | + |
| 89 | +/** |
| 90 | + * Provides the fraction of messages to remove for models that do not support prompt caching. |
| 91 | + * |
| 92 | + * @param {ModelInfo} modelInfo - The model information. |
| 93 | + * @returns {number} The truncation fraction for non-prompt caching models (fixed at 0.1). |
| 94 | + */ |
| 95 | +function getTruncFractionForNonPromptCachingModels(modelInfo: ModelInfo): number { |
| 96 | + return Math.min(40_000 / modelInfo.contextWindow, 0.2) |
| 97 | +} |
0 commit comments