Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/types/src/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ export const geminiModels = {
},
"gemini-2.5-pro": {
maxTokens: 64_000,
contextWindow: 1_048_576,
contextWindow: 249_500,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the context window intentionally set to 249,500 instead of the round 250,000 mentioned in the issue? This 500-token difference might be deliberate for a safety buffer, but wanted to confirm.

supportsImages: true,
supportsPromptCache: true,
inputPrice: 2.5, // This is the pricing for prompts above 200k tokens.
Expand Down
99 changes: 99 additions & 0 deletions src/api/providers/__tests__/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"

import { t } from "i18next"
import { GeminiHandler } from "../gemini"
import { BaseProvider } from "../base-provider"

const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"

Expand Down Expand Up @@ -248,4 +249,102 @@ describe("GeminiHandler", () => {
expect(cost).toBeUndefined()
})
})

describe("countTokens", () => {
const mockContent: Anthropic.Messages.ContentBlockParam[] = [
{
type: "text",
text: "Hello world",
},
]

beforeEach(() => {
// Add countTokens mock to the client
handler["client"].models.countTokens = vitest.fn()
})

it("should return token count from Gemini API when valid", async () => {
// Mock successful response with valid totalTokens
;(handler["client"].models.countTokens as any).mockResolvedValue({
totalTokens: 42,
})

const result = await handler.countTokens(mockContent)
expect(result).toBe(42)

// Verify the API was called correctly
expect(handler["client"].models.countTokens).toHaveBeenCalledWith({
model: GEMINI_20_FLASH_THINKING_NAME,
contents: expect.any(Object),
})
})

it("should fall back to base provider when totalTokens is undefined", async () => {
// Mock response with undefined totalTokens
;(handler["client"].models.countTokens as any).mockResolvedValue({
totalTokens: undefined,
})

// Spy on the base provider's countTokens method
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
baseCountTokensSpy.mockResolvedValue(100)

const result = await handler.countTokens(mockContent)
expect(result).toBe(100)
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
})

it("should fall back to base provider when totalTokens is null", async () => {
// Mock response with null totalTokens
;(handler["client"].models.countTokens as any).mockResolvedValue({
totalTokens: null,
})

// Spy on the base provider's countTokens method
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
baseCountTokensSpy.mockResolvedValue(100)

const result = await handler.countTokens(mockContent)
expect(result).toBe(100)
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
})

it("should fall back to base provider when totalTokens is NaN", async () => {
// Mock response with NaN totalTokens
;(handler["client"].models.countTokens as any).mockResolvedValue({
totalTokens: NaN,
})

// Spy on the base provider's countTokens method
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
baseCountTokensSpy.mockResolvedValue(100)

const result = await handler.countTokens(mockContent)
expect(result).toBe(100)
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
})

it("should return 0 when totalTokens is 0", async () => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent test coverage! These tests thoroughly validate all the edge cases (undefined, null, NaN, 0, and API errors). The fallback mechanism is well-tested.

// Mock response with 0 totalTokens - this should be valid
;(handler["client"].models.countTokens as any).mockResolvedValue({
totalTokens: 0,
})

const result = await handler.countTokens(mockContent)
expect(result).toBe(0)
})

it("should fall back to base provider on API error", async () => {
// Mock API error
;(handler["client"].models.countTokens as any).mockRejectedValue(new Error("API Error"))

// Spy on the base provider's countTokens method
const baseCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
baseCountTokensSpy.mockResolvedValue(100)

const result = await handler.countTokens(mockContent)
expect(result).toBe(100)
expect(baseCountTokensSpy).toHaveBeenCalledWith(mockContent)
})
})
})
5 changes: 3 additions & 2 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
contents: convertAnthropicContentToGemini(content),
})

if (response.totalTokens === undefined) {
console.warn("Gemini token counting returned undefined, using fallback")
// Check if totalTokens is a valid number (not undefined, null, or NaN)
if (typeof response.totalTokens !== "number" || isNaN(response.totalTokens)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good improvement on the validation! Though I wonder if we should also check for negative numbers explicitly? The current isNaN() check should catch most edge cases, but typeof response.totalTokens === 'number' && response.totalTokens >= 0 would be even more defensive.

console.warn("Gemini token counting returned invalid value, using fallback", response.totalTokens)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider including the model ID in the warning for easier debugging:

Suggested change
console.warn("Gemini token counting returned invalid value, using fallback", response.totalTokens)
console.warn(`Gemini token counting for ${model} returned invalid value, using fallback`, response.totalTokens)

return super.countTokens(content)
}

Expand Down
Loading