@@ -6,6 +6,7 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"
66
77import { t } from "i18next"
88import { GeminiHandler } from "../gemini"
9+ import { BaseProvider } from "../base-provider"
910
1011const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
1112
@@ -248,4 +249,102 @@ describe("GeminiHandler", () => {
248249 expect ( cost ) . toBeUndefined ( )
249250 } )
250251 } )
252+
253+ describe ( "countTokens" , ( ) => {
254+ const mockContent : Anthropic . Messages . ContentBlockParam [ ] = [
255+ {
256+ type : "text" ,
257+ text : "Hello world" ,
258+ } ,
259+ ]
260+
261+ beforeEach ( ( ) => {
262+ // Add countTokens mock to the client
263+ handler [ "client" ] . models . countTokens = vitest . fn ( )
264+ } )
265+
266+ it ( "should return token count from Gemini API when valid" , async ( ) => {
267+ // Mock successful response with valid totalTokens
268+ ; ( handler [ "client" ] . models . countTokens as any ) . mockResolvedValue ( {
269+ totalTokens : 42 ,
270+ } )
271+
272+ const result = await handler . countTokens ( mockContent )
273+ expect ( result ) . toBe ( 42 )
274+
275+ // Verify the API was called correctly
276+ expect ( handler [ "client" ] . models . countTokens ) . toHaveBeenCalledWith ( {
277+ model : GEMINI_20_FLASH_THINKING_NAME ,
278+ contents : expect . any ( Object ) ,
279+ } )
280+ } )
281+
282+ it ( "should fall back to base provider when totalTokens is undefined" , async ( ) => {
283+ // Mock response with undefined totalTokens
284+ ; ( handler [ "client" ] . models . countTokens as any ) . mockResolvedValue ( {
285+ totalTokens : undefined ,
286+ } )
287+
288+ // Spy on the base provider's countTokens method
289+ const baseCountTokensSpy = vitest . spyOn ( BaseProvider . prototype , "countTokens" )
290+ baseCountTokensSpy . mockResolvedValue ( 100 )
291+
292+ const result = await handler . countTokens ( mockContent )
293+ expect ( result ) . toBe ( 100 )
294+ expect ( baseCountTokensSpy ) . toHaveBeenCalledWith ( mockContent )
295+ } )
296+
297+ it ( "should fall back to base provider when totalTokens is null" , async ( ) => {
298+ // Mock response with null totalTokens
299+ ; ( handler [ "client" ] . models . countTokens as any ) . mockResolvedValue ( {
300+ totalTokens : null ,
301+ } )
302+
303+ // Spy on the base provider's countTokens method
304+ const baseCountTokensSpy = vitest . spyOn ( BaseProvider . prototype , "countTokens" )
305+ baseCountTokensSpy . mockResolvedValue ( 100 )
306+
307+ const result = await handler . countTokens ( mockContent )
308+ expect ( result ) . toBe ( 100 )
309+ expect ( baseCountTokensSpy ) . toHaveBeenCalledWith ( mockContent )
310+ } )
311+
312+ it ( "should fall back to base provider when totalTokens is NaN" , async ( ) => {
313+ // Mock response with NaN totalTokens
314+ ; ( handler [ "client" ] . models . countTokens as any ) . mockResolvedValue ( {
315+ totalTokens : NaN ,
316+ } )
317+
318+ // Spy on the base provider's countTokens method
319+ const baseCountTokensSpy = vitest . spyOn ( BaseProvider . prototype , "countTokens" )
320+ baseCountTokensSpy . mockResolvedValue ( 100 )
321+
322+ const result = await handler . countTokens ( mockContent )
323+ expect ( result ) . toBe ( 100 )
324+ expect ( baseCountTokensSpy ) . toHaveBeenCalledWith ( mockContent )
325+ } )
326+
327+ it ( "should return 0 when totalTokens is 0" , async ( ) => {
328+ // Mock response with 0 totalTokens - this should be valid
329+ ; ( handler [ "client" ] . models . countTokens as any ) . mockResolvedValue ( {
330+ totalTokens : 0 ,
331+ } )
332+
333+ const result = await handler . countTokens ( mockContent )
334+ expect ( result ) . toBe ( 0 )
335+ } )
336+
337+ it ( "should fall back to base provider on API error" , async ( ) => {
338+ // Mock API error
339+ ; ( handler [ "client" ] . models . countTokens as any ) . mockRejectedValue ( new Error ( "API Error" ) )
340+
341+ // Spy on the base provider's countTokens method
342+ const baseCountTokensSpy = vitest . spyOn ( BaseProvider . prototype , "countTokens" )
343+ baseCountTokensSpy . mockResolvedValue ( 100 )
344+
345+ const result = await handler . countTokens ( mockContent )
346+ expect ( result ) . toBe ( 100 )
347+ expect ( baseCountTokensSpy ) . toHaveBeenCalledWith ( mockContent )
348+ } )
349+ } )
251350} )
0 commit comments