From 18995aac46c76259018a89b45a757e2e6cd98ef9 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Thu, 18 Sep 2025 01:10:28 +0000 Subject: [PATCH] fix: correct Gemini countTokens payload format to use Content[] instead of Part[] - Fixed countTokens method to wrap parts in proper Content structure with user role - Added comprehensive tests for countTokens including multimodal content and fallback scenarios - This resolves premature context truncation due to incorrect token counting Fixes #8113 --- src/api/providers/__tests__/gemini.spec.ts | 135 +++++++++++++++++++++ src/api/providers/gemini.ts | 4 +- 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 812c1ae1a6..2299747afb 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -6,6 +6,7 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types" import { t } from "i18next" import { GeminiHandler } from "../gemini" +import { BaseProvider } from "../base-provider" const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219" @@ -248,4 +249,138 @@ describe("GeminiHandler", () => { expect(cost).toBeUndefined() }) }) + + describe("countTokens", () => { + it("should count tokens successfully with correct Content[] format", async () => { + // Mock the countTokens response + const mockCountTokens = vitest.fn().mockResolvedValue({ + totalTokens: 42, + }) + + handler["client"].models.countTokens = mockCountTokens + + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello world" }] + + const result = await handler.countTokens(content) + expect(result).toBe(42) + + // Verify the call was made with correct Content[] format + expect(mockCountTokens).toHaveBeenCalledWith({ + model: GEMINI_20_FLASH_THINKING_NAME, + contents: [ + { + role: "user", + parts: [{ text: "Hello world" }], + }, + ], + }) + }) + + it("should handle multimodal content correctly", async () => { + const mockCountTokens = vitest.fn().mockResolvedValue({ + totalTokens: 100, + }) + + handler["client"].models.countTokens = mockCountTokens + + const content: Anthropic.Messages.ContentBlockParam[] = [ + { type: "text", text: "Describe this image:" }, + { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: "base64data", + }, + }, + ] + + const result = await handler.countTokens(content) + expect(result).toBe(100) + + // Verify the Content[] structure with mixed content + expect(mockCountTokens).toHaveBeenCalledWith({ + model: GEMINI_20_FLASH_THINKING_NAME, + contents: [ + { + role: "user", + parts: [ + { text: "Describe this image:" }, + { inlineData: { data: "base64data", mimeType: "image/jpeg" } }, + ], + }, + ], + }) + }) + + it("should fall back to base provider when SDK returns undefined", async () => { + // Mock countTokens to return undefined totalTokens + const mockCountTokens = vitest.fn().mockResolvedValue({ + totalTokens: undefined, + }) + + handler["client"].models.countTokens = mockCountTokens + + // Spy on the parent class method + const superCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens") + + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }] + + await handler.countTokens(content) + + // Verify fallback was called + expect(superCountTokensSpy).toHaveBeenCalledWith(content) + }) + + it("should fall back to base provider when SDK throws error", async () => { + // Mock countTokens to throw an error + const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API error")) + + handler["client"].models.countTokens = mockCountTokens + + // Spy on console.warn + const consoleWarnSpy = vitest.spyOn(console, "warn").mockImplementation(() => {}) + + // Spy on the parent class method + const superCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens") + + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }] + + await handler.countTokens(content) + + // Verify warning was logged + expect(consoleWarnSpy).toHaveBeenCalledWith( + "Gemini token counting failed, using fallback", + expect.any(Error), + ) + + // Verify fallback was called + expect(superCountTokensSpy).toHaveBeenCalledWith(content) + + // Clean up + consoleWarnSpy.mockRestore() + }) + + it("should handle empty content array", async () => { + const mockCountTokens = vitest.fn().mockResolvedValue({ + totalTokens: 0, + }) + + handler["client"].models.countTokens = mockCountTokens + + const result = await handler.countTokens([]) + expect(result).toBe(0) + + // Verify the call with empty parts + expect(mockCountTokens).toHaveBeenCalledWith({ + model: GEMINI_20_FLASH_THINKING_NAME, + contents: [ + { + role: "user", + parts: [], + }, + ], + }) + }) + }) }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 775d763a05..e7d7235e7b 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -258,9 +258,11 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl try { const { id: model } = this.getModel() + // Wrap the parts in a proper Content structure with user role + // The SDK expects Content[] format: [{ role: "user", parts: Part[] }] const response = await this.client.models.countTokens({ model, - contents: convertAnthropicContentToGemini(content), + contents: [{ role: "user", parts: convertAnthropicContentToGemini(content) }], }) if (response.totalTokens === undefined) {