Skip to content

Commit 9eab89d

Browse files
committed
fix: correct Gemini token counting and context window for gemini-2.5-pro
- Fix token counting in GeminiHandler to properly validate response.totalTokens - Update context window for gemini-2.5-pro from 1,048,576 to 249,500 tokens - Add comprehensive tests for countTokens method Fixes #6891
1 parent 3ee6072 commit 9eab89d

File tree

3 files changed

+103
-3
lines changed

3 files changed

+103
-3
lines changed

packages/types/src/providers/gemini.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ export const geminiModels = {
144144
},
145145
"gemini-2.5-pro": {
146146
maxTokens: 64_000,
147-
contextWindow: 1_048_576,
147+
contextWindow: 249_500,
148148
supportsImages: true,
149149
supportsPromptCache: true,
150150
inputPrice: 2.5, // This is the pricing for prompts above 200k tokens.

src/api/providers/__tests__/gemini.spec.ts

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"
66

77
import { t } from "i18next"
88
import { GeminiHandler } from "../gemini"
9+
import { BaseProvider } from "../base-provider"
910

1011
const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
1112

@@ -248,4 +249,102 @@ describe("GeminiHandler", () => {
248249
expect(cost).toBeUndefined()
249250
})
250251
})
252+
253+
describe("countTokens", () => {
254+
const mockContent: Anthropic.Messages.ContentBlockParam[] = [
255+
{
256+
type: "text",
257+
text: "Hello world",
258+
},
259+
]
260+
261+
beforeEach(() => {
262+
// Add countTokens mock to the client
263+
handler["client"].models.countTokens = vitest.fn()
264+
})
265+
266+
it("should return token count from Gemini API when valid", async () => {
267+
// Mock successful response with valid totalTokens
268+
;(handler["client"].models.countTokens as any).mockResolvedValue({
269+
totalTokens: 42,
270+
})
271+
272+
const result = await handler.countTokens(mockContent)
273+
expect(result).toBe(42)
274+
275+
// Verify the API was called correctly
276+
expect(handler["client"].models.countTokens).toHaveBeenCalledWith({
277+
model: GEMINI_20_FLASH_THINKING_NAME,
278+
contents: expect.any(Object),
279+
})
280+
})
281+
282+
it("should fall back to base provider when totalTokens is undefined", async () => {
283+
// Mock response with undefined totalTokens
284+
;(handler["client"].models.countTokens as any).mockResolvedValue({
285+
totalTokens: undefined,
286+
})
287+
288+
// Spy on the base provider's countTokens method
289+
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
290+
baseCountTokensSpy.mockResolvedValue(100)
291+
292+
const result = await handler.countTokens(mockContent)
293+
expect(result).toBe(100)
294+
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
295+
})
296+
297+
it("should fall back to base provider when totalTokens is null", async () => {
298+
// Mock response with null totalTokens
299+
;(handler["client"].models.countTokens as any).mockResolvedValue({
300+
totalTokens: null,
301+
})
302+
303+
// Spy on the base provider's countTokens method
304+
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
305+
baseCountTokensSpy.mockResolvedValue(100)
306+
307+
const result = await handler.countTokens(mockContent)
308+
expect(result).toBe(100)
309+
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
310+
})
311+
312+
it("should fall back to base provider when totalTokens is NaN", async () => {
313+
// Mock response with NaN totalTokens
314+
;(handler["client"].models.countTokens as any).mockResolvedValue({
315+
totalTokens: NaN,
316+
})
317+
318+
// Spy on the base provider's countTokens method
319+
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
320+
baseCountTokensSpy.mockResolvedValue(100)
321+
322+
const result = await handler.countTokens(mockContent)
323+
expect(result).toBe(100)
324+
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
325+
})
326+
327+
it("should return 0 when totalTokens is 0", async () => {
328+
// Mock response with 0 totalTokens - this should be valid
329+
;(handler["client"].models.countTokens as any).mockResolvedValue({
330+
totalTokens: 0,
331+
})
332+
333+
const result = await handler.countTokens(mockContent)
334+
expect(result).toBe(0)
335+
})
336+
337+
it("should fall back to base provider on API error", async () => {
338+
// Mock API error
339+
;(handler["client"].models.countTokens as any).mockRejectedValue(new Error("API Error"))
340+
341+
// Spy on the base provider's countTokens method
342+
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
343+
baseCountTokensSpy.mockResolvedValue(100)
344+
345+
const result = await handler.countTokens(mockContent)
346+
expect(result).toBe(100)
347+
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
348+
})
349+
})
251350
})

src/api/providers/gemini.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
253253
contents: convertAnthropicContentToGemini(content),
254254
})
255255

256-
if (response.totalTokens === undefined) {
257-
console.warn("Gemini token counting returned undefined, using fallback")
256+
// Check if totalTokens is a valid number (not undefined, null, or NaN)
257+
if (typeof response.totalTokens !== "number" || isNaN(response.totalTokens)) {
258+
console.warn("Gemini token counting returned invalid value, using fallback", response.totalTokens)
258259
return super.countTokens(content)
259260
}
260261

0 commit comments

Comments
 (0)