Skip to content

Commit 46576e0

Browse files
authored
Merge pull request #1194 from RooVetGit/fix_context_window_truncation_math
Fix context window truncation math
2 parents 82b282b + 91ef9fb commit 46576e0

File tree

7 files changed

+186
-128
lines changed

7 files changed

+186
-128
lines changed

src/api/providers/glama.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
6969
let maxTokens: number | undefined
7070

7171
if (this.getModel().id.startsWith("anthropic/")) {
72-
maxTokens = 8_192
72+
maxTokens = this.getModel().info.maxTokens
7373
}
7474

7575
const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = {
@@ -177,7 +177,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
177177
}
178178

179179
if (this.getModel().id.startsWith("anthropic/")) {
180-
requestOptions.max_tokens = 8192
180+
requestOptions.max_tokens = this.getModel().info.maxTokens
181181
}
182182

183183
const response = await this.client.chat.completions.create(requestOptions)

src/api/providers/openrouter.ts

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
5454

5555
// prompt caching: https://openrouter.ai/docs/prompt-caching
5656
// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
57-
switch (this.getModel().id) {
58-
case "anthropic/claude-3.7-sonnet":
59-
case "anthropic/claude-3.5-sonnet":
60-
case "anthropic/claude-3.5-sonnet:beta":
61-
case "anthropic/claude-3.5-sonnet-20240620":
62-
case "anthropic/claude-3.5-sonnet-20240620:beta":
63-
case "anthropic/claude-3-5-haiku":
64-
case "anthropic/claude-3-5-haiku:beta":
65-
case "anthropic/claude-3-5-haiku-20241022":
66-
case "anthropic/claude-3-5-haiku-20241022:beta":
67-
case "anthropic/claude-3-haiku":
68-
case "anthropic/claude-3-haiku:beta":
69-
case "anthropic/claude-3-opus":
70-
case "anthropic/claude-3-opus:beta":
57+
switch (true) {
58+
case this.getModel().id.startsWith("anthropic/"):
7159
openAiMessages[0] = {
7260
role: "system",
7361
content: [
@@ -103,23 +91,6 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
10391
break
10492
}
10593

106-
// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
107-
// (models usually default to max tokens allowed)
108-
let maxTokens: number | undefined
109-
switch (this.getModel().id) {
110-
case "anthropic/claude-3.7-sonnet":
111-
case "anthropic/claude-3.5-sonnet":
112-
case "anthropic/claude-3.5-sonnet:beta":
113-
case "anthropic/claude-3.5-sonnet-20240620":
114-
case "anthropic/claude-3.5-sonnet-20240620:beta":
115-
case "anthropic/claude-3-5-haiku":
116-
case "anthropic/claude-3-5-haiku:beta":
117-
case "anthropic/claude-3-5-haiku-20241022":
118-
case "anthropic/claude-3-5-haiku-20241022:beta":
119-
maxTokens = 8_192
120-
break
121-
}
122-
12394
let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
12495
let topP: number | undefined = undefined
12596

@@ -140,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
140111
let fullResponseText = ""
141112
const stream = await this.client.chat.completions.create({
142113
model: this.getModel().id,
143-
max_tokens: maxTokens,
114+
max_tokens: this.getModel().info.maxTokens,
144115
temperature: this.options.modelTemperature ?? defaultTemperature,
145116
top_p: topP,
146117
messages: openAiMessages,

src/api/providers/unbound.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
7171
let maxTokens: number | undefined
7272

7373
if (this.getModel().id.startsWith("anthropic/")) {
74-
maxTokens = 8_192
74+
maxTokens = this.getModel().info.maxTokens
7575
}
7676

7777
const { data: completion, response } = await this.client.chat.completions
@@ -150,7 +150,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
150150
}
151151

152152
if (this.getModel().id.startsWith("anthropic/")) {
153-
requestOptions.max_tokens = 8192
153+
requestOptions.max_tokens = this.getModel().info.maxTokens
154154
}
155155

156156
const response = await this.client.chat.completions.create(requestOptions)

src/core/sliding-window/__tests__/sliding-window.test.ts

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import { Anthropic } from "@anthropic-ai/sdk"
55
import { ModelInfo } from "../../../shared/api"
66
import { truncateConversation, truncateConversationIfNeeded } from "../index"
77

8+
/**
9+
* Tests for the truncateConversation function
10+
*/
811
describe("truncateConversation", () => {
912
it("should retain the first message", () => {
1013
const messages: Anthropic.Messages.MessageParam[] = [
@@ -91,6 +94,86 @@ describe("truncateConversation", () => {
9194
})
9295
})
9396

97+
/**
98+
* Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
99+
*/
100+
describe("getMaxTokens", () => {
101+
// We'll test this indirectly through truncateConversationIfNeeded
102+
const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
103+
contextWindow,
104+
supportsPromptCache: true, // Not relevant for getMaxTokens
105+
maxTokens,
106+
})
107+
108+
// Reuse across tests for consistency
109+
const messages: Anthropic.Messages.MessageParam[] = [
110+
{ role: "user", content: "First message" },
111+
{ role: "assistant", content: "Second message" },
112+
{ role: "user", content: "Third message" },
113+
{ role: "assistant", content: "Fourth message" },
114+
{ role: "user", content: "Fifth message" },
115+
]
116+
117+
it("should use maxTokens as buffer when specified", () => {
118+
const modelInfo = createModelInfo(100000, 50000)
119+
// Max tokens = 100000 - 50000 = 50000
120+
121+
// Below max tokens - no truncation
122+
const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
123+
expect(result1).toEqual(messages)
124+
125+
// Above max tokens - truncate
126+
const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
127+
expect(result2).not.toEqual(messages)
128+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
129+
})
130+
131+
it("should use 20% of context window as buffer when maxTokens is undefined", () => {
132+
const modelInfo = createModelInfo(100000, undefined)
133+
// Max tokens = 100000 - (100000 * 0.2) = 80000
134+
135+
// Below max tokens - no truncation
136+
const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
137+
expect(result1).toEqual(messages)
138+
139+
// Above max tokens - truncate
140+
const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
141+
expect(result2).not.toEqual(messages)
142+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
143+
})
144+
145+
it("should handle small context windows appropriately", () => {
146+
const modelInfo = createModelInfo(50000, 10000)
147+
// Max tokens = 50000 - 10000 = 40000
148+
149+
// Below max tokens - no truncation
150+
const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
151+
expect(result1).toEqual(messages)
152+
153+
// Above max tokens - truncate
154+
const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
155+
expect(result2).not.toEqual(messages)
156+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
157+
})
158+
159+
it("should handle large context windows appropriately", () => {
160+
const modelInfo = createModelInfo(200000, 30000)
161+
// Max tokens = 200000 - 30000 = 170000
162+
163+
// Below max tokens - no truncation
164+
const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
165+
expect(result1).toEqual(messages)
166+
167+
// Above max tokens - truncate
168+
const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
169+
expect(result2).not.toEqual(messages)
170+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
171+
})
172+
})
173+
174+
/**
175+
* Tests for the truncateConversationIfNeeded function
176+
*/
94177
describe("truncateConversationIfNeeded", () => {
95178
const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
96179
contextWindow,
@@ -106,25 +189,43 @@ describe("truncateConversationIfNeeded", () => {
106189
{ role: "user", content: "Fifth message" },
107190
]
108191

109-
it("should not truncate if tokens are below threshold for prompt caching models", () => {
110-
const modelInfo = createModelInfo(200000, true, 50000)
111-
const totalTokens = 100000 // Below threshold
192+
it("should not truncate if tokens are below max tokens threshold", () => {
193+
const modelInfo = createModelInfo(100000, true, 30000)
194+
const maxTokens = 100000 - 30000 // 70000
195+
const totalTokens = 69999 // Below threshold
196+
112197
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
113-
expect(result).toEqual(messages)
198+
expect(result).toEqual(messages) // No truncation occurs
114199
})
115200

116-
it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
117-
const modelInfo = createModelInfo(200000, false)
118-
const totalTokens = 100000 // Below threshold
201+
it("should truncate if tokens are above max tokens threshold", () => {
202+
const modelInfo = createModelInfo(100000, true, 30000)
203+
const maxTokens = 100000 - 30000 // 70000
204+
const totalTokens = 70001 // Above threshold
205+
206+
// When truncating, always uses 0.5 fraction
207+
// With 4 messages after the first, 0.5 fraction means remove 2 messages
208+
const expectedResult = [messages[0], messages[3], messages[4]]
209+
119210
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
120-
expect(result).toEqual(messages)
211+
expect(result).toEqual(expectedResult)
121212
})
122213

123-
it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
124-
const modelInfo = createModelInfo(50000, true) // Small context window
125-
const totalTokens = 40001 // Above 80% threshold (40000)
126-
const mockResult = [messages[0], messages[3], messages[4]]
127-
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
128-
expect(result).toEqual(mockResult)
214+
it("should work with non-prompt caching models the same as prompt caching models", () => {
215+
// The implementation no longer differentiates between prompt caching and non-prompt caching models
216+
const modelInfo1 = createModelInfo(100000, true, 30000)
217+
const modelInfo2 = createModelInfo(100000, false, 30000)
218+
219+
// Test below threshold
220+
const belowThreshold = 69999
221+
expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
222+
truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
223+
)
224+
225+
// Test above threshold
226+
const aboveThreshold = 70001
227+
expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
228+
truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
229+
)
129230
})
130231
})

src/core/sliding-window/index.ts

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,75 +28,26 @@ export function truncateConversation(
2828
/**
2929
* Conditionally truncates the conversation messages if the total token count exceeds the model's limit.
3030
*
31-
* Depending on whether the model supports prompt caching, different maximum token thresholds
32-
* and truncation fractions are used. If the current total tokens exceed the threshold,
33-
* the conversation is truncated using the appropriate fraction.
34-
*
3531
* @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages.
3632
* @param {number} totalTokens - The total number of tokens in the conversation.
37-
* @param {ModelInfo} modelInfo - Model metadata including context window size and prompt cache support.
33+
* @param {ModelInfo} modelInfo - Model metadata including context window size.
3834
* @returns {Anthropic.Messages.MessageParam[]} The original or truncated conversation messages.
3935
*/
4036
export function truncateConversationIfNeeded(
4137
messages: Anthropic.Messages.MessageParam[],
4238
totalTokens: number,
4339
modelInfo: ModelInfo,
4440
): Anthropic.Messages.MessageParam[] {
45-
if (modelInfo.supportsPromptCache) {
46-
return totalTokens < getMaxTokensForPromptCachingModels(modelInfo)
47-
? messages
48-
: truncateConversation(messages, getTruncFractionForPromptCachingModels(modelInfo))
49-
} else {
50-
return totalTokens < getMaxTokensForNonPromptCachingModels(modelInfo)
51-
? messages
52-
: truncateConversation(messages, getTruncFractionForNonPromptCachingModels(modelInfo))
53-
}
41+
return totalTokens < getMaxTokens(modelInfo) ? messages : truncateConversation(messages, 0.5)
5442
}
5543

5644
/**
57-
* Calculates the maximum allowed tokens for models that support prompt caching.
58-
*
59-
* The maximum is computed as the greater of (contextWindow - buffer) and 80% of the contextWindow.
45+
* Calculates the maximum allowed tokens
6046
*
6147
* @param {ModelInfo} modelInfo - The model information containing the context window size.
62-
* @returns {number} The maximum number of tokens allowed for prompt caching models.
63-
*/
64-
function getMaxTokensForPromptCachingModels(modelInfo: ModelInfo): number {
65-
// The buffer needs to be at least as large as `modelInfo.maxTokens`.
66-
const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
67-
return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
68-
}
69-
70-
/**
71-
* Provides the fraction of messages to remove for models that support prompt caching.
72-
*
73-
* @param {ModelInfo} modelInfo - The model information (unused in current implementation).
74-
* @returns {number} The truncation fraction for prompt caching models (fixed at 0.5).
75-
*/
76-
function getTruncFractionForPromptCachingModels(modelInfo: ModelInfo): number {
77-
return 0.5
78-
}
79-
80-
/**
81-
* Calculates the maximum allowed tokens for models that do not support prompt caching.
82-
*
83-
* The maximum is computed as the greater of (contextWindow - 40000) and 80% of the contextWindow.
84-
*
85-
* @param {ModelInfo} modelInfo - The model information containing the context window size.
86-
* @returns {number} The maximum number of tokens allowed for non-prompt caching models.
87-
*/
88-
function getMaxTokensForNonPromptCachingModels(modelInfo: ModelInfo): number {
89-
// The buffer needs to be at least as large as `modelInfo.maxTokens`.
90-
const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
91-
return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
92-
}
93-
94-
/**
95-
* Provides the fraction of messages to remove for models that do not support prompt caching.
96-
*
97-
* @param {ModelInfo} modelInfo - The model information.
98-
* @returns {number} The truncation fraction for non-prompt caching models (fixed at 0.1).
48+
* @returns {number} The maximum number of tokens allowed
9949
*/
100-
function getTruncFractionForNonPromptCachingModels(modelInfo: ModelInfo): number {
101-
return Math.min(40_000 / modelInfo.contextWindow, 0.2)
50+
function getMaxTokens(modelInfo: ModelInfo): number {
51+
// The buffer needs to be at least as large as `modelInfo.maxTokens`, or 20% of the context window if for some reason it's not set.
52+
return modelInfo.contextWindow - Math.max(modelInfo.maxTokens || modelInfo.contextWindow * 0.2)
10253
}

0 commit comments

Comments
 (0)