Skip to content

Commit 181993f

Browse files
authored
feat: enhance token counting by extracting text from messages using VSCode LM API (#6424)
1 parent 83280a0 commit 181993f

File tree

3 files changed

+204
-15
lines changed

3 files changed

+204
-15
lines changed

src/api/providers/vscode-lm.ts

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import type { ApiHandlerOptions } from "../../shared/api"
77
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
88

99
import { ApiStream } from "../transform/stream"
10-
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
10+
import { convertToVsCodeLmMessages, extractTextCountFromMessage } from "../transform/vscode-lm-format"
1111

1212
import { BaseProvider } from "./base-provider"
1313
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -231,7 +231,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
231231
console.debug("Roo Code <Language Model API>: Empty chat message content")
232232
return 0
233233
}
234-
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
234+
const countMessage = extractTextCountFromMessage(text)
235+
tokenCount = await this.client.countTokens(countMessage, this.currentRequestCancellation.token)
235236
} else {
236237
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
237238
return 0
@@ -268,15 +269,10 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
268269
}
269270
}
270271

271-
private async calculateTotalInputTokens(
272-
systemPrompt: string,
273-
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
274-
): Promise<number> {
275-
const systemTokens: number = await this.internalCountTokens(systemPrompt)
276-
272+
private async calculateTotalInputTokens(vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
277273
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.internalCountTokens(msg)))
278274

279-
return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
275+
return messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
280276
}
281277

282278
private ensureCleanState(): void {
@@ -359,7 +355,7 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
359355
this.currentRequestCancellation = new vscode.CancellationTokenSource()
360356

361357
// Calculate input tokens before starting the stream
362-
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
358+
const totalInputTokens: number = await this.calculateTotalInputTokens(vsCodeLmMessages)
363359

364360
// Accumulate the text and count at the end of the stream to reduce token counting overhead.
365361
let accumulatedText: string = ""

src/api/transform/__tests__/vscode-lm-format.spec.ts

Lines changed: 160 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
// npx vitest run src/api/transform/__tests__/vscode-lm-format.spec.ts
22

33
import { Anthropic } from "@anthropic-ai/sdk"
4+
import * as vscode from "vscode"
45

5-
import { convertToVsCodeLmMessages, convertToAnthropicRole } from "../vscode-lm-format"
6+
import { convertToVsCodeLmMessages, convertToAnthropicRole, extractTextCountFromMessage } from "../vscode-lm-format"
67

78
// Mock crypto using Vitest
89
vitest.stubGlobal("crypto", {
@@ -24,8 +25,8 @@ interface MockLanguageModelToolCallPart {
2425

2526
interface MockLanguageModelToolResultPart {
2627
type: "tool_result"
27-
toolUseId: string
28-
parts: MockLanguageModelTextPart[]
28+
callId: string
29+
content: MockLanguageModelTextPart[]
2930
}
3031

3132
// Mock vscode namespace
@@ -52,8 +53,8 @@ vitest.mock("vscode", () => {
5253
class MockLanguageModelToolResultPart {
5354
type = "tool_result"
5455
constructor(
55-
public toolUseId: string,
56-
public parts: MockLanguageModelTextPart[],
56+
public callId: string,
57+
public content: MockLanguageModelTextPart[],
5758
) {}
5859
}
5960

@@ -189,3 +190,157 @@ describe("convertToAnthropicRole", () => {
189190
expect(result).toBeNull()
190191
})
191192
})
193+
194+
describe("extractTextCountFromMessage", () => {
195+
it("should extract text from simple string content", () => {
196+
const message = {
197+
role: "user",
198+
content: "Hello world",
199+
} as any
200+
201+
const result = extractTextCountFromMessage(message)
202+
expect(result).toBe("Hello world")
203+
})
204+
205+
it("should extract text from LanguageModelTextPart", () => {
206+
const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content")
207+
const message = {
208+
role: "user",
209+
content: [mockTextPart],
210+
} as any
211+
212+
const result = extractTextCountFromMessage(message)
213+
expect(result).toBe("Text content")
214+
})
215+
216+
it("should extract text from multiple LanguageModelTextParts", () => {
217+
const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("First part")
218+
const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Second part")
219+
const message = {
220+
role: "user",
221+
content: [mockTextPart1, mockTextPart2],
222+
} as any
223+
224+
const result = extractTextCountFromMessage(message)
225+
expect(result).toBe("First partSecond part")
226+
})
227+
228+
it("should extract text from LanguageModelToolResultPart", () => {
229+
const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result content")
230+
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("tool-result-id", [
231+
mockTextPart,
232+
])
233+
const message = {
234+
role: "user",
235+
content: [mockToolResultPart],
236+
} as any
237+
238+
const result = extractTextCountFromMessage(message)
239+
expect(result).toBe("tool-result-idTool result content")
240+
})
241+
242+
it("should extract text from LanguageModelToolCallPart without input", () => {
243+
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {})
244+
const message = {
245+
role: "assistant",
246+
content: [mockToolCallPart],
247+
} as any
248+
249+
const result = extractTextCountFromMessage(message)
250+
expect(result).toBe("tool-namecall-id")
251+
})
252+
253+
it("should extract text from LanguageModelToolCallPart with input", () => {
254+
const mockInput = { operation: "add", numbers: [1, 2, 3] }
255+
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)(
256+
"call-id",
257+
"calculator",
258+
mockInput,
259+
)
260+
const message = {
261+
role: "assistant",
262+
content: [mockToolCallPart],
263+
} as any
264+
265+
const result = extractTextCountFromMessage(message)
266+
expect(result).toBe(`calculatorcall-id${JSON.stringify(mockInput)}`)
267+
})
268+
269+
it("should extract text from LanguageModelToolCallPart with empty input", () => {
270+
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {})
271+
const message = {
272+
role: "assistant",
273+
content: [mockToolCallPart],
274+
} as any
275+
276+
const result = extractTextCountFromMessage(message)
277+
expect(result).toBe("tool-namecall-id")
278+
})
279+
280+
it("should extract text from mixed content types", () => {
281+
const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content")
282+
const mockToolResultTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result")
283+
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [
284+
mockToolResultTextPart,
285+
])
286+
const mockInput = { param: "value" }
287+
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool", mockInput)
288+
289+
const message = {
290+
role: "assistant",
291+
content: [mockTextPart, mockToolResultPart, mockToolCallPart],
292+
} as any
293+
294+
const result = extractTextCountFromMessage(message)
295+
expect(result).toBe(`Text contentresult-idTool resulttoolcall-id${JSON.stringify(mockInput)}`)
296+
})
297+
298+
it("should handle empty array content", () => {
299+
const message = {
300+
role: "user",
301+
content: [],
302+
} as any
303+
304+
const result = extractTextCountFromMessage(message)
305+
expect(result).toBe("")
306+
})
307+
308+
it("should handle undefined content", () => {
309+
const message = {
310+
role: "user",
311+
content: undefined,
312+
} as any
313+
314+
const result = extractTextCountFromMessage(message)
315+
expect(result).toBe("")
316+
})
317+
318+
it("should handle ToolResultPart with multiple text parts", () => {
319+
const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 1")
320+
const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 2")
321+
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [
322+
mockTextPart1,
323+
mockTextPart2,
324+
])
325+
326+
const message = {
327+
role: "user",
328+
content: [mockToolResultPart],
329+
} as any
330+
331+
const result = extractTextCountFromMessage(message)
332+
expect(result).toBe("result-idPart 1Part 2")
333+
})
334+
335+
it("should handle ToolResultPart with empty parts array", () => {
336+
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [])
337+
338+
const message = {
339+
role: "user",
340+
content: [mockToolResultPart],
341+
} as any
342+
343+
const result = extractTextCountFromMessage(message)
344+
expect(result).toBe("result-id")
345+
})
346+
})

src/api/transform/vscode-lm-format.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,41 @@ export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModel
155155
return null
156156
}
157157
}
158+
159+
/**
160+
* Extracts the text content from a VS Code Language Model chat message.
161+
* @param message A VS Code Language Model chat message.
162+
* @returns The extracted text content.
163+
*/
164+
export function extractTextCountFromMessage(message: vscode.LanguageModelChatMessage): string {
165+
let text = ""
166+
if (Array.isArray(message.content)) {
167+
for (const item of message.content) {
168+
if (item instanceof vscode.LanguageModelTextPart) {
169+
text += item.value
170+
}
171+
if (item instanceof vscode.LanguageModelToolResultPart) {
172+
text += item.callId
173+
for (const part of item.content) {
174+
if (part instanceof vscode.LanguageModelTextPart) {
175+
text += part.value
176+
}
177+
}
178+
}
179+
if (item instanceof vscode.LanguageModelToolCallPart) {
180+
text += item.name
181+
text += item.callId
182+
if (item.input && Object.keys(item.input).length > 0) {
183+
try {
184+
text += JSON.stringify(item.input)
185+
} catch (error) {
186+
console.error("Roo Code <Language Model API>: Failed to stringify tool call input:", error)
187+
}
188+
}
189+
}
190+
}
191+
} else if (typeof message.content === "string") {
192+
text += message.content
193+
}
194+
return text
195+
}

0 commit comments

Comments
 (0)