Skip to content

Commit 61263df

Browse files
committed
feat: enhance token counting by extracting text from messages using VSCode LM API
1 parent 6331944 commit 61263df

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

src/api/providers/vscode-lm.ts

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import type { ApiHandlerOptions } from "../../shared/api"
77
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
88

99
import { ApiStream } from "../transform/stream"
10-
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
10+
import { convertToVsCodeLmMessages, extractTextCountFromMessage } from "../transform/vscode-lm-format"
1111

1212
import { BaseProvider } from "./base-provider"
1313
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -231,7 +231,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
231231
console.debug("Roo Code <Language Model API>: Empty chat message content")
232232
return 0
233233
}
234-
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
234+
const countMessage = extractTextCountFromMessage(text)
235+
tokenCount = await this.client.countTokens(countMessage, this.currentRequestCancellation.token)
235236
} else {
236237
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
237238
return 0
@@ -268,15 +269,10 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
268269
}
269270
}
270271

271-
private async calculateTotalInputTokens(
272-
systemPrompt: string,
273-
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
274-
): Promise<number> {
275-
const systemTokens: number = await this.internalCountTokens(systemPrompt)
276-
272+
private async calculateTotalInputTokens(vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
277273
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.internalCountTokens(msg)))
278274

279-
return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
275+
return messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
280276
}
281277

282278
private ensureCleanState(): void {
@@ -359,7 +355,7 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
359355
this.currentRequestCancellation = new vscode.CancellationTokenSource()
360356

361357
// Calculate input tokens before starting the stream
362-
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
358+
const totalInputTokens: number = await this.calculateTotalInputTokens(vsCodeLmMessages)
363359

364360
// Accumulate the text and count at the end of the stream to reduce token counting overhead.
365361
let accumulatedText: string = ""

src/api/transform/vscode-lm-format.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,32 @@ export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModel
155155
return null
156156
}
157157
}
158+
159+
export function extractTextCountFromMessage(message: vscode.LanguageModelChatMessage): string {
160+
let text = ""
161+
if (Array.isArray(message.content)) {
162+
for (const item of message.content) {
163+
if (item instanceof vscode.LanguageModelTextPart) {
164+
text += item.value
165+
}
166+
if (item instanceof vscode.LanguageModelToolResultPart) {
167+
text += item.callId
168+
for (const part of item.content) {
169+
if (part instanceof vscode.LanguageModelTextPart) {
170+
text += part.value
171+
}
172+
}
173+
}
174+
if (item instanceof vscode.LanguageModelToolCallPart) {
175+
text += item.name
176+
text += item.callId
177+
if (item.input && Object.keys(item.input).length > 0) {
178+
text += JSON.stringify(item.input)
179+
}
180+
}
181+
}
182+
} else if (typeof message.content === "string") {
183+
text += message.content
184+
}
185+
return text
186+
}

0 commit comments

Comments
 (0)