Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 171 additions & 15 deletions src/api/providers/__tests__/gemini.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,63 @@ import { Anthropic } from "@anthropic-ai/sdk"
import { GeminiHandler } from "../gemini"
import { geminiDefaultModelId } from "../../../shared/api"

const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
const GEMINI_THINKING_MODEL = "gemini-2.5-flash-preview-04-17:thinking"

describe("GeminiHandler", () => {
let handler: GeminiHandler
let thinkingHandler: GeminiHandler
let proHandler: GeminiHandler

beforeEach(() => {
// Create mock functions
const mockGenerateContentStream = jest.fn()
const mockGenerateContent = jest.fn()
const mockGetGenerativeModel = jest.fn()

// Regular handler without thinking capabilities
handler = new GeminiHandler({
apiKey: "test-key",
apiModelId: GEMINI_20_FLASH_THINKING_NAME,
apiModelId: "gemini-2.5-flash-preview-04-17", // Non-thinking model
geminiApiKey: "test-key",
})

// Replace the client with our mock
handler["client"] = {
// Handler with thinking capabilities
thinkingHandler = new GeminiHandler({
apiKey: "test-key",
apiModelId: GEMINI_THINKING_MODEL,
geminiApiKey: "test-key",
})

// Pro handler with different capabilities and pricing
proHandler = new GeminiHandler({
apiKey: "test-key",
apiModelId: "gemini-2.5-pro-preview-03-25",
geminiApiKey: "test-key",
})

// Replace the clients with our mocks
const mockClient = {
models: {
generateContentStream: mockGenerateContentStream,
generateContent: mockGenerateContent,
getGenerativeModel: mockGetGenerativeModel,
},
} as any

handler["client"] = mockClient
thinkingHandler["client"] = { ...mockClient }
proHandler["client"] = { ...mockClient }
})

describe("constructor", () => {
it("should initialize with provided config", () => {
expect(handler["options"].geminiApiKey).toBe("test-key")
expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME)
// Regular handler should have non-thinking model
expect(handler["options"].apiModelId).toBe("gemini-2.5-flash-preview-04-17")
// Thinking handler should have thinking model
expect(thinkingHandler["options"].apiModelId).toBe(GEMINI_THINKING_MODEL)
// Pro handler should have pro model
expect(proHandler["options"].apiModelId).toBe("gemini-2.5-pro-preview-03-25")
})
})

Expand All @@ -53,17 +79,18 @@ describe("GeminiHandler", () => {

const systemPrompt = "You are a helpful assistant"

it("should handle text messages correctly", async () => {
it("should handle text messages without thinking capabilities correctly", async () => {
// Setup the mock implementation to return an async generator
;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({
const mockGenerateContentStream = handler["client"].models.generateContentStream as jest.Mock
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Hello" }
yield { text: " world!" }
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
},
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const stream = handler.createMessage(systemPrompt, mockMessages) // Using standard handler without thinking capabilities
const chunks = []

for await (const chunk of stream) {
Expand All @@ -84,12 +111,119 @@ describe("GeminiHandler", () => {
type: "usage",
inputTokens: 10,
outputTokens: 5,
thoughtsTokenCount: undefined, // thoughtsTokenCount should be undefined when not thinking
thinkingBudget: undefined, // Added expected field
})

// Verify the call to generateContentStream
expect(thinkingHandler["client"].models.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
model: "gemini-2.5-flash-preview-04-17",
config: expect.objectContaining({
temperature: 0,
systemInstruction: systemPrompt,
}),
}),
)
})

it("should handle text messages with thinking capabilities correctly", async () => {
// Setup the mock implementation with thinking tokens for the thinking handler
const mockGenerateContentStream = thinkingHandler["client"].models.generateContentStream as jest.Mock
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Thinking..." }
yield {
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 5,
thoughtsTokenCount: 25,
},
}
},
})

const stream = thinkingHandler.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have 2 chunks: 'Thinking...' and usage info with thinking tokens
expect(chunks.length).toBe(2)
expect(chunks[0]).toEqual({
type: "text",
text: "Thinking...",
})
expect(chunks[1]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
thoughtsTokenCount: 25,
thinkingBudget: 24_576, // From gemini-2.5-flash-preview-04-17:thinking model info
})

// Verify the call includes thinkingConfig
expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
model: GEMINI_20_FLASH_THINKING_NAME,
model: "gemini-2.5-flash-preview-04-17",
config: expect.objectContaining({
temperature: 0,
systemInstruction: systemPrompt,
thinkingConfig: {
thinkingBudget: 24_576,
},
}),
}),
)
})

it("should handle text messages with pro model correctly", async () => {
// Setup the mock implementation for pro model
const mockGenerateContentStream = proHandler["client"].models.generateContentStream as jest.Mock
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Pro model" }
yield { text: " response" }
yield {
usageMetadata: {
promptTokenCount: 15,
candidatesTokenCount: 8,
},
}
},
})

const stream = proHandler.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have 3 chunks: 'Pro model', ' response', and usage info
expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({
type: "text",
text: "Pro model",
})
expect(chunks[1]).toEqual({
type: "text",
text: " response",
})
expect(chunks[2]).toEqual({
type: "usage",
inputTokens: 15,
outputTokens: 8,
thoughtsTokenCount: undefined,
thinkingBudget: undefined,
})

// Verify the call to generateContentStream
expect(proHandler["client"].models.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
model: "gemini-2.5-pro-preview-03-25",
config: expect.objectContaining({
temperature: 0,
systemInstruction: systemPrompt,
Expand All @@ -113,7 +247,7 @@ describe("GeminiHandler", () => {
})

describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
it("should complete prompt successfully with non-thinking model", async () => {
// Mock the response with text property
;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({
text: "Test response",
Expand All @@ -124,7 +258,7 @@ describe("GeminiHandler", () => {

// Verify the call to generateContent
expect(handler["client"].models.generateContent).toHaveBeenCalledWith({
model: GEMINI_20_FLASH_THINKING_NAME,
model: "gemini-2.5-flash-preview-04-17", // Use the non-thinking model ID
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
config: {
httpOptions: undefined,
Expand Down Expand Up @@ -154,12 +288,34 @@ describe("GeminiHandler", () => {
})

describe("getModel", () => {
it("should return correct model info", () => {
it("should return correct model info for non-thinking model", () => {
const modelInfo = handler.getModel()
expect(modelInfo.id).toBe(GEMINI_20_FLASH_THINKING_NAME)
expect(modelInfo.id).toBe("gemini-2.5-flash-preview-04-17")
expect(modelInfo.info).toBeDefined()
expect(modelInfo.thinkingConfig).toBeUndefined()
expect(modelInfo.info.maxTokens).toBe(65_535)
expect(modelInfo.info.contextWindow).toBe(1_048_576)
})

it("should return correct model info for thinking model", () => {
const modelInfo = thinkingHandler.getModel()
expect(modelInfo.id).toBe("gemini-2.5-flash-preview-04-17")
expect(modelInfo.info).toBeDefined()
expect(modelInfo.thinkingConfig).toBeDefined()
expect(modelInfo.thinkingConfig?.thinkingBudget).toBe(24_576)
expect(modelInfo.info.maxTokens).toBe(65_535)
expect(modelInfo.info.contextWindow).toBe(1_048_576)
})

it("should return correct model info for pro model", () => {
const modelInfo = proHandler.getModel()
expect(modelInfo.id).toBe("gemini-2.5-pro-preview-03-25")
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(8192)
expect(modelInfo.info.contextWindow).toBe(32_767)
expect(modelInfo.thinkingConfig).toBeUndefined()
expect(modelInfo.info.maxTokens).toBe(65_535)
expect(modelInfo.info.contextWindow).toBe(1_048_576)
expect(modelInfo.info.inputPrice).toBe(2.5)
expect(modelInfo.info.outputPrice).toBe(15)
})

it("should return default model if invalid model specified", () => {
Expand Down
29 changes: 19 additions & 10 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { SingleCompletionHandler } from "../"
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
import type { ApiStream } from "../transform/stream"
import type { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { BaseProvider } from "./base-provider"

export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
Expand Down Expand Up @@ -59,7 +59,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
type: "usage",
inputTokens: lastUsageMetadata.promptTokenCount ?? 0,
outputTokens: lastUsageMetadata.candidatesTokenCount ?? 0,
}
thoughtsTokenCount: lastUsageMetadata.thoughtsTokenCount ?? undefined,
thinkingBudget: thinkingConfig?.thinkingBudget,
} satisfies ApiStreamUsageChunk
}
}

Expand All @@ -70,20 +72,27 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
maxOutputTokens?: number
} {
let id = this.options.apiModelId ? (this.options.apiModelId as GeminiModelId) : geminiDefaultModelId
let info: ModelInfo = geminiModels[id]
const thinkingSuffix = ":thinking"
let thinkingConfig: ThinkingConfig | undefined = undefined
let maxOutputTokens: number | undefined = undefined

const thinkingSuffix = ":thinking"
// If this is a thinking model, get the info before modifying the ID
// so we can access the maxThinkingTokens value
let info: ModelInfo = geminiModels[id]
const originalInfo = id?.endsWith(thinkingSuffix) ? info : undefined

if (id?.endsWith(thinkingSuffix)) {
if (originalInfo) {
console.log("modelMaxThinkingTokens debug:", {
value: this.options.modelMaxThinkingTokens,
type: typeof this.options.modelMaxThinkingTokens,
infoMaxThinkingTokens: originalInfo.maxThinkingTokens,
})
const maxThinkingTokens = this.options.modelMaxThinkingTokens ?? originalInfo.maxThinkingTokens ?? 4096
thinkingConfig = { thinkingBudget: maxThinkingTokens }

// Remove thinking suffix and get base model info
id = id.slice(0, -thinkingSuffix.length) as GeminiModelId
info = geminiModels[id]

thinkingConfig = this.options.modelMaxThinkingTokens
? { thinkingBudget: this.options.modelMaxThinkingTokens }
: undefined

maxOutputTokens = this.options.modelMaxTokens ?? info.maxTokens ?? undefined
}

Expand Down
2 changes: 2 additions & 0 deletions src/api/transform/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ export interface ApiStreamUsageChunk {
cacheWriteTokens?: number
cacheReadTokens?: number
totalCost?: number // openrouter
thoughtsTokenCount?: number
thinkingBudget?: number
}
8 changes: 8 additions & 0 deletions src/core/Cline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,8 @@ export class Cline extends EventEmitter<ClineEvents> {
let inputTokens = 0
let outputTokens = 0
let totalCost: number | undefined
let thoughtsTokenCount: number | undefined
let thinkingBudget: number | undefined

// update api_req_started. we can't use api_req_finished anymore since it's a unique case where it could come after a streaming message (ie in the middle of being updated or executed)
// fortunately api_req_finished was always parsed out for the gui anyways, so it remains solely for legacy purposes to keep track of prices in tasks from history
Expand All @@ -1688,6 +1690,8 @@ export class Cline extends EventEmitter<ClineEvents> {
tokensOut: outputTokens,
cacheWrites: cacheWriteTokens,
cacheReads: cacheReadTokens,
thoughtsTokenCount,
thinkingBudget: thinkingBudget,
cost:
totalCost ??
calculateApiCostAnthropic(
Expand Down Expand Up @@ -1781,6 +1785,10 @@ export class Cline extends EventEmitter<ClineEvents> {
cacheWriteTokens += chunk.cacheWriteTokens ?? 0
cacheReadTokens += chunk.cacheReadTokens ?? 0
totalCost = chunk.totalCost
if (typeof chunk.thoughtsTokenCount === "number") {
thoughtsTokenCount = (thoughtsTokenCount ?? 0) + chunk.thoughtsTokenCount
}
thinkingBudget = chunk.thinkingBudget
break
case "text":
assistantMessage += chunk.text
Expand Down
6 changes: 6 additions & 0 deletions src/exports/roo-code.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ type TokenUsage = {
totalCacheReads?: number | undefined
totalCost: number
contextTokens: number
thoughtsTokenCount?: number | undefined
thinkingBudget?: number | undefined
}

type RooCodeEvents = {
Expand Down Expand Up @@ -524,6 +526,8 @@ type RooCodeEvents = {
totalCacheReads?: number | undefined
totalCost: number
contextTokens: number
thoughtsTokenCount?: number | undefined
thinkingBudget?: number | undefined
},
{
[x: string]: {
Expand All @@ -541,6 +545,8 @@ type RooCodeEvents = {
totalCacheReads?: number | undefined
totalCost: number
contextTokens: number
thoughtsTokenCount?: number | undefined
thinkingBudget?: number | undefined
},
]
}
Expand Down
Loading