Skip to content

Commit d8cafbc

Browse files
committed
Simplify the context truncation math
1 parent 82b282b commit d8cafbc

File tree

2 files changed

+122
-70
lines changed

2 files changed

+122
-70
lines changed

src/core/sliding-window/__tests__/sliding-window.test.ts

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import { Anthropic } from "@anthropic-ai/sdk"
55
import { ModelInfo } from "../../../shared/api"
66
import { truncateConversation, truncateConversationIfNeeded } from "../index"
77

8+
/**
9+
* Tests for the truncateConversation function
10+
*/
811
describe("truncateConversation", () => {
912
it("should retain the first message", () => {
1013
const messages: Anthropic.Messages.MessageParam[] = [
@@ -91,6 +94,86 @@ describe("truncateConversation", () => {
9194
})
9295
})
9396

97+
/**
98+
* Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
99+
*/
100+
describe("getMaxTokens", () => {
101+
// We'll test this indirectly through truncateConversationIfNeeded
102+
const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
103+
contextWindow,
104+
supportsPromptCache: true, // Not relevant for getMaxTokens
105+
maxTokens,
106+
})
107+
108+
// Reuse across tests for consistency
109+
const messages: Anthropic.Messages.MessageParam[] = [
110+
{ role: "user", content: "First message" },
111+
{ role: "assistant", content: "Second message" },
112+
{ role: "user", content: "Third message" },
113+
{ role: "assistant", content: "Fourth message" },
114+
{ role: "user", content: "Fifth message" },
115+
]
116+
117+
it("should use maxTokens as buffer when specified", () => {
118+
const modelInfo = createModelInfo(100000, 50000)
119+
// Max tokens = 100000 - 50000 = 50000
120+
121+
// Below max tokens - no truncation
122+
const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
123+
expect(result1).toEqual(messages)
124+
125+
// Above max tokens - truncate
126+
const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
127+
expect(result2).not.toEqual(messages)
128+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
129+
})
130+
131+
it("should use 20% of context window as buffer when maxTokens is undefined", () => {
132+
const modelInfo = createModelInfo(100000, undefined)
133+
// Max tokens = 100000 - (100000 * 0.2) = 80000
134+
135+
// Below max tokens - no truncation
136+
const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
137+
expect(result1).toEqual(messages)
138+
139+
// Above max tokens - truncate
140+
const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
141+
expect(result2).not.toEqual(messages)
142+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
143+
})
144+
145+
it("should handle small context windows appropriately", () => {
146+
const modelInfo = createModelInfo(50000, 10000)
147+
// Max tokens = 50000 - 10000 = 40000
148+
149+
// Below max tokens - no truncation
150+
const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
151+
expect(result1).toEqual(messages)
152+
153+
// Above max tokens - truncate
154+
const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
155+
expect(result2).not.toEqual(messages)
156+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
157+
})
158+
159+
it("should handle large context windows appropriately", () => {
160+
const modelInfo = createModelInfo(200000, 30000)
161+
// Max tokens = 200000 - 30000 = 170000
162+
163+
// Below max tokens - no truncation
164+
const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
165+
expect(result1).toEqual(messages)
166+
167+
// Above max tokens - truncate
168+
const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
169+
expect(result2).not.toEqual(messages)
170+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
171+
})
172+
})
173+
174+
/**
175+
* Tests for the truncateConversationIfNeeded function
176+
*/
94177
describe("truncateConversationIfNeeded", () => {
95178
const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
96179
contextWindow,
@@ -106,25 +189,43 @@ describe("truncateConversationIfNeeded", () => {
106189
{ role: "user", content: "Fifth message" },
107190
]
108191

109-
it("should not truncate if tokens are below threshold for prompt caching models", () => {
110-
const modelInfo = createModelInfo(200000, true, 50000)
111-
const totalTokens = 100000 // Below threshold
192+
it("should not truncate if tokens are below max tokens threshold", () => {
193+
const modelInfo = createModelInfo(100000, true, 30000)
194+
const maxTokens = 100000 - 30000 // 70000
195+
const totalTokens = 69999 // Below threshold
196+
112197
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
113-
expect(result).toEqual(messages)
198+
expect(result).toEqual(messages) // No truncation occurs
114199
})
115200

116-
it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
117-
const modelInfo = createModelInfo(200000, false)
118-
const totalTokens = 100000 // Below threshold
201+
it("should truncate if tokens are above max tokens threshold", () => {
202+
const modelInfo = createModelInfo(100000, true, 30000)
203+
const maxTokens = 100000 - 30000 // 70000
204+
const totalTokens = 70001 // Above threshold
205+
206+
// When truncating, always uses 0.5 fraction
207+
// With 4 messages after the first, 0.5 fraction means remove 2 messages
208+
const expectedResult = [messages[0], messages[3], messages[4]]
209+
119210
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
120-
expect(result).toEqual(messages)
211+
expect(result).toEqual(expectedResult)
121212
})
122213

123-
it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
124-
const modelInfo = createModelInfo(50000, true) // Small context window
125-
const totalTokens = 40001 // Above 80% threshold (40000)
126-
const mockResult = [messages[0], messages[3], messages[4]]
127-
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
128-
expect(result).toEqual(mockResult)
214+
it("should work with non-prompt caching models the same as prompt caching models", () => {
215+
// The implementation no longer differentiates between prompt caching and non-prompt caching models
216+
const modelInfo1 = createModelInfo(100000, true, 30000)
217+
const modelInfo2 = createModelInfo(100000, false, 30000)
218+
219+
// Test below threshold
220+
const belowThreshold = 69999
221+
expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
222+
truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
223+
)
224+
225+
// Test above threshold
226+
const aboveThreshold = 70001
227+
expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
228+
truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
229+
)
129230
})
130231
})

src/core/sliding-window/index.ts

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,75 +28,26 @@ export function truncateConversation(
2828
/**
2929
* Conditionally truncates the conversation messages if the total token count exceeds the model's limit.
3030
*
31-
* Depending on whether the model supports prompt caching, different maximum token thresholds
32-
* and truncation fractions are used. If the current total tokens exceed the threshold,
33-
* the conversation is truncated using the appropriate fraction.
34-
*
3531
* @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages.
3632
* @param {number} totalTokens - The total number of tokens in the conversation.
37-
* @param {ModelInfo} modelInfo - Model metadata including context window size and prompt cache support.
33+
* @param {ModelInfo} modelInfo - Model metadata including context window size.
3834
* @returns {Anthropic.Messages.MessageParam[]} The original or truncated conversation messages.
3935
*/
4036
export function truncateConversationIfNeeded(
4137
messages: Anthropic.Messages.MessageParam[],
4238
totalTokens: number,
4339
modelInfo: ModelInfo,
4440
): Anthropic.Messages.MessageParam[] {
45-
if (modelInfo.supportsPromptCache) {
46-
return totalTokens < getMaxTokensForPromptCachingModels(modelInfo)
47-
? messages
48-
: truncateConversation(messages, getTruncFractionForPromptCachingModels(modelInfo))
49-
} else {
50-
return totalTokens < getMaxTokensForNonPromptCachingModels(modelInfo)
51-
? messages
52-
: truncateConversation(messages, getTruncFractionForNonPromptCachingModels(modelInfo))
53-
}
41+
return totalTokens < getMaxTokens(modelInfo) ? messages : truncateConversation(messages, 0.5)
5442
}
5543

5644
/**
57-
* Calculates the maximum allowed tokens for models that support prompt caching.
58-
*
59-
* The maximum is computed as the greater of (contextWindow - buffer) and 80% of the contextWindow.
45+
* Calculates the maximum allowed tokens
6046
*
6147
* @param {ModelInfo} modelInfo - The model information containing the context window size.
62-
* @returns {number} The maximum number of tokens allowed for prompt caching models.
63-
*/
64-
function getMaxTokensForPromptCachingModels(modelInfo: ModelInfo): number {
65-
// The buffer needs to be at least as large as `modelInfo.maxTokens`.
66-
const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
67-
return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
68-
}
69-
70-
/**
71-
* Provides the fraction of messages to remove for models that support prompt caching.
72-
*
73-
* @param {ModelInfo} modelInfo - The model information (unused in current implementation).
74-
* @returns {number} The truncation fraction for prompt caching models (fixed at 0.5).
75-
*/
76-
function getTruncFractionForPromptCachingModels(modelInfo: ModelInfo): number {
77-
return 0.5
78-
}
79-
80-
/**
81-
* Calculates the maximum allowed tokens for models that do not support prompt caching.
82-
*
83-
* The maximum is computed as the greater of (contextWindow - 40000) and 80% of the contextWindow.
84-
*
85-
* @param {ModelInfo} modelInfo - The model information containing the context window size.
86-
* @returns {number} The maximum number of tokens allowed for non-prompt caching models.
87-
*/
88-
function getMaxTokensForNonPromptCachingModels(modelInfo: ModelInfo): number {
89-
// The buffer needs to be at least as large as `modelInfo.maxTokens`.
90-
const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
91-
return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
92-
}
93-
94-
/**
95-
* Provides the fraction of messages to remove for models that do not support prompt caching.
96-
*
97-
* @param {ModelInfo} modelInfo - The model information.
98-
* @returns {number} The truncation fraction for non-prompt caching models (fixed at 0.1).
48+
* @returns {number} The maximum number of tokens allowed
9949
*/
100-
function getTruncFractionForNonPromptCachingModels(modelInfo: ModelInfo): number {
101-
return Math.min(40_000 / modelInfo.contextWindow, 0.2)
50+
function getMaxTokens(modelInfo: ModelInfo): number {
51+
// The buffer needs to be at least as large as `modelInfo.maxTokens`, or 20% of the context window if for some reason it's not set.
52+
return modelInfo.contextWindow - Math.max(modelInfo.maxTokens || modelInfo.contextWindow * 0.2)
10253
}

0 commit comments

Comments
 (0)