Skip to content

Commit 58f6915

Browse files
committed
feat: introducing token-counting strategy using both local and API-based counting
1 parent 5bc144a commit 58f6915

File tree

9 files changed

+316
-98
lines changed

9 files changed

+316
-98
lines changed

src/api/index.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,17 @@ export interface ApiHandler {
5454
* but they can override this to use their native token counting endpoints
5555
*
5656
* @param content The content to count tokens for
57+
* @param options Additional options for token counting
5758
* @returns A promise resolving to the token count
5859
*/
59-
countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number>
60+
countTokens(
61+
content: Array<Anthropic.Messages.ContentBlockParam>,
62+
options: {
63+
maxTokens?: number | null
64+
effectiveThreshold?: number
65+
totalTokens: number
66+
},
67+
): Promise<number>
6068
}
6169

6270
export function buildApiHandler(configuration: ProviderSettings): ApiHandler {

src/api/providers/anthropic.ts

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -271,31 +271,18 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
271271
return content?.type === "text" ? content.text : ""
272272
}
273273

274-
/**
275-
* Counts tokens for the given content using Anthropic's API
276-
*
277-
* @param content The content blocks to count tokens for
278-
* @returns A promise resolving to the token count
279-
*/
280-
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
281-
try {
282-
return await super.countTokens(content)
283-
} catch (error) {
284-
console.warn("Anthropic local token counting failed, falling back to remote API", error)
285-
try {
286-
const { id: model } = this.getModel()
287-
const response = await this.client.messages.countTokens({
288-
model,
289-
messages: [{ role: "user", content }],
290-
})
291-
if (response.input_tokens !== undefined) {
292-
return response.input_tokens
293-
}
294-
console.warn("Anthropic remote token counting returned undefined, falling back to 0")
295-
} catch (remoteError) {
296-
console.warn("Anthropic remote token counting failed, falling back to 0", remoteError)
297-
}
298-
return 0
274+
protected override async apiBasedTokenCount(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
275+
const { id: model } = this.getModel()
276+
console.log(`API-BASED COUNTINNNNG`)
277+
const response = await this.client.messages.countTokens({
278+
model,
279+
messages: [{ role: "user", content }],
280+
})
281+
282+
if (response.input_tokens === undefined) {
283+
throw new Error("Anthropic remote token counting returned undefined.")
299284
}
285+
286+
return response.input_tokens
300287
}
301288
}

src/api/providers/base-provider.ts

Lines changed: 108 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,66 @@ import { Anthropic } from "@anthropic-ai/sdk"
22

33
import type { ModelInfo } from "@roo-code/types"
44

5+
import { getAllowedTokens, isSafetyNetTriggered } from "../utils/context-safety"
56
import type { ApiHandler, ApiHandlerCreateMessageMetadata } from "../index"
67
import { ApiStream } from "../transform/stream"
7-
import { countTokens } from "../../utils/countTokens"
8+
import { countTokens as localCountTokens } from "../../utils/countTokens"
89

910
/**
10-
* Base class for API providers that implements common functionality.
11+
* A utility class to compare local token estimates with precise API counts
12+
* and calculate a factor to improve estimation accuracy.
13+
*/
14+
class TokenCountComparator {
15+
private static readonly MAX_SAMPLES = 20
16+
private static readonly DEFAULT_SAFETY_FACTOR = 1.2
17+
private static readonly ADDITIONAL_SAFETY_FACTOR = 1.0
18+
19+
private samples: Array<{ local: number; api: number }> = []
20+
private safetyFactor = TokenCountComparator.DEFAULT_SAFETY_FACTOR
21+
22+
public addSample(local: number, api: number): void {
23+
if (local > 0 && api > 0) {
24+
this.samples.push({ local, api })
25+
if (this.samples.length > TokenCountComparator.MAX_SAMPLES) {
26+
this.samples.shift()
27+
}
28+
this.recalculateSafetyFactor()
29+
}
30+
}
31+
32+
public getSafetyFactor(): number {
33+
return this.safetyFactor
34+
}
35+
36+
private recalculateSafetyFactor(): void {
37+
if (this.samples.length === 0) {
38+
this.safetyFactor = TokenCountComparator.DEFAULT_SAFETY_FACTOR
39+
return
40+
}
41+
42+
const totalRatio = this.samples.reduce((sum, sample) => sum + sample.api / sample.local, 0)
43+
const averageRatio = totalRatio / this.samples.length
44+
this.safetyFactor = Math.max(1, averageRatio) * TokenCountComparator.ADDITIONAL_SAFETY_FACTOR
45+
}
46+
47+
public getSampleCount(): number {
48+
return this.samples.length
49+
}
50+
51+
public getAverageRatio(): number {
52+
if (this.samples.length === 0) return 1
53+
const totalRatio = this.samples.reduce((sum, sample) => sum + sample.api / sample.local, 0)
54+
return totalRatio / this.samples.length
55+
}
56+
}
57+
58+
/**
59+
* Base class for API providers that implements common functionality
1160
*/
1261
export abstract class BaseProvider implements ApiHandler {
62+
protected isFirstRequest = true
63+
protected tokenComparator = new TokenCountComparator()
64+
1365
abstract createMessage(
1466
systemPrompt: string,
1567
messages: Anthropic.Messages.MessageParam[],
@@ -18,18 +70,63 @@ export abstract class BaseProvider implements ApiHandler {
1870

1971
abstract getModel(): { id: string; info: ModelInfo }
2072

21-
/**
22-
* Default token counting implementation using tiktoken.
23-
* Providers can override this to use their native token counting endpoints.
24-
*
25-
* @param content The content to count tokens for
26-
* @returns A promise resolving to the token count
27-
*/
28-
async countTokens(content: Anthropic.Messages.ContentBlockParam[]): Promise<number> {
73+
// Override this function for each API provider
74+
protected async apiBasedTokenCount(content: Anthropic.Messages.ContentBlockParam[]) {
75+
return await localCountTokens(content, { useWorker: true })
76+
}
77+
78+
async countTokens(
79+
content: Anthropic.Messages.ContentBlockParam[],
80+
options: {
81+
maxTokens?: number | null
82+
effectiveThreshold?: number
83+
totalTokens: number
84+
},
85+
): Promise<number> {
2986
if (content.length === 0) {
3087
return 0
3188
}
3289

33-
return countTokens(content, { useWorker: true })
90+
const providerName = this.constructor.name
91+
92+
if (this.isFirstRequest) {
93+
this.isFirstRequest = false
94+
try {
95+
const apiCount = await this.apiBasedTokenCount(content)
96+
const localEstimate = await localCountTokens(content, { useWorker: true })
97+
this.tokenComparator.addSample(localEstimate, apiCount)
98+
99+
return apiCount
100+
} catch (error) {
101+
const localEstimate = await localCountTokens(content, { useWorker: true })
102+
return localEstimate
103+
}
104+
}
105+
106+
const localEstimate = await localCountTokens(content, { useWorker: true })
107+
108+
const { info } = this.getModel()
109+
const contextWindow = info.contextWindow
110+
const allowedTokens = getAllowedTokens(contextWindow, options.maxTokens)
111+
const projectedTokens = options.totalTokens + localEstimate * this.tokenComparator.getSafetyFactor()
112+
113+
if (
114+
isSafetyNetTriggered({
115+
projectedTokens,
116+
contextWindow,
117+
effectiveThreshold: options.effectiveThreshold,
118+
allowedTokens,
119+
})
120+
) {
121+
try {
122+
const apiCount = await this.apiBasedTokenCount(content)
123+
this.tokenComparator.addSample(localEstimate, apiCount)
124+
return apiCount
125+
} catch (error) {
126+
return Math.ceil(localEstimate * this.tokenComparator.getSafetyFactor())
127+
}
128+
}
129+
130+
return localEstimate
34131
}
35132
}

src/api/providers/gemini.ts

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ type GeminiHandlerOptions = ApiHandlerOptions & {
2525

2626
export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
2727
protected options: ApiHandlerOptions
28-
2928
private client: GoogleGenAI
3029

3130
constructor({ isVertex, ...options }: GeminiHandlerOptions) {
@@ -167,26 +166,18 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
167166
}
168167
}
169168

170-
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
171-
try {
172-
return await super.countTokens(content)
173-
} catch (error) {
174-
console.warn("Gemini local token counting failed, falling back to remote API", error)
175-
try {
176-
const { id: model } = this.getModel()
177-
const response = await this.client.models.countTokens({
178-
model,
179-
contents: convertAnthropicContentToGemini(content),
180-
})
181-
if (response.totalTokens !== undefined) {
182-
return response.totalTokens
183-
}
184-
console.warn("Gemini remote token counting returned undefined, falling back to 0")
185-
} catch (remoteError) {
186-
console.warn("Gemini remote token counting failed, falling back to 0", remoteError)
187-
}
188-
return 0
169+
protected override async apiBasedTokenCount(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
170+
const { id: model } = this.getModel()
171+
const response = await this.client.models.countTokens({
172+
model,
173+
contents: convertAnthropicContentToGemini(content),
174+
})
175+
176+
if (response.totalTokens === undefined) {
177+
throw new Error("Gemini API returned undefined token count")
189178
}
179+
180+
return response.totalTokens
190181
}
191182

192183
public calculateCost({

src/api/providers/lm-studio.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
6464

6565
let inputTokens = 0
6666
try {
67-
inputTokens = await this.countTokens([{ type: "text", text: systemPrompt }, ...toContentBlocks(messages)])
67+
inputTokens = await this.countTokens([{ type: "text", text: systemPrompt }, ...toContentBlocks(messages)], {
68+
totalTokens: 0,
69+
maxTokens: null,
70+
effectiveThreshold: undefined,
71+
})
6872
} catch (err) {
6973
console.error("[LmStudio] Failed to count input tokens:", err)
7074
inputTokens = 0
@@ -112,7 +116,11 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
112116

113117
let outputTokens = 0
114118
try {
115-
outputTokens = await this.countTokens([{ type: "text", text: assistantText }])
119+
outputTokens = await this.countTokens([{ type: "text", text: assistantText }], {
120+
totalTokens: 0,
121+
maxTokens: null,
122+
effectiveThreshold: undefined,
123+
})
116124
} catch (err) {
117125
console.error("[LmStudio] Failed to count output tokens:", err)
118126
outputTokens = 0

src/api/utils/context-safety.ts

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import { TOKEN_BUFFER_PERCENTAGE } from "../../core/sliding-window"
2+
3+
type SafetyNetOptions = {
4+
projectedTokens: number
5+
contextWindow: number
6+
effectiveThreshold?: number
7+
allowedTokens: number
8+
}
9+
10+
/**
11+
* Calculates the allowed token limit for a given context window, reserving
12+
* space for the response and a safety buffer.
13+
*
14+
* @param contextWindow The total context window size of the model.
15+
* @param maxTokens The maximum number of tokens reserved for the response.
16+
* @returns The number of tokens allowed for the prompt context.
17+
*/
18+
export function getAllowedTokens(contextWindow: number, maxTokens?: number | null) {
19+
// Calculate the maximum tokens reserved for response
20+
const reservedTokens = maxTokens ?? contextWindow * 0.2
21+
22+
// Calculate available tokens for conversation history
23+
// Truncate if we're within TOKEN_BUFFER_PERCENTAGE of the context window
24+
return contextWindow * (1 - TOKEN_BUFFER_PERCENTAGE) - reservedTokens
25+
}
26+
27+
/**
28+
* Determines if the token counting safety net should be triggered.
29+
*
30+
* The safety net is triggered if the projected token count exceeds either:
31+
* 1. The effective condensation threshold (as a percentage of the context window).
32+
* 2. The absolute allowed token limit.
33+
*
34+
* @param options The options for the safety net check.
35+
* @returns True if the safety net should be triggered, false otherwise.
36+
*/
37+
export function isSafetyNetTriggered({
38+
projectedTokens,
39+
contextWindow,
40+
effectiveThreshold,
41+
allowedTokens,
42+
}: SafetyNetOptions): boolean {
43+
// Ensure a valid threshold, defaulting to a high value if not provided,
44+
// which effectively relies on the allowedTokens check.
45+
const threshold = effectiveThreshold ?? 100
46+
const contextPercent = (100 * projectedTokens) / contextWindow
47+
48+
return contextPercent >= threshold || projectedTokens > allowedTokens
49+
}

src/core/condense/index.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,13 @@ export async function summarizeConversation(
198198
typeof message.content === "string" ? [{ text: message.content, type: "text" as const }] : message.content,
199199
)
200200

201-
const newContextTokens = outputTokens + (await apiHandler.countTokens(contextBlocks))
201+
const newContextTokens =
202+
outputTokens +
203+
(await apiHandler.countTokens(contextBlocks, {
204+
totalTokens: 0,
205+
maxTokens: null,
206+
effectiveThreshold: undefined,
207+
}))
202208
if (newContextTokens >= prevContextTokens) {
203209
const error = t("common:errors.condense_context_grew")
204210
return { ...response, cost, error }

0 commit comments

Comments
 (0)