Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions src/api/providers/__tests__/vscode-lm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ import { VsCodeLmHandler } from "../vscode-lm"
import type { ApiHandlerOptions } from "../../../shared/api"
import type { Anthropic } from "@anthropic-ai/sdk"

// Mock the base provider's countTokens method
vi.mock("../base-provider", async () => {
const actual = await vi.importActual("../base-provider")
return {
...actual,
BaseProvider: class MockBaseProvider {
async countTokens() {
return 100 // Mock tiktoken to return 100 tokens
}
},
}
})

const mockLanguageModelChat = {
id: "test-model",
name: "Test Model",
Expand Down Expand Up @@ -300,4 +313,149 @@ describe("VsCodeLmHandler", () => {
await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed")
})
})

describe("countTokens with tiktoken fallback", () => {
it("should fall back to tiktoken when VSCode API returns 0 for non-empty content", async () => {
const content: Anthropic.Messages.ContentBlockParam[] = [
{
type: "text",
text: "Hello world",
},
]

// Mock VSCode API to return 0
mockLanguageModelChat.countTokens.mockResolvedValue(0)
handler["client"] = mockLanguageModelChat
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()

const result = await handler.countTokens(content)

// Should use tiktoken fallback which returns 100
expect(result).toBe(100)
})

it("should fall back to tiktoken when VSCode API throws an error", async () => {
const content: Anthropic.Messages.ContentBlockParam[] = [
{
type: "text",
text: "Hello world",
},
]

// Mock VSCode API to throw an error
mockLanguageModelChat.countTokens.mockRejectedValue(new Error("API Error"))
handler["client"] = mockLanguageModelChat
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()

const result = await handler.countTokens(content)

// Should use tiktoken fallback which returns 100
expect(result).toBe(100)
})

it("should use VSCode API when it returns valid token count", async () => {
const content: Anthropic.Messages.ContentBlockParam[] = [
{
type: "text",
text: "Hello world",
},
]

// Mock VSCode API to return valid count
mockLanguageModelChat.countTokens.mockResolvedValue(50)
handler["client"] = mockLanguageModelChat
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()

const result = await handler.countTokens(content)

// Should use VSCode API result
expect(result).toBe(50)
})

it("should fall back to tiktoken when no client is available", async () => {
const content: Anthropic.Messages.ContentBlockParam[] = [
{
type: "text",
text: "Hello world",
},
]

// No client available
handler["client"] = null

const result = await handler.countTokens(content)

// Should use tiktoken fallback which returns 100
expect(result).toBe(100)
})

it("should fall back to tiktoken when VSCode API returns negative value", async () => {
const content: Anthropic.Messages.ContentBlockParam[] = [
{
type: "text",
text: "Hello world",
},
]

// Mock VSCode API to return negative value
mockLanguageModelChat.countTokens.mockResolvedValue(-1)
handler["client"] = mockLanguageModelChat
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()

const result = await handler.countTokens(content)

// Should use tiktoken fallback which returns 100
expect(result).toBe(100)
})
})

describe("createMessage with frequent token updates", () => {
beforeEach(() => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel])
mockLanguageModelChat.countTokens.mockResolvedValue(10)

// Override the default client with our test client
handler["client"] = mockLanguageModelChat
})

it("should provide token updates during streaming", async () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user" as const,
content: "Hello",
},
]

// Create a long response to trigger intermediate token updates
const longResponse = "a".repeat(150) // 150 characters to trigger at least one update
mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
stream: (async function* () {
// Send response in chunks
yield new vscode.LanguageModelTextPart(longResponse.slice(0, 50))
yield new vscode.LanguageModelTextPart(longResponse.slice(50, 100))
yield new vscode.LanguageModelTextPart(longResponse.slice(100))
return
})(),
text: (async function* () {
yield longResponse
return
})(),
})

const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have text chunks and multiple usage updates
const textChunks = chunks.filter((c) => c.type === "text")
const usageChunks = chunks.filter((c) => c.type === "usage")

expect(textChunks).toHaveLength(3) // 3 text chunks
expect(usageChunks.length).toBeGreaterThan(1) // At least 2 usage updates (intermediate + final)
})
})
})
107 changes: 93 additions & 14 deletions src/api/providers/vscode-lm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,32 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
* @returns A promise resolving to the token count
*/
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
// Convert Anthropic content blocks to a string for VSCode LM token counting
let textContent = ""

for (const block of content) {
if (block.type === "text") {
textContent += block.text || ""
} else if (block.type === "image") {
// VSCode LM doesn't support images directly, so we'll just use a placeholder
textContent += "[IMAGE]"
try {
// Convert Anthropic content blocks to a string for VSCode LM token counting
let textContent = ""

for (const block of content) {
if (block.type === "text") {
textContent += block.text || ""
} else if (block.type === "image") {
// VSCode LM doesn't support images directly, so we'll just use a placeholder
textContent += "[IMAGE]"
}
}
}

return this.internalCountTokens(textContent)
const tokenCount = await this.internalCountTokens(textContent)

// If VSCode API returns 0 or fails, fall back to tiktoken
if (tokenCount === 0 && textContent.length > 0) {
console.debug("Roo Code <Language Model API>: Falling back to tiktoken for token counting")
return super.countTokens(content)
}

return tokenCount
} catch (error) {
console.warn("Roo Code <Language Model API>: Error in countTokens, falling back to tiktoken:", error)
return super.countTokens(content)
}
}

/**
Expand All @@ -204,12 +217,24 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
private async internalCountTokens(text: string | vscode.LanguageModelChatMessage): Promise<number> {
// Check for required dependencies
if (!this.client) {
console.warn("Roo Code <Language Model API>: No client available for token counting")
console.warn(
"Roo Code <Language Model API>: No client available for token counting, using tiktoken fallback",
)
// Fall back to tiktoken for string inputs
if (typeof text === "string") {
return this.fallbackToTiktoken(text)
}
return 0
}

if (!this.currentRequestCancellation) {
console.warn("Roo Code <Language Model API>: No cancellation token available for token counting")
console.warn(
"Roo Code <Language Model API>: No cancellation token available for token counting, using tiktoken fallback",
)
// Fall back to tiktoken for string inputs
if (typeof text === "string") {
return this.fallbackToTiktoken(text)
}
return 0
}

Expand Down Expand Up @@ -240,14 +265,30 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
// Validate the result
if (typeof tokenCount !== "number") {
console.warn("Roo Code <Language Model API>: Non-numeric token count received:", tokenCount)
// Fall back to tiktoken for string inputs
if (typeof text === "string") {
return this.fallbackToTiktoken(text)
}
return 0
}

if (tokenCount < 0) {
console.warn("Roo Code <Language Model API>: Negative token count received:", tokenCount)
// Fall back to tiktoken for string inputs
if (typeof text === "string") {
return this.fallbackToTiktoken(text)
}
return 0
}

// If we get 0 tokens but have content, fall back to tiktoken
if (tokenCount === 0 && typeof text === "string" && text.length > 0) {
console.debug(
"Roo Code <Language Model API>: VSCode API returned 0 tokens for non-empty text, using tiktoken fallback",
)
return this.fallbackToTiktoken(text)
}

return tokenCount
} catch (error) {
// Handle specific error types
Expand All @@ -257,17 +298,42 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
}

const errorMessage = error instanceof Error ? error.message : "Unknown error"
console.warn("Roo Code <Language Model API>: Token counting failed:", errorMessage)
console.warn("Roo Code <Language Model API>: Token counting failed, using tiktoken fallback:", errorMessage)

// Log additional error details if available
if (error instanceof Error && error.stack) {
console.debug("Token counting error stack:", error.stack)
}

// Fall back to tiktoken for string inputs
if (typeof text === "string") {
return this.fallbackToTiktoken(text)
}

return 0 // Fallback to prevent stream interruption
}
}

/**
* Fallback to tiktoken for token counting when VSCode API is unavailable or returns invalid results
*/
private async fallbackToTiktoken(text: string): Promise<number> {
try {
// Convert text to Anthropic content blocks format for base provider
const content: Anthropic.Messages.ContentBlockParam[] = [
{
type: "text",
text: text,
},
]
return super.countTokens(content)
} catch (error) {
console.error("Roo Code <Language Model API>: Tiktoken fallback failed:", error)
// Last resort: estimate based on character count (rough approximation)
return Math.ceil(text.length / 4)
}
}

private async calculateTotalInputTokens(
systemPrompt: string,
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
Expand Down Expand Up @@ -363,6 +429,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan

// Accumulate the text and count at the end of the stream to reduce token counting overhead.
let accumulatedText: string = ""
let lastTokenUpdateLength: number = 0
const TOKEN_UPDATE_INTERVAL = 100 // Update tokens every 100 characters for more responsive UI

try {
// Create the response stream with minimal required options
Expand Down Expand Up @@ -393,6 +461,17 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
type: "text",
text: chunk.value,
}

// Provide more frequent token updates during streaming
if (accumulatedText.length - lastTokenUpdateLength >= TOKEN_UPDATE_INTERVAL) {
const currentOutputTokens = await this.internalCountTokens(accumulatedText)
yield {
type: "usage",
inputTokens: totalInputTokens,
outputTokens: currentOutputTokens,
}
lastTokenUpdateLength = accumulatedText.length
}
} else if (chunk instanceof vscode.LanguageModelToolCallPart) {
try {
// Validate tool call parameters
Expand Down