diff --git a/packages/types/src/providers/gemini.ts b/packages/types/src/providers/gemini.ts index a7225c7330..f7e6fd2c00 100644 --- a/packages/types/src/providers/gemini.ts +++ b/packages/types/src/providers/gemini.ts @@ -144,7 +144,7 @@ export const geminiModels = { }, "gemini-2.5-pro": { maxTokens: 64_000, - contextWindow: 1_048_576, + contextWindow: 249_500, supportsImages: true, supportsPromptCache: true, inputPrice: 2.5, // This is the pricing for prompts above 200k tokens. diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 812c1ae1a6..48e385ead3 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -6,6 +6,7 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types" import { t } from "i18next" import { GeminiHandler } from "../gemini" +import { BaseProvider } from "../base-provider" const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219" @@ -248,4 +249,102 @@ describe("GeminiHandler", () => { expect(cost).toBeUndefined() }) }) + + describe("countTokens", () => { + const mockContent: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Hello world", + }, + ] + + beforeEach(() => { + // Add countTokens mock to the client + handler["client"].models.countTokens = vitest.fn() + }) + + it("should return token count from Gemini API when valid", async () => { + // Mock successful response with valid totalTokens + ;(handler["client"].models.countTokens as any).mockResolvedValue({ + totalTokens: 42, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(42) + + // Verify the API was called correctly + expect(handler["client"].models.countTokens).toHaveBeenCalledWith({ + model: GEMINI_20_FLASH_THINKING_NAME, + contents: expect.any(Object), + }) + }) + + it("should fall back to base provider when totalTokens is undefined", async () => { + // Mock response with undefined totalTokens + ;(handler["client"].models.countTokens as any).mockResolvedValue({ + totalTokens: undefined, + }) + + // Spy on the base provider's countTokens method + const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens") + baseCountTokensSpy.mockResolvedValue(100) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(100) + expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent) + }) + + it("should fall back to base provider when totalTokens is null", async () => { + // Mock response with null totalTokens + ;(handler["client"].models.countTokens as any).mockResolvedValue({ + totalTokens: null, + }) + + // Spy on the base provider's countTokens method + const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens") + baseCountTokensSpy.mockResolvedValue(100) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(100) + expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent) + }) + + it("should fall back to base provider when totalTokens is NaN", async () => { + // Mock response with NaN totalTokens + ;(handler["client"].models.countTokens as any).mockResolvedValue({ + totalTokens: NaN, + }) + + // Spy on the base provider's countTokens method + const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens") + baseCountTokensSpy.mockResolvedValue(100) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(100) + expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent) + }) + + it("should return 0 when totalTokens is 0", async () => { + // Mock response with 0 totalTokens - this should be valid + ;(handler["client"].models.countTokens as any).mockResolvedValue({ + totalTokens: 0, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(0) + }) + + it("should fall back to base provider on API error", async () => { + // Mock API error + ;(handler["client"].models.countTokens as any).mockRejectedValue(new Error("API Error")) + + // Spy on the base provider's countTokens method + const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens") + baseCountTokensSpy.mockResolvedValue(100) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(100) + expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent) + }) + }) }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 5e547edbdc..701e963e8e 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -253,8 +253,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl contents: convertAnthropicContentToGemini(content), }) - if (response.totalTokens === undefined) { - console.warn("Gemini token counting returned undefined, using fallback") + // Check if totalTokens is a valid number (not undefined, null, or NaN) + if (typeof response.totalTokens !== "number" || isNaN(response.totalTokens)) { + console.warn("Gemini token counting returned invalid value, using fallback", response.totalTokens) return super.countTokens(content) }