Skip to content

Commit 7f1c6d7

Browse files
committed
Merge branch 'cte/move-model-fetchers' into cte/openrouter-claude-thinking
2 parents 392a237 + 50ce955 commit 7f1c6d7

File tree

7 files changed

+183
-133
lines changed

7 files changed

+183
-133
lines changed

src/api/providers/glama.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
7171
let maxTokens: number | undefined
7272

7373
if (this.getModel().id.startsWith("anthropic/")) {
74-
maxTokens = 8_192
74+
maxTokens = this.getModel().info.maxTokens
7575
}
7676

7777
const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = {
@@ -179,7 +179,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
179179
}
180180

181181
if (this.getModel().id.startsWith("anthropic/")) {
182-
requestOptions.max_tokens = 8192
182+
requestOptions.max_tokens = this.getModel().info.maxTokens
183183
}
184184

185185
const response = await this.client.chat.completions.create(requestOptions)
@@ -214,6 +214,17 @@ export async function getGlamaModels() {
214214
cacheReadsPrice: parseApiPrice(rawModel.pricePerToken?.cacheRead),
215215
}
216216

217+
switch (rawModel.id) {
218+
case rawModel.id.startsWith("anthropic/claude-3-7-sonnet"):
219+
modelInfo.maxTokens = 16384
220+
break
221+
case rawModel.id.startsWith("anthropic/"):
222+
modelInfo.maxTokens = 8192
223+
break
224+
default:
225+
break
226+
}
227+
217228
models[rawModel.id] = modelInfo
218229
}
219230
} catch (error) {

src/api/providers/openrouter.ts

Lines changed: 20 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
5656

5757
// prompt caching: https://openrouter.ai/docs/prompt-caching
5858
// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
59-
switch (modelId) {
60-
case "anthropic/claude-3.7-sonnet:thinking":
61-
case "anthropic/claude-3.7-sonnet":
62-
case "anthropic/claude-3.7-sonnet:beta":
63-
case "anthropic/claude-3.5-sonnet":
64-
case "anthropic/claude-3.5-sonnet:beta":
65-
case "anthropic/claude-3.5-sonnet-20240620":
66-
case "anthropic/claude-3.5-sonnet-20240620:beta":
67-
case "anthropic/claude-3-5-haiku":
68-
case "anthropic/claude-3-5-haiku:beta":
69-
case "anthropic/claude-3-5-haiku-20241022":
70-
case "anthropic/claude-3-5-haiku-20241022:beta":
71-
case "anthropic/claude-3-haiku":
72-
case "anthropic/claude-3-haiku:beta":
73-
case "anthropic/claude-3-opus":
74-
case "anthropic/claude-3-opus:beta":
59+
switch (true) {
60+
case this.getModel().id.startsWith("anthropic/"):
7561
openAiMessages[0] = {
7662
role: "system",
7763
content: [
@@ -107,20 +93,6 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
10793
break
10894
}
10995

110-
// Not sure how openrouter defaults max tokens when no value is
111-
// provided, but the Anthropic API requires this value and since they
112-
// offer both 4096 and 8192 variants, we should ensure 8192.
113-
// (Models usually default to max tokens allowed.)
114-
let maxTokens: number | undefined = undefined
115-
116-
if (modelId.startsWith("anthropic/claude-3.5")) {
117-
maxTokens = modelInfo.maxTokens ?? 8_192
118-
}
119-
120-
if (modelId.startsWith("anthropic/claude-3.7")) {
121-
maxTokens = modelInfo.maxTokens ?? 16_384
122-
}
123-
12496
let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
12597
let topP: number | undefined = undefined
12698

@@ -136,6 +108,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
136108

137109
let temperature = this.options.modelTemperature ?? defaultTemperature
138110

111+
// Anthropic "Thinking" models require a temperature of 1.0.
139112
if (modelInfo.thinking) {
140113
temperature = 1.0
141114
}
@@ -145,7 +118,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
145118

146119
const completionParams: OpenRouterChatCompletionParams = {
147120
model: modelId,
148-
max_tokens: maxTokens,
121+
max_tokens: modelInfo.maxTokens,
149122
temperature,
150123
top_p: topP,
151124
messages: openAiMessages,
@@ -290,56 +263,46 @@ export async function getOpenRouterModels() {
290263
thinking: rawModel.id === "anthropic/claude-3.7-sonnet:thinking",
291264
}
292265

293-
switch (rawModel.id) {
294-
case "anthropic/claude-3.7-sonnet:thinking":
295-
case "anthropic/claude-3.7-sonnet":
296-
case "anthropic/claude-3.7-sonnet:beta":
297-
modelInfo.maxTokens = 16_384
266+
// NOTE: this needs to be synced with api.ts/openrouter default model info.
267+
switch (true) {
268+
case rawModel.id.startsWith("anthropic/claude-3.7-sonnet"):
298269
modelInfo.supportsComputerUse = true
299270
modelInfo.supportsPromptCache = true
300271
modelInfo.cacheWritesPrice = 3.75
301272
modelInfo.cacheReadsPrice = 0.3
273+
modelInfo.maxTokens = 16384
302274
break
303-
case "anthropic/claude-3.5-sonnet":
304-
case "anthropic/claude-3.5-sonnet:beta":
305-
// NOTE: This needs to be synced with api.ts/openrouter default model info.
306-
modelInfo.maxTokens = 8_192
307-
modelInfo.supportsComputerUse = true
275+
case rawModel.id.startsWith("anthropic/claude-3.5-sonnet-20240620"):
308276
modelInfo.supportsPromptCache = true
309277
modelInfo.cacheWritesPrice = 3.75
310278
modelInfo.cacheReadsPrice = 0.3
279+
modelInfo.maxTokens = 8192
311280
break
312-
case "anthropic/claude-3.5-sonnet-20240620":
313-
case "anthropic/claude-3.5-sonnet-20240620:beta":
314-
modelInfo.maxTokens = 8_192
281+
case rawModel.id.startsWith("anthropic/claude-3.5-sonnet"):
282+
modelInfo.supportsComputerUse = true
315283
modelInfo.supportsPromptCache = true
316284
modelInfo.cacheWritesPrice = 3.75
317285
modelInfo.cacheReadsPrice = 0.3
286+
modelInfo.maxTokens = 8192
318287
break
319-
case "anthropic/claude-3-5-haiku":
320-
case "anthropic/claude-3-5-haiku:beta":
321-
case "anthropic/claude-3-5-haiku-20241022":
322-
case "anthropic/claude-3-5-haiku-20241022:beta":
323-
case "anthropic/claude-3.5-haiku":
324-
case "anthropic/claude-3.5-haiku:beta":
325-
case "anthropic/claude-3.5-haiku-20241022":
326-
case "anthropic/claude-3.5-haiku-20241022:beta":
327-
modelInfo.maxTokens = 8_192
288+
case rawModel.id.startsWith("anthropic/claude-3-5-haiku"):
328289
modelInfo.supportsPromptCache = true
329290
modelInfo.cacheWritesPrice = 1.25
330291
modelInfo.cacheReadsPrice = 0.1
292+
modelInfo.maxTokens = 8192
331293
break
332-
case "anthropic/claude-3-opus":
333-
case "anthropic/claude-3-opus:beta":
294+
case rawModel.id.startsWith("anthropic/claude-3-opus"):
334295
modelInfo.supportsPromptCache = true
335296
modelInfo.cacheWritesPrice = 18.75
336297
modelInfo.cacheReadsPrice = 1.5
298+
modelInfo.maxTokens = 8192
337299
break
338-
case "anthropic/claude-3-haiku":
339-
case "anthropic/claude-3-haiku:beta":
300+
case rawModel.id.startsWith("anthropic/claude-3-haiku"):
301+
default:
340302
modelInfo.supportsPromptCache = true
341303
modelInfo.cacheWritesPrice = 0.3
342304
modelInfo.cacheReadsPrice = 0.03
305+
modelInfo.maxTokens = 8192
343306
break
344307
}
345308

src/api/providers/requesty.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ export async function getRequestyModels({ apiKey }: { apiKey?: string }) {
7070
cacheReadsPrice: parseApiPrice(rawModel.cached_price),
7171
}
7272

73+
switch (rawModel.id) {
74+
case rawModel.id.startsWith("anthropic/claude-3-7-sonnet"):
75+
modelInfo.maxTokens = 16384
76+
break
77+
case rawModel.id.startsWith("anthropic/"):
78+
modelInfo.maxTokens = 8192
79+
break
80+
default:
81+
break
82+
}
83+
7384
models[rawModel.id] = modelInfo
7485
}
7586
} catch (error) {

src/api/providers/unbound.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
7373
let maxTokens: number | undefined
7474

7575
if (this.getModel().id.startsWith("anthropic/")) {
76-
maxTokens = 8_192
76+
maxTokens = this.getModel().info.maxTokens
7777
}
7878

7979
const { data: completion, response } = await this.client.chat.completions
@@ -152,7 +152,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
152152
}
153153

154154
if (this.getModel().id.startsWith("anthropic/")) {
155-
requestOptions.max_tokens = 8192
155+
requestOptions.max_tokens = this.getModel().info.maxTokens
156156
}
157157

158158
const response = await this.client.chat.completions.create(requestOptions)
@@ -176,7 +176,7 @@ export async function getUnboundModels() {
176176
const rawModels: Record<string, any> = response.data
177177

178178
for (const [modelId, model] of Object.entries(rawModels)) {
179-
models[modelId] = {
179+
const modelInfo: ModelInfo = {
180180
maxTokens: model?.maxTokens ? parseInt(model.maxTokens) : undefined,
181181
contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0,
182182
supportsImages: model?.supportsImages ?? false,
@@ -187,6 +187,19 @@ export async function getUnboundModels() {
187187
cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined,
188188
cacheReadsPrice: model?.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined,
189189
}
190+
191+
switch (true) {
192+
case modelId.startsWith("anthropic/claude-3-7-sonnet"):
193+
modelInfo.maxTokens = 16384
194+
break
195+
case modelId.startsWith("anthropic/"):
196+
modelInfo.maxTokens = 8192
197+
break
198+
default:
199+
break
200+
}
201+
202+
models[modelId] = modelInfo
190203
}
191204
}
192205
} catch (error) {

src/core/sliding-window/__tests__/sliding-window.test.ts

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import { Anthropic } from "@anthropic-ai/sdk"
55
import { ModelInfo } from "../../../shared/api"
66
import { truncateConversation, truncateConversationIfNeeded } from "../index"
77

8+
/**
9+
* Tests for the truncateConversation function
10+
*/
811
describe("truncateConversation", () => {
912
it("should retain the first message", () => {
1013
const messages: Anthropic.Messages.MessageParam[] = [
@@ -91,6 +94,86 @@ describe("truncateConversation", () => {
9194
})
9295
})
9396

97+
/**
98+
* Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
99+
*/
100+
describe("getMaxTokens", () => {
101+
// We'll test this indirectly through truncateConversationIfNeeded
102+
const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
103+
contextWindow,
104+
supportsPromptCache: true, // Not relevant for getMaxTokens
105+
maxTokens,
106+
})
107+
108+
// Reuse across tests for consistency
109+
const messages: Anthropic.Messages.MessageParam[] = [
110+
{ role: "user", content: "First message" },
111+
{ role: "assistant", content: "Second message" },
112+
{ role: "user", content: "Third message" },
113+
{ role: "assistant", content: "Fourth message" },
114+
{ role: "user", content: "Fifth message" },
115+
]
116+
117+
it("should use maxTokens as buffer when specified", () => {
118+
const modelInfo = createModelInfo(100000, 50000)
119+
// Max tokens = 100000 - 50000 = 50000
120+
121+
// Below max tokens - no truncation
122+
const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
123+
expect(result1).toEqual(messages)
124+
125+
// Above max tokens - truncate
126+
const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
127+
expect(result2).not.toEqual(messages)
128+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
129+
})
130+
131+
it("should use 20% of context window as buffer when maxTokens is undefined", () => {
132+
const modelInfo = createModelInfo(100000, undefined)
133+
// Max tokens = 100000 - (100000 * 0.2) = 80000
134+
135+
// Below max tokens - no truncation
136+
const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
137+
expect(result1).toEqual(messages)
138+
139+
// Above max tokens - truncate
140+
const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
141+
expect(result2).not.toEqual(messages)
142+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
143+
})
144+
145+
it("should handle small context windows appropriately", () => {
146+
const modelInfo = createModelInfo(50000, 10000)
147+
// Max tokens = 50000 - 10000 = 40000
148+
149+
// Below max tokens - no truncation
150+
const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
151+
expect(result1).toEqual(messages)
152+
153+
// Above max tokens - truncate
154+
const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
155+
expect(result2).not.toEqual(messages)
156+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
157+
})
158+
159+
it("should handle large context windows appropriately", () => {
160+
const modelInfo = createModelInfo(200000, 30000)
161+
// Max tokens = 200000 - 30000 = 170000
162+
163+
// Below max tokens - no truncation
164+
const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
165+
expect(result1).toEqual(messages)
166+
167+
// Above max tokens - truncate
168+
const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
169+
expect(result2).not.toEqual(messages)
170+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
171+
})
172+
})
173+
174+
/**
175+
* Tests for the truncateConversationIfNeeded function
176+
*/
94177
describe("truncateConversationIfNeeded", () => {
95178
const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
96179
contextWindow,
@@ -106,25 +189,43 @@ describe("truncateConversationIfNeeded", () => {
106189
{ role: "user", content: "Fifth message" },
107190
]
108191

109-
it("should not truncate if tokens are below threshold for prompt caching models", () => {
110-
const modelInfo = createModelInfo(200000, true, 50000)
111-
const totalTokens = 100000 // Below threshold
192+
it("should not truncate if tokens are below max tokens threshold", () => {
193+
const modelInfo = createModelInfo(100000, true, 30000)
194+
const maxTokens = 100000 - 30000 // 70000
195+
const totalTokens = 69999 // Below threshold
196+
112197
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
113-
expect(result).toEqual(messages)
198+
expect(result).toEqual(messages) // No truncation occurs
114199
})
115200

116-
it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
117-
const modelInfo = createModelInfo(200000, false)
118-
const totalTokens = 100000 // Below threshold
201+
it("should truncate if tokens are above max tokens threshold", () => {
202+
const modelInfo = createModelInfo(100000, true, 30000)
203+
const maxTokens = 100000 - 30000 // 70000
204+
const totalTokens = 70001 // Above threshold
205+
206+
// When truncating, always uses 0.5 fraction
207+
// With 4 messages after the first, 0.5 fraction means remove 2 messages
208+
const expectedResult = [messages[0], messages[3], messages[4]]
209+
119210
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
120-
expect(result).toEqual(messages)
211+
expect(result).toEqual(expectedResult)
121212
})
122213

123-
it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
124-
const modelInfo = createModelInfo(50000, true) // Small context window
125-
const totalTokens = 40001 // Above 80% threshold (40000)
126-
const mockResult = [messages[0], messages[3], messages[4]]
127-
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
128-
expect(result).toEqual(mockResult)
214+
it("should work with non-prompt caching models the same as prompt caching models", () => {
215+
// The implementation no longer differentiates between prompt caching and non-prompt caching models
216+
const modelInfo1 = createModelInfo(100000, true, 30000)
217+
const modelInfo2 = createModelInfo(100000, false, 30000)
218+
219+
// Test below threshold
220+
const belowThreshold = 69999
221+
expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
222+
truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
223+
)
224+
225+
// Test above threshold
226+
const aboveThreshold = 70001
227+
expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
228+
truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
229+
)
129230
})
130231
})

0 commit comments

Comments
 (0)