diff --git a/src/api/providers/__tests__/vscode-lm.spec.ts b/src/api/providers/__tests__/vscode-lm.spec.ts index afb349e5e0..9d83395b16 100644 --- a/src/api/providers/__tests__/vscode-lm.spec.ts +++ b/src/api/providers/__tests__/vscode-lm.spec.ts @@ -59,6 +59,19 @@ import { VsCodeLmHandler } from "../vscode-lm" import type { ApiHandlerOptions } from "../../../shared/api" import type { Anthropic } from "@anthropic-ai/sdk" +// Mock the base provider's countTokens method +vi.mock("../base-provider", async () => { + const actual = await vi.importActual("../base-provider") + return { + ...actual, + BaseProvider: class MockBaseProvider { + async countTokens() { + return 100 // Mock tiktoken to return 100 tokens + } + }, + } +}) + const mockLanguageModelChat = { id: "test-model", name: "Test Model", @@ -300,4 +313,149 @@ describe("VsCodeLmHandler", () => { await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed") }) }) + + describe("countTokens with tiktoken fallback", () => { + it("should fall back to tiktoken when VSCode API returns 0 for non-empty content", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Hello world", + }, + ] + + // Mock VSCode API to return 0 + mockLanguageModelChat.countTokens.mockResolvedValue(0) + handler["client"] = mockLanguageModelChat + handler["currentRequestCancellation"] = new vscode.CancellationTokenSource() + + const result = await handler.countTokens(content) + + // Should use tiktoken fallback which returns 100 + expect(result).toBe(100) + }) + + it("should fall back to tiktoken when VSCode API throws an error", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Hello world", + }, + ] + + // Mock VSCode API to throw an error + mockLanguageModelChat.countTokens.mockRejectedValue(new Error("API Error")) + handler["client"] = mockLanguageModelChat + handler["currentRequestCancellation"] = new vscode.CancellationTokenSource() + + const result = await handler.countTokens(content) + + // Should use tiktoken fallback which returns 100 + expect(result).toBe(100) + }) + + it("should use VSCode API when it returns valid token count", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Hello world", + }, + ] + + // Mock VSCode API to return valid count + mockLanguageModelChat.countTokens.mockResolvedValue(50) + handler["client"] = mockLanguageModelChat + handler["currentRequestCancellation"] = new vscode.CancellationTokenSource() + + const result = await handler.countTokens(content) + + // Should use VSCode API result + expect(result).toBe(50) + }) + + it("should fall back to tiktoken when no client is available", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Hello world", + }, + ] + + // No client available + handler["client"] = null + + const result = await handler.countTokens(content) + + // Should use tiktoken fallback which returns 100 + expect(result).toBe(100) + }) + + it("should fall back to tiktoken when VSCode API returns negative value", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Hello world", + }, + ] + + // Mock VSCode API to return negative value + mockLanguageModelChat.countTokens.mockResolvedValue(-1) + handler["client"] = mockLanguageModelChat + handler["currentRequestCancellation"] = new vscode.CancellationTokenSource() + + const result = await handler.countTokens(content) + + // Should use tiktoken fallback which returns 100 + expect(result).toBe(100) + }) + }) + + describe("createMessage with frequent token updates", () => { + beforeEach(() => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + mockLanguageModelChat.countTokens.mockResolvedValue(10) + + // Override the default client with our test client + handler["client"] = mockLanguageModelChat + }) + + it("should provide token updates during streaming", async () => { + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user" as const, + content: "Hello", + }, + ] + + // Create a long response to trigger intermediate token updates + const longResponse = "a".repeat(150) // 150 characters to trigger at least one update + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + // Send response in chunks + yield new vscode.LanguageModelTextPart(longResponse.slice(0, 50)) + yield new vscode.LanguageModelTextPart(longResponse.slice(50, 100)) + yield new vscode.LanguageModelTextPart(longResponse.slice(100)) + return + })(), + text: (async function* () { + yield longResponse + return + })(), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Should have text chunks and multiple usage updates + const textChunks = chunks.filter((c) => c.type === "text") + const usageChunks = chunks.filter((c) => c.type === "usage") + + expect(textChunks).toHaveLength(3) // 3 text chunks + expect(usageChunks.length).toBeGreaterThan(1) // At least 2 usage updates (intermediate + final) + }) + }) }) diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index 6474371bee..c65ced2b24 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -183,19 +183,32 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan * @returns A promise resolving to the token count */ override async countTokens(content: Array): Promise { - // Convert Anthropic content blocks to a string for VSCode LM token counting - let textContent = "" - - for (const block of content) { - if (block.type === "text") { - textContent += block.text || "" - } else if (block.type === "image") { - // VSCode LM doesn't support images directly, so we'll just use a placeholder - textContent += "[IMAGE]" + try { + // Convert Anthropic content blocks to a string for VSCode LM token counting + let textContent = "" + + for (const block of content) { + if (block.type === "text") { + textContent += block.text || "" + } else if (block.type === "image") { + // VSCode LM doesn't support images directly, so we'll just use a placeholder + textContent += "[IMAGE]" + } } - } - return this.internalCountTokens(textContent) + const tokenCount = await this.internalCountTokens(textContent) + + // If VSCode API returns 0 or fails, fall back to tiktoken + if (tokenCount === 0 && textContent.length > 0) { + console.debug("Roo Code : Falling back to tiktoken for token counting") + return super.countTokens(content) + } + + return tokenCount + } catch (error) { + console.warn("Roo Code : Error in countTokens, falling back to tiktoken:", error) + return super.countTokens(content) + } } /** @@ -204,12 +217,24 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan private async internalCountTokens(text: string | vscode.LanguageModelChatMessage): Promise { // Check for required dependencies if (!this.client) { - console.warn("Roo Code : No client available for token counting") + console.warn( + "Roo Code : No client available for token counting, using tiktoken fallback", + ) + // Fall back to tiktoken for string inputs + if (typeof text === "string") { + return this.fallbackToTiktoken(text) + } return 0 } if (!this.currentRequestCancellation) { - console.warn("Roo Code : No cancellation token available for token counting") + console.warn( + "Roo Code : No cancellation token available for token counting, using tiktoken fallback", + ) + // Fall back to tiktoken for string inputs + if (typeof text === "string") { + return this.fallbackToTiktoken(text) + } return 0 } @@ -240,14 +265,30 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan // Validate the result if (typeof tokenCount !== "number") { console.warn("Roo Code : Non-numeric token count received:", tokenCount) + // Fall back to tiktoken for string inputs + if (typeof text === "string") { + return this.fallbackToTiktoken(text) + } return 0 } if (tokenCount < 0) { console.warn("Roo Code : Negative token count received:", tokenCount) + // Fall back to tiktoken for string inputs + if (typeof text === "string") { + return this.fallbackToTiktoken(text) + } return 0 } + // If we get 0 tokens but have content, fall back to tiktoken + if (tokenCount === 0 && typeof text === "string" && text.length > 0) { + console.debug( + "Roo Code : VSCode API returned 0 tokens for non-empty text, using tiktoken fallback", + ) + return this.fallbackToTiktoken(text) + } + return tokenCount } catch (error) { // Handle specific error types @@ -257,17 +298,42 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan } const errorMessage = error instanceof Error ? error.message : "Unknown error" - console.warn("Roo Code : Token counting failed:", errorMessage) + console.warn("Roo Code : Token counting failed, using tiktoken fallback:", errorMessage) // Log additional error details if available if (error instanceof Error && error.stack) { console.debug("Token counting error stack:", error.stack) } + // Fall back to tiktoken for string inputs + if (typeof text === "string") { + return this.fallbackToTiktoken(text) + } + return 0 // Fallback to prevent stream interruption } } + /** + * Fallback to tiktoken for token counting when VSCode API is unavailable or returns invalid results + */ + private async fallbackToTiktoken(text: string): Promise { + try { + // Convert text to Anthropic content blocks format for base provider + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: text, + }, + ] + return super.countTokens(content) + } catch (error) { + console.error("Roo Code : Tiktoken fallback failed:", error) + // Last resort: estimate based on character count (rough approximation) + return Math.ceil(text.length / 4) + } + } + private async calculateTotalInputTokens( systemPrompt: string, vsCodeLmMessages: vscode.LanguageModelChatMessage[], @@ -363,6 +429,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan // Accumulate the text and count at the end of the stream to reduce token counting overhead. let accumulatedText: string = "" + let lastTokenUpdateLength: number = 0 + const TOKEN_UPDATE_INTERVAL = 100 // Update tokens every 100 characters for more responsive UI try { // Create the response stream with minimal required options @@ -393,6 +461,17 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan type: "text", text: chunk.value, } + + // Provide more frequent token updates during streaming + if (accumulatedText.length - lastTokenUpdateLength >= TOKEN_UPDATE_INTERVAL) { + const currentOutputTokens = await this.internalCountTokens(accumulatedText) + yield { + type: "usage", + inputTokens: totalInputTokens, + outputTokens: currentOutputTokens, + } + lastTokenUpdateLength = accumulatedText.length + } } else if (chunk instanceof vscode.LanguageModelToolCallPart) { try { // Validate tool call parameters