Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions src/api/providers/__tests__/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"

import { t } from "i18next"
import { GeminiHandler } from "../gemini"
import { BaseProvider } from "../base-provider"

const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"

Expand Down Expand Up @@ -248,4 +249,138 @@ describe("GeminiHandler", () => {
expect(cost).toBeUndefined()
})
})

describe("countTokens", () => {
it("should count tokens successfully with correct Content[] format", async () => {
// Mock the countTokens response
const mockCountTokens = vitest.fn().mockResolvedValue({
totalTokens: 42,
})

handler["client"].models.countTokens = mockCountTokens

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello world" }]

const result = await handler.countTokens(content)
expect(result).toBe(42)

// Verify the call was made with correct Content[] format
expect(mockCountTokens).toHaveBeenCalledWith({
model: GEMINI_20_FLASH_THINKING_NAME,
contents: [
{
role: "user",
parts: [{ text: "Hello world" }],
},
],
})
})

it("should handle multimodal content correctly", async () => {
const mockCountTokens = vitest.fn().mockResolvedValue({
totalTokens: 100,
})

handler["client"].models.countTokens = mockCountTokens

const content: Anthropic.Messages.ContentBlockParam[] = [
{ type: "text", text: "Describe this image:" },
{
type: "image",
source: {
type: "base64",
media_type: "image/jpeg",
data: "base64data",
},
},
]

const result = await handler.countTokens(content)
expect(result).toBe(100)

// Verify the Content[] structure with mixed content
expect(mockCountTokens).toHaveBeenCalledWith({
model: GEMINI_20_FLASH_THINKING_NAME,
contents: [
{
role: "user",
parts: [
{ text: "Describe this image:" },
{ inlineData: { data: "base64data", mimeType: "image/jpeg" } },
],
},
],
})
})

it("should fall back to base provider when SDK returns undefined", async () => {
// Mock countTokens to return undefined totalTokens
const mockCountTokens = vitest.fn().mockResolvedValue({
totalTokens: undefined,
})

handler["client"].models.countTokens = mockCountTokens

// Spy on the parent class method
const superCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]

await handler.countTokens(content)

// Verify fallback was called
expect(superCountTokensSpy).toHaveBeenCalledWith(content)
})

it("should fall back to base provider when SDK throws error", async () => {
// Mock countTokens to throw an error
const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API error"))

handler["client"].models.countTokens = mockCountTokens

// Spy on console.warn
const consoleWarnSpy = vitest.spyOn(console, "warn").mockImplementation(() => {})

// Spy on the parent class method
const superCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]

await handler.countTokens(content)

// Verify warning was logged
expect(consoleWarnSpy).toHaveBeenCalledWith(
"Gemini token counting failed, using fallback",
expect.any(Error),
)

// Verify fallback was called
expect(superCountTokensSpy).toHaveBeenCalledWith(content)

// Clean up
consoleWarnSpy.mockRestore()
})

it("should handle empty content array", async () => {
const mockCountTokens = vitest.fn().mockResolvedValue({
totalTokens: 0,
})

handler["client"].models.countTokens = mockCountTokens

const result = await handler.countTokens([])
expect(result).toBe(0)

// Verify the call with empty parts
expect(mockCountTokens).toHaveBeenCalledWith({
model: GEMINI_20_FLASH_THINKING_NAME,
contents: [
{
role: "user",
parts: [],
},
],
})
})
})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a test case with tool_use/tool_result content blocks? The issue specifically mentions functionCall/functionResponse parts can cause problems with the wrong format. Would be good to ensure our fix handles those cases too.

})
4 changes: 3 additions & 1 deletion src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,11 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
try {
const { id: model } = this.getModel()

// Wrap the parts in a proper Content structure with user role
// The SDK expects Content[] format: [{ role: "user", parts: Part[] }]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this comment slightly more detailed? Something like:

Suggested change
// The SDK expects Content[] format: [{ role: "user", parts: Part[] }]
// Wrap the parts in a proper Content structure with user role
// The SDK expects Content[] format: [{ role: "user", parts: Part[] }]
// Previously we were sending Part[] directly, causing the SDK to return undefined

This would help future maintainers understand why this specific format is critical.

const response = await this.client.models.countTokens({
model,
contents: convertAnthropicContentToGemini(content),
contents: [{ role: "user", parts: convertAnthropicContentToGemini(content) }],
})

if (response.totalTokens === undefined) {
Expand Down
Loading