Skip to content

Commit 50ce955

Browse files
committed
Merge branch 'main' into cte/move-model-fetchers
2 parents 159621c + 46576e0 commit 50ce955

File tree

7 files changed

+186
-128
lines changed

7 files changed

+186
-128
lines changed

src/api/providers/glama.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
7171
let maxTokens: number | undefined
7272

7373
if (this.getModel().id.startsWith("anthropic/")) {
74-
maxTokens = 8_192
74+
maxTokens = this.getModel().info.maxTokens
7575
}
7676

7777
const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = {
@@ -179,7 +179,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
179179
}
180180

181181
if (this.getModel().id.startsWith("anthropic/")) {
182-
requestOptions.max_tokens = 8192
182+
requestOptions.max_tokens = this.getModel().info.maxTokens
183183
}
184184

185185
const response = await this.client.chat.completions.create(requestOptions)
@@ -214,6 +214,17 @@ export async function getGlamaModels() {
214214
cacheReadsPrice: parseApiPrice(rawModel.pricePerToken?.cacheRead),
215215
}
216216

217+
switch (rawModel.id) {
218+
case rawModel.id.startsWith("anthropic/claude-3-7-sonnet"):
219+
modelInfo.maxTokens = 16384
220+
break
221+
case rawModel.id.startsWith("anthropic/"):
222+
modelInfo.maxTokens = 8192
223+
break
224+
default:
225+
break
226+
}
227+
217228
models[rawModel.id] = modelInfo
218229
}
219230
} catch (error) {

src/api/providers/openrouter.ts

Lines changed: 23 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
5454

5555
// prompt caching: https://openrouter.ai/docs/prompt-caching
5656
// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
57-
switch (this.getModel().id) {
58-
case "anthropic/claude-3.7-sonnet":
59-
case "anthropic/claude-3.5-sonnet":
60-
case "anthropic/claude-3.5-sonnet:beta":
61-
case "anthropic/claude-3.5-sonnet-20240620":
62-
case "anthropic/claude-3.5-sonnet-20240620:beta":
63-
case "anthropic/claude-3-5-haiku":
64-
case "anthropic/claude-3-5-haiku:beta":
65-
case "anthropic/claude-3-5-haiku-20241022":
66-
case "anthropic/claude-3-5-haiku-20241022:beta":
67-
case "anthropic/claude-3-haiku":
68-
case "anthropic/claude-3-haiku:beta":
69-
case "anthropic/claude-3-opus":
70-
case "anthropic/claude-3-opus:beta":
57+
switch (true) {
58+
case this.getModel().id.startsWith("anthropic/"):
7159
openAiMessages[0] = {
7260
role: "system",
7361
content: [
@@ -103,23 +91,6 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
10391
break
10492
}
10593

106-
// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
107-
// (models usually default to max tokens allowed)
108-
let maxTokens: number | undefined
109-
switch (this.getModel().id) {
110-
case "anthropic/claude-3.7-sonnet":
111-
case "anthropic/claude-3.5-sonnet":
112-
case "anthropic/claude-3.5-sonnet:beta":
113-
case "anthropic/claude-3.5-sonnet-20240620":
114-
case "anthropic/claude-3.5-sonnet-20240620:beta":
115-
case "anthropic/claude-3-5-haiku":
116-
case "anthropic/claude-3-5-haiku:beta":
117-
case "anthropic/claude-3-5-haiku-20241022":
118-
case "anthropic/claude-3-5-haiku-20241022:beta":
119-
maxTokens = 8_192
120-
break
121-
}
122-
12394
let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
12495
let topP: number | undefined = undefined
12596

@@ -140,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
140111
let fullResponseText = ""
141112
const stream = await this.client.chat.completions.create({
142113
model: this.getModel().id,
143-
max_tokens: maxTokens,
114+
max_tokens: this.getModel().info.maxTokens,
144115
temperature: this.options.modelTemperature ?? defaultTemperature,
145116
top_p: topP,
146117
messages: openAiMessages,
@@ -270,46 +241,46 @@ export async function getOpenRouterModels() {
270241
description: rawModel.description,
271242
}
272243

273-
switch (rawModel.id) {
274-
case "anthropic/claude-3.7-sonnet":
275-
case "anthropic/claude-3.7-sonnet:beta":
276-
case "anthropic/claude-3.5-sonnet":
277-
case "anthropic/claude-3.5-sonnet:beta":
278-
// NOTE: This needs to be synced with api.ts/openrouter default model info.
244+
// NOTE: this needs to be synced with api.ts/openrouter default model info.
245+
switch (true) {
246+
case rawModel.id.startsWith("anthropic/claude-3.7-sonnet"):
279247
modelInfo.supportsComputerUse = true
280248
modelInfo.supportsPromptCache = true
281249
modelInfo.cacheWritesPrice = 3.75
282250
modelInfo.cacheReadsPrice = 0.3
251+
modelInfo.maxTokens = 16384
252+
break
253+
case rawModel.id.startsWith("anthropic/claude-3.5-sonnet-20240620"):
254+
modelInfo.supportsPromptCache = true
255+
modelInfo.cacheWritesPrice = 3.75
256+
modelInfo.cacheReadsPrice = 0.3
257+
modelInfo.maxTokens = 8192
283258
break
284-
case "anthropic/claude-3.5-sonnet-20240620":
285-
case "anthropic/claude-3.5-sonnet-20240620:beta":
259+
case rawModel.id.startsWith("anthropic/claude-3.5-sonnet"):
260+
modelInfo.supportsComputerUse = true
286261
modelInfo.supportsPromptCache = true
287262
modelInfo.cacheWritesPrice = 3.75
288263
modelInfo.cacheReadsPrice = 0.3
264+
modelInfo.maxTokens = 8192
289265
break
290-
case "anthropic/claude-3-5-haiku":
291-
case "anthropic/claude-3-5-haiku:beta":
292-
case "anthropic/claude-3-5-haiku-20241022":
293-
case "anthropic/claude-3-5-haiku-20241022:beta":
294-
case "anthropic/claude-3.5-haiku":
295-
case "anthropic/claude-3.5-haiku:beta":
296-
case "anthropic/claude-3.5-haiku-20241022":
297-
case "anthropic/claude-3.5-haiku-20241022:beta":
266+
case rawModel.id.startsWith("anthropic/claude-3-5-haiku"):
298267
modelInfo.supportsPromptCache = true
299268
modelInfo.cacheWritesPrice = 1.25
300269
modelInfo.cacheReadsPrice = 0.1
270+
modelInfo.maxTokens = 8192
301271
break
302-
case "anthropic/claude-3-opus":
303-
case "anthropic/claude-3-opus:beta":
272+
case rawModel.id.startsWith("anthropic/claude-3-opus"):
304273
modelInfo.supportsPromptCache = true
305274
modelInfo.cacheWritesPrice = 18.75
306275
modelInfo.cacheReadsPrice = 1.5
276+
modelInfo.maxTokens = 8192
307277
break
308-
case "anthropic/claude-3-haiku":
309-
case "anthropic/claude-3-haiku:beta":
278+
case rawModel.id.startsWith("anthropic/claude-3-haiku"):
279+
default:
310280
modelInfo.supportsPromptCache = true
311281
modelInfo.cacheWritesPrice = 0.3
312282
modelInfo.cacheReadsPrice = 0.03
283+
modelInfo.maxTokens = 8192
313284
break
314285
}
315286

src/api/providers/requesty.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ export async function getRequestyModels({ apiKey }: { apiKey?: string }) {
7070
cacheReadsPrice: parseApiPrice(rawModel.cached_price),
7171
}
7272

73+
switch (rawModel.id) {
74+
case rawModel.id.startsWith("anthropic/claude-3-7-sonnet"):
75+
modelInfo.maxTokens = 16384
76+
break
77+
case rawModel.id.startsWith("anthropic/"):
78+
modelInfo.maxTokens = 8192
79+
break
80+
default:
81+
break
82+
}
83+
7384
models[rawModel.id] = modelInfo
7485
}
7586
} catch (error) {

src/api/providers/unbound.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
7373
let maxTokens: number | undefined
7474

7575
if (this.getModel().id.startsWith("anthropic/")) {
76-
maxTokens = 8_192
76+
maxTokens = this.getModel().info.maxTokens
7777
}
7878

7979
const { data: completion, response } = await this.client.chat.completions
@@ -152,7 +152,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
152152
}
153153

154154
if (this.getModel().id.startsWith("anthropic/")) {
155-
requestOptions.max_tokens = 8192
155+
requestOptions.max_tokens = this.getModel().info.maxTokens
156156
}
157157

158158
const response = await this.client.chat.completions.create(requestOptions)
@@ -176,7 +176,7 @@ export async function getUnboundModels() {
176176
const rawModels: Record<string, any> = response.data
177177

178178
for (const [modelId, model] of Object.entries(rawModels)) {
179-
models[modelId] = {
179+
const modelInfo: ModelInfo = {
180180
maxTokens: model?.maxTokens ? parseInt(model.maxTokens) : undefined,
181181
contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0,
182182
supportsImages: model?.supportsImages ?? false,
@@ -187,6 +187,19 @@ export async function getUnboundModels() {
187187
cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined,
188188
cacheReadsPrice: model?.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined,
189189
}
190+
191+
switch (true) {
192+
case modelId.startsWith("anthropic/claude-3-7-sonnet"):
193+
modelInfo.maxTokens = 16384
194+
break
195+
case modelId.startsWith("anthropic/"):
196+
modelInfo.maxTokens = 8192
197+
break
198+
default:
199+
break
200+
}
201+
202+
models[modelId] = modelInfo
190203
}
191204
}
192205
} catch (error) {

src/core/sliding-window/__tests__/sliding-window.test.ts

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import { Anthropic } from "@anthropic-ai/sdk"
55
import { ModelInfo } from "../../../shared/api"
66
import { truncateConversation, truncateConversationIfNeeded } from "../index"
77

8+
/**
9+
* Tests for the truncateConversation function
10+
*/
811
describe("truncateConversation", () => {
912
it("should retain the first message", () => {
1013
const messages: Anthropic.Messages.MessageParam[] = [
@@ -91,6 +94,86 @@ describe("truncateConversation", () => {
9194
})
9295
})
9396

97+
/**
98+
* Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
99+
*/
100+
describe("getMaxTokens", () => {
101+
// We'll test this indirectly through truncateConversationIfNeeded
102+
const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
103+
contextWindow,
104+
supportsPromptCache: true, // Not relevant for getMaxTokens
105+
maxTokens,
106+
})
107+
108+
// Reuse across tests for consistency
109+
const messages: Anthropic.Messages.MessageParam[] = [
110+
{ role: "user", content: "First message" },
111+
{ role: "assistant", content: "Second message" },
112+
{ role: "user", content: "Third message" },
113+
{ role: "assistant", content: "Fourth message" },
114+
{ role: "user", content: "Fifth message" },
115+
]
116+
117+
it("should use maxTokens as buffer when specified", () => {
118+
const modelInfo = createModelInfo(100000, 50000)
119+
// Max tokens = 100000 - 50000 = 50000
120+
121+
// Below max tokens - no truncation
122+
const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
123+
expect(result1).toEqual(messages)
124+
125+
// Above max tokens - truncate
126+
const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
127+
expect(result2).not.toEqual(messages)
128+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
129+
})
130+
131+
it("should use 20% of context window as buffer when maxTokens is undefined", () => {
132+
const modelInfo = createModelInfo(100000, undefined)
133+
// Max tokens = 100000 - (100000 * 0.2) = 80000
134+
135+
// Below max tokens - no truncation
136+
const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
137+
expect(result1).toEqual(messages)
138+
139+
// Above max tokens - truncate
140+
const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
141+
expect(result2).not.toEqual(messages)
142+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
143+
})
144+
145+
it("should handle small context windows appropriately", () => {
146+
const modelInfo = createModelInfo(50000, 10000)
147+
// Max tokens = 50000 - 10000 = 40000
148+
149+
// Below max tokens - no truncation
150+
const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
151+
expect(result1).toEqual(messages)
152+
153+
// Above max tokens - truncate
154+
const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
155+
expect(result2).not.toEqual(messages)
156+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
157+
})
158+
159+
it("should handle large context windows appropriately", () => {
160+
const modelInfo = createModelInfo(200000, 30000)
161+
// Max tokens = 200000 - 30000 = 170000
162+
163+
// Below max tokens - no truncation
164+
const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
165+
expect(result1).toEqual(messages)
166+
167+
// Above max tokens - truncate
168+
const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
169+
expect(result2).not.toEqual(messages)
170+
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
171+
})
172+
})
173+
174+
/**
175+
* Tests for the truncateConversationIfNeeded function
176+
*/
94177
describe("truncateConversationIfNeeded", () => {
95178
const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
96179
contextWindow,
@@ -106,25 +189,43 @@ describe("truncateConversationIfNeeded", () => {
106189
{ role: "user", content: "Fifth message" },
107190
]
108191

109-
it("should not truncate if tokens are below threshold for prompt caching models", () => {
110-
const modelInfo = createModelInfo(200000, true, 50000)
111-
const totalTokens = 100000 // Below threshold
192+
it("should not truncate if tokens are below max tokens threshold", () => {
193+
const modelInfo = createModelInfo(100000, true, 30000)
194+
const maxTokens = 100000 - 30000 // 70000
195+
const totalTokens = 69999 // Below threshold
196+
112197
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
113-
expect(result).toEqual(messages)
198+
expect(result).toEqual(messages) // No truncation occurs
114199
})
115200

116-
it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
117-
const modelInfo = createModelInfo(200000, false)
118-
const totalTokens = 100000 // Below threshold
201+
it("should truncate if tokens are above max tokens threshold", () => {
202+
const modelInfo = createModelInfo(100000, true, 30000)
203+
const maxTokens = 100000 - 30000 // 70000
204+
const totalTokens = 70001 // Above threshold
205+
206+
// When truncating, always uses 0.5 fraction
207+
// With 4 messages after the first, 0.5 fraction means remove 2 messages
208+
const expectedResult = [messages[0], messages[3], messages[4]]
209+
119210
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
120-
expect(result).toEqual(messages)
211+
expect(result).toEqual(expectedResult)
121212
})
122213

123-
it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
124-
const modelInfo = createModelInfo(50000, true) // Small context window
125-
const totalTokens = 40001 // Above 80% threshold (40000)
126-
const mockResult = [messages[0], messages[3], messages[4]]
127-
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
128-
expect(result).toEqual(mockResult)
214+
it("should work with non-prompt caching models the same as prompt caching models", () => {
215+
// The implementation no longer differentiates between prompt caching and non-prompt caching models
216+
const modelInfo1 = createModelInfo(100000, true, 30000)
217+
const modelInfo2 = createModelInfo(100000, false, 30000)
218+
219+
// Test below threshold
220+
const belowThreshold = 69999
221+
expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
222+
truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
223+
)
224+
225+
// Test above threshold
226+
const aboveThreshold = 70001
227+
expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
228+
truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
229+
)
129230
})
130231
})

0 commit comments

Comments
 (0)