@@ -2,14 +2,66 @@ import { Anthropic } from "@anthropic-ai/sdk"
22
33import type { ModelInfo } from "@roo-code/types"
44
5+ import { getAllowedTokens , isSafetyNetTriggered } from "../utils/context-safety"
56import type { ApiHandler , ApiHandlerCreateMessageMetadata } from "../index"
67import { ApiStream } from "../transform/stream"
7- import { countTokens } from "../../utils/countTokens"
8+ import { countTokens as localCountTokens } from "../../utils/countTokens"
89
910/**
10- * Base class for API providers that implements common functionality.
11+ * A utility class to compare local token estimates with precise API counts
12+ * and calculate a factor to improve estimation accuracy.
13+ */
14+ class TokenCountComparator {
15+ private static readonly MAX_SAMPLES = 20
16+ private static readonly DEFAULT_SAFETY_FACTOR = 1.2
17+ private static readonly ADDITIONAL_SAFETY_FACTOR = 1.0
18+
19+ private samples : Array < { local : number ; api : number } > = [ ]
20+ private safetyFactor = TokenCountComparator . DEFAULT_SAFETY_FACTOR
21+
22+ public addSample ( local : number , api : number ) : void {
23+ if ( local > 0 && api > 0 ) {
24+ this . samples . push ( { local, api } )
25+ if ( this . samples . length > TokenCountComparator . MAX_SAMPLES ) {
26+ this . samples . shift ( )
27+ }
28+ this . recalculateSafetyFactor ( )
29+ }
30+ }
31+
32+ public getSafetyFactor ( ) : number {
33+ return this . safetyFactor
34+ }
35+
36+ private recalculateSafetyFactor ( ) : void {
37+ if ( this . samples . length === 0 ) {
38+ this . safetyFactor = TokenCountComparator . DEFAULT_SAFETY_FACTOR
39+ return
40+ }
41+
42+ const totalRatio = this . samples . reduce ( ( sum , sample ) => sum + sample . api / sample . local , 0 )
43+ const averageRatio = totalRatio / this . samples . length
44+ this . safetyFactor = Math . max ( 1 , averageRatio ) * TokenCountComparator . ADDITIONAL_SAFETY_FACTOR
45+ }
46+
47+ public getSampleCount ( ) : number {
48+ return this . samples . length
49+ }
50+
51+ public getAverageRatio ( ) : number {
52+ if ( this . samples . length === 0 ) return 1
53+ const totalRatio = this . samples . reduce ( ( sum , sample ) => sum + sample . api / sample . local , 0 )
54+ return totalRatio / this . samples . length
55+ }
56+ }
57+
58+ /**
59+ * Base class for API providers that implements common functionality
1160 */
1261export abstract class BaseProvider implements ApiHandler {
62+ protected isFirstRequest = true
63+ protected tokenComparator = new TokenCountComparator ( )
64+
1365 abstract createMessage (
1466 systemPrompt : string ,
1567 messages : Anthropic . Messages . MessageParam [ ] ,
@@ -18,18 +70,63 @@ export abstract class BaseProvider implements ApiHandler {
1870
1971 abstract getModel ( ) : { id : string ; info : ModelInfo }
2072
21- /**
22- * Default token counting implementation using tiktoken.
23- * Providers can override this to use their native token counting endpoints.
24- *
25- * @param content The content to count tokens for
26- * @returns A promise resolving to the token count
27- */
28- async countTokens ( content : Anthropic . Messages . ContentBlockParam [ ] ) : Promise < number > {
73+ // Override this function for each API provider
74+ protected async apiBasedTokenCount ( content : Anthropic . Messages . ContentBlockParam [ ] ) {
75+ return await localCountTokens ( content , { useWorker : true } )
76+ }
77+
78+ async countTokens (
79+ content : Anthropic . Messages . ContentBlockParam [ ] ,
80+ options : {
81+ maxTokens ?: number | null
82+ effectiveThreshold ?: number
83+ totalTokens : number
84+ } ,
85+ ) : Promise < number > {
2986 if ( content . length === 0 ) {
3087 return 0
3188 }
3289
33- return countTokens ( content , { useWorker : true } )
90+ const providerName = this . constructor . name
91+
92+ if ( this . isFirstRequest ) {
93+ this . isFirstRequest = false
94+ try {
95+ const apiCount = await this . apiBasedTokenCount ( content )
96+ const localEstimate = await localCountTokens ( content , { useWorker : true } )
97+ this . tokenComparator . addSample ( localEstimate , apiCount )
98+
99+ return apiCount
100+ } catch ( error ) {
101+ const localEstimate = await localCountTokens ( content , { useWorker : true } )
102+ return localEstimate
103+ }
104+ }
105+
106+ const localEstimate = await localCountTokens ( content , { useWorker : true } )
107+
108+ const { info } = this . getModel ( )
109+ const contextWindow = info . contextWindow
110+ const allowedTokens = getAllowedTokens ( contextWindow , options . maxTokens )
111+ const projectedTokens = options . totalTokens + localEstimate * this . tokenComparator . getSafetyFactor ( )
112+
113+ if (
114+ isSafetyNetTriggered ( {
115+ projectedTokens,
116+ contextWindow,
117+ effectiveThreshold : options . effectiveThreshold ,
118+ allowedTokens,
119+ } )
120+ ) {
121+ try {
122+ const apiCount = await this . apiBasedTokenCount ( content )
123+ this . tokenComparator . addSample ( localEstimate , apiCount )
124+ return apiCount
125+ } catch ( error ) {
126+ return Math . ceil ( localEstimate * this . tokenComparator . getSafetyFactor ( ) )
127+ }
128+ }
129+
130+ return localEstimate
34131 }
35132}
0 commit comments