Skip to content

Commit b51d5c3

Browse files
committed
fix: handle VS Code LM API changes for token counting
- Update internalCountTokens to handle both old and new API versions - Add fallback mechanism when countTokens fails with message objects - Extract text content from messages when new API format is detected - Update calculateTotalInputTokens to try batch counting first - Add comprehensive test coverage for the fallback behavior Fixes #6290
1 parent 342ee70 commit b51d5c3

File tree

2 files changed

+146
-2
lines changed

2 files changed

+146
-2
lines changed

src/api/providers/__tests__/vscode-lm.spec.ts

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,4 +300,105 @@ describe("VsCodeLmHandler", () => {
300300
await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed")
301301
})
302302
})
303+
304+
describe("countTokens", () => {
305+
beforeEach(() => {
306+
const mockModel = { ...mockLanguageModelChat }
307+
;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel])
308+
handler["client"] = mockLanguageModelChat
309+
// Initialize the cancellation token for token counting
310+
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()
311+
})
312+
313+
it("should count tokens for text content", async () => {
314+
mockLanguageModelChat.countTokens.mockResolvedValueOnce(42)
315+
316+
const content = [{ type: "text" as const, text: "Hello, world!" }]
317+
const result = await handler.countTokens(content)
318+
319+
expect(result).toBe(42)
320+
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Hello, world!", expect.any(Object))
321+
})
322+
323+
it("should handle image content with placeholder", async () => {
324+
mockLanguageModelChat.countTokens.mockResolvedValueOnce(10)
325+
326+
const content = [
327+
{ type: "text" as const, text: "Check this out: " },
328+
{
329+
type: "image" as const,
330+
source: { type: "base64" as const, media_type: "image/png" as const, data: "base64data" },
331+
},
332+
]
333+
const result = await handler.countTokens(content)
334+
335+
expect(result).toBe(10)
336+
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith(
337+
"Check this out: [IMAGE]",
338+
expect.any(Object),
339+
)
340+
})
341+
342+
it("should handle empty content", async () => {
343+
const result = await handler.countTokens([])
344+
expect(result).toBe(0)
345+
expect(mockLanguageModelChat.countTokens).not.toHaveBeenCalled()
346+
})
347+
348+
it("should handle API errors gracefully", async () => {
349+
mockLanguageModelChat.countTokens.mockRejectedValueOnce(new Error("API Error"))
350+
351+
const content = [{ type: "text" as const, text: "Test content" }]
352+
const result = await handler.countTokens(content)
353+
354+
expect(result).toBe(0)
355+
})
356+
357+
it("should handle API errors and fallback gracefully", async () => {
358+
// Test that when countTokens fails with a message object, it falls back to string extraction
359+
const content = [{ type: "text" as const, text: "Test content" }]
360+
361+
// Simulate the scenario where the new API expects different format
362+
let callCount = 0
363+
mockLanguageModelChat.countTokens.mockImplementation(async (input) => {
364+
callCount++
365+
// First call with string succeeds
366+
if (typeof input === "string") {
367+
return 42
368+
}
369+
// Calls with message objects fail (simulating API change)
370+
throw new Error("Invalid message format - expected LanguageModelChatMessage2")
371+
})
372+
373+
const result = await handler.countTokens(content)
374+
375+
// Should successfully count tokens despite API changes
376+
expect(result).toBe(42)
377+
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Test content", expect.any(Object))
378+
})
379+
380+
it("should handle batch token counting with fallback", async () => {
381+
// Mock the internal countTokens to simulate batch counting failure then individual success
382+
const originalInternalCountTokens = handler["internalCountTokens"].bind(handler)
383+
let callCount = 0
384+
385+
handler["internalCountTokens"] = vi.fn().mockImplementation(async (input) => {
386+
callCount++
387+
if (callCount === 1 && Array.isArray(input)) {
388+
// First call with array fails
389+
return 0
390+
}
391+
// Subsequent calls succeed
392+
return callCount === 2 ? 15 : 10 // system: 15, message: 10
393+
})
394+
395+
const systemPrompt = "You are a helpful assistant"
396+
const messages = [vscode.LanguageModelChatMessage.User("Hello")]
397+
398+
const result = await handler["calculateTotalInputTokens"](systemPrompt, messages)
399+
400+
expect(result).toBe(25) // 15 + 10
401+
expect(handler["internalCountTokens"]).toHaveBeenCalledTimes(3) // batch + system + message
402+
})
403+
})
303404
})

src/api/providers/vscode-lm.ts

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,35 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
231231
console.debug("Roo Code <Language Model API>: Empty chat message content")
232232
return 0
233233
}
234-
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
234+
235+
// VS Code's updated API expects LanguageModelChatMessage2 format
236+
// Try the new API first, fall back to old API if it fails
237+
try {
238+
// Attempt to use the message directly (new API)
239+
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
240+
} catch (apiError) {
241+
// If the new API fails, try converting to string format
242+
console.debug("Roo Code <Language Model API>: Falling back to string-based token counting")
243+
244+
// Extract text content from the message
245+
let textContent = ""
246+
if (Array.isArray(text.content)) {
247+
for (const part of text.content) {
248+
if (part && typeof part === "object" && "value" in part && typeof part.value === "string") {
249+
textContent += part.value
250+
}
251+
}
252+
} else if (typeof text.content === "string") {
253+
textContent = text.content
254+
}
255+
256+
if (textContent) {
257+
tokenCount = await this.client.countTokens(textContent, this.currentRequestCancellation.token)
258+
} else {
259+
console.warn("Roo Code <Language Model API>: Could not extract text content from message")
260+
return 0
261+
}
262+
}
235263
} else {
236264
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
237265
return 0
@@ -272,8 +300,23 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
272300
systemPrompt: string,
273301
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
274302
): Promise<number> {
275-
const systemTokens: number = await this.internalCountTokens(systemPrompt)
303+
try {
304+
// Try to count all messages together first (new API approach)
305+
const allMessages = [vscode.LanguageModelChatMessage.Assistant(systemPrompt), ...vsCodeLmMessages]
276306

307+
// Attempt to count tokens for all messages at once
308+
const totalTokens = await this.internalCountTokens(allMessages as any)
309+
if (totalTokens > 0) {
310+
return totalTokens
311+
}
312+
} catch (error) {
313+
console.debug(
314+
"Roo Code <Language Model API>: Batch token counting failed, falling back to individual counting",
315+
)
316+
}
317+
318+
// Fallback: count tokens individually
319+
const systemTokens: number = await this.internalCountTokens(systemPrompt)
277320
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.internalCountTokens(msg)))
278321

279322
return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)

0 commit comments

Comments
 (0)