From 57f2a4edda10b538c9935194cb9ea9727a8ba80b Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:03:27 -0700 Subject: [PATCH 01/20] Support Gemini 2.5 Flash thinking --- package-lock.json | 35 +++++- package.json | 2 +- src/api/providers/__tests__/gemini.test.ts | 139 +++++++-------------- src/api/providers/anthropic.ts | 16 +-- src/api/providers/gemini.ts | 126 +++++++++++-------- src/api/transform/gemini-format.ts | 69 ++-------- src/shared/api.ts | 20 +++ 7 files changed, 186 insertions(+), 221 deletions(-) diff --git a/package-lock.json b/package-lock.json index 91f5affa09..56a5b9fd54 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,7 +13,7 @@ "@anthropic-ai/vertex-sdk": "^0.7.0", "@aws-sdk/client-bedrock-runtime": "^3.779.0", "@google-cloud/vertexai": "^1.9.3", - "@google/generative-ai": "^0.18.0", + "@google/genai": "^0.9.0", "@mistralai/mistralai": "^1.3.6", "@modelcontextprotocol/sdk": "^1.7.0", "@types/clone-deep": "^4.0.4", @@ -5781,14 +5781,39 @@ "node": ">=18.0.0" } }, - "node_modules/@google/generative-ai": { - "version": "0.18.0", - "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.18.0.tgz", - "integrity": "sha512-AhaIWSpk2tuhYHrBhUqC0xrWWznmYEja1/TRDIb+5kruBU5kUzMlFsXCQNO9PzyTZ4clUJ3CX/Rvy+Xm9x+w3g==", + "node_modules/@google/genai": { + "version": "0.9.0", + "resolved": "https://registry.npmjs.org/@google/genai/-/genai-0.9.0.tgz", + "integrity": "sha512-FD2RizYGInsvfjeaN6O+wQGpRnGVglS1XWrGQr8K7D04AfMmvPodDSw94U9KyFtsVLzWH9kmlPyFM+G4jbmkqg==", + "license": "Apache-2.0", + "dependencies": { + "google-auth-library": "^9.14.2", + "ws": "^8.18.0", + "zod": "^3.22.4", + "zod-to-json-schema": "^3.22.4" + }, "engines": { "node": ">=18.0.0" } }, + "node_modules/@google/genai/node_modules/zod": { + "version": "3.24.3", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.3.tgz", + "integrity": "sha512-HhY1oqzWCQWuUqvBFnsyrtZRhyPeR7SUGv+C4+MsisMuVfSPx8HpwWqH8tRahSlt6M3PiFAcoeFhZAqIXTxoSg==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/@google/genai/node_modules/zod-to-json-schema": { + "version": "3.24.5", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.5.tgz", + "integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.24.1" + } + }, "node_modules/@humanwhocodes/config-array": { "version": "0.13.0", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.13.0.tgz", diff --git a/package.json b/package.json index 4804150ffd..4db450c42a 100644 --- a/package.json +++ b/package.json @@ -405,7 +405,7 @@ "@anthropic-ai/vertex-sdk": "^0.7.0", "@aws-sdk/client-bedrock-runtime": "^3.779.0", "@google-cloud/vertexai": "^1.9.3", - "@google/generative-ai": "^0.18.0", + "@google/genai": "^0.9.0", "@mistralai/mistralai": "^1.3.6", "@modelcontextprotocol/sdk": "^1.7.0", "@types/clone-deep": "^4.0.4", diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index d12c261b79..c3a14f7960 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -1,30 +1,32 @@ -import { GeminiHandler } from "../gemini" +// npx jest src/api/providers/__tests__/gemini.test.ts + import { Anthropic } from "@anthropic-ai/sdk" -import { GoogleGenerativeAI } from "@google/generative-ai" - -// Mock the Google Generative AI SDK -jest.mock("@google/generative-ai", () => ({ - GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ - getGenerativeModel: jest.fn().mockReturnValue({ - generateContentStream: jest.fn(), - generateContent: jest.fn().mockResolvedValue({ - response: { - text: () => "Test response", - }, - }), - }), - })), -})) + +import { GeminiHandler } from "../gemini" describe("GeminiHandler", () => { let handler: GeminiHandler beforeEach(() => { + // Create mock functions + const mockGenerateContentStream = jest.fn() + const mockGenerateContent = jest.fn() + const mockGetGenerativeModel = jest.fn() + handler = new GeminiHandler({ apiKey: "test-key", apiModelId: "gemini-2.0-flash-thinking-exp-1219", geminiApiKey: "test-key", }) + + // Replace the client with our mock + handler.client = { + models: { + generateContentStream: mockGenerateContentStream, + generateContent: mockGenerateContent, + getGenerativeModel: mockGetGenerativeModel, + }, + } as any }) describe("constructor", () => { @@ -32,15 +34,6 @@ describe("GeminiHandler", () => { expect(handler["options"].geminiApiKey).toBe("test-key") expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219") }) - - it.skip("should throw if API key is missing", () => { - expect(() => { - new GeminiHandler({ - apiModelId: "gemini-2.0-flash-thinking-exp-1219", - geminiApiKey: "", - }) - }).toThrow("API key is required for Google Gemini") - }) }) describe("createMessage", () => { @@ -58,25 +51,15 @@ describe("GeminiHandler", () => { const systemPrompt = "You are a helpful assistant" it("should handle text messages correctly", async () => { - // Mock the stream response - const mockStream = { - stream: [{ text: () => "Hello" }, { text: () => " world!" }], - response: { - usageMetadata: { - promptTokenCount: 10, - candidatesTokenCount: 5, - }, + // Setup the mock implementation to return an async generator + ;(handler.client.models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Hello" } + yield { text: " world!" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } }, - } - - // Setup the mock implementation - const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContentStream: mockGenerateContentStream, }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel - const stream = handler.createMessage(systemPrompt, mockMessages) const chunks = [] @@ -100,35 +83,21 @@ describe("GeminiHandler", () => { outputTokens: 5, }) - // Verify the model configuration - expect(mockGetGenerativeModel).toHaveBeenCalledWith( - { - model: "gemini-2.0-flash-thinking-exp-1219", - systemInstruction: systemPrompt, - }, - { - baseUrl: undefined, - }, - ) - - // Verify generation config - expect(mockGenerateContentStream).toHaveBeenCalledWith( + // Verify the call to generateContentStream + expect(handler.client.models.generateContentStream).toHaveBeenCalledWith( expect.objectContaining({ - generationConfig: { + model: "gemini-2.0-flash-thinking-exp-1219", + config: expect.objectContaining({ temperature: 0, - }, + systemInstruction: systemPrompt, + }), }), ) }) it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContentStream: mockGenerateContentStream, - }) - - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel + ;(handler.client.models.generateContentStream as jest.Mock).mockRejectedValue(mockError) const stream = handler.createMessage(systemPrompt, mockMessages) @@ -136,35 +105,26 @@ describe("GeminiHandler", () => { for await (const chunk of stream) { // Should throw before yielding any chunks } - }).rejects.toThrow("Gemini API error") + }).rejects.toThrow() }) }) describe("completePrompt", () => { it("should complete prompt successfully", async () => { - const mockGenerateContent = jest.fn().mockResolvedValue({ - response: { - text: () => "Test response", - }, + // Mock the response with text property + ;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({ + text: "Test response", }) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent, - }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockGetGenerativeModel).toHaveBeenCalledWith( - { - model: "gemini-2.0-flash-thinking-exp-1219", - }, - { - baseUrl: undefined, - }, - ) - expect(mockGenerateContent).toHaveBeenCalledWith({ + + // Verify the call to generateContent + expect(handler.client.models.generateContent).toHaveBeenCalledWith({ + model: "gemini-2.0-flash-thinking-exp-1219", contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], - generationConfig: { + config: { + httpOptions: undefined, temperature: 0, }, }) @@ -172,11 +132,7 @@ describe("GeminiHandler", () => { it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - const mockGenerateContent = jest.fn().mockRejectedValue(mockError) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent, - }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel + ;(handler.client.models.generateContent as jest.Mock).mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( "Gemini completion error: Gemini API error", @@ -184,15 +140,10 @@ describe("GeminiHandler", () => { }) it("should handle empty response", async () => { - const mockGenerateContent = jest.fn().mockResolvedValue({ - response: { - text: () => "", - }, - }) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent, + // Mock the response with empty text + ;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({ + text: "", }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel const result = await handler.completePrompt("Test prompt") expect(result).toBe("") diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index a906ad6e7e..9032754ac6 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -23,6 +23,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa const apiKeyFieldName = this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey" + this.client = new Anthropic({ baseURL: this.options.anthropicBaseUrl || undefined, [apiKeyFieldName]: this.options.apiKey, @@ -217,10 +218,10 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } async completePrompt(prompt: string) { - let { id: modelId, temperature } = this.getModel() + let { id: model, temperature } = this.getModel() const message = await this.client.messages.create({ - model: modelId, + model, max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS, thinking: undefined, temperature, @@ -241,16 +242,11 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa override async countTokens(content: Array): Promise { try { // Use the current model - const actualModelId = this.getModel().id + const { id: model } = this.getModel() const response = await this.client.messages.countTokens({ - model: actualModelId, - messages: [ - { - role: "user", - content: content, - }, - ], + model, + messages: [{ role: "user", content: content }], }) return response.input_tokens diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 98117e99a9..a84e7df5af 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -1,88 +1,114 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { GoogleGenerativeAI } from "@google/generative-ai" +import type { Anthropic } from "@anthropic-ai/sdk" +import { + GoogleGenAI, + ThinkingConfig, + type GenerateContentResponseUsageMetadata, + type GenerateContentParameters, +} from "@google/genai" + import { SingleCompletionHandler } from "../" -import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api" +import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api" +import { geminiDefaultModelId, geminiModels } from "../../shared/api" import { convertAnthropicMessageToGemini } from "../transform/gemini-format" -import { ApiStream } from "../transform/stream" +import type { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" -const GEMINI_DEFAULT_TEMPERATURE = 0 - export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - private client: GoogleGenerativeAI + public client: GoogleGenAI constructor(options: ApiHandlerOptions) { super() this.options = options - this.client = new GoogleGenerativeAI(options.geminiApiKey ?? "not-provided") + this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" }) } - override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const model = this.client.getGenerativeModel( - { - model: this.getModel().id, - systemInstruction: systemPrompt, - }, - { - baseUrl: this.options.googleGeminiBaseUrl || undefined, - }, - ) - const result = await model.generateContentStream({ + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const { id: model, thinkingConfig, maxOutputTokens } = this.getModel() + + const params: GenerateContentParameters = { + model, contents: messages.map(convertAnthropicMessageToGemini), - generationConfig: { - // maxOutputTokens: this.getModel().info.maxTokens, - temperature: this.options.modelTemperature ?? GEMINI_DEFAULT_TEMPERATURE, + config: { + thinkingConfig, + maxOutputTokens, + temperature: this.options.modelTemperature ?? 0, + systemInstruction: systemPrompt, }, - }) + } - for await (const chunk of result.stream) { - yield { - type: "text", - text: chunk.text(), + const result = await this.client.models.generateContentStream(params) + + let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined + + for await (const chunk of result) { + if (chunk.text) { + yield { type: "text", text: chunk.text } + } + + if (chunk.usageMetadata) { + lastUsageMetadata = chunk.usageMetadata } } - const response = await result.response - yield { - type: "usage", - inputTokens: response.usageMetadata?.promptTokenCount ?? 0, - outputTokens: response.usageMetadata?.candidatesTokenCount ?? 0, + if (lastUsageMetadata) { + yield { + type: "usage", + inputTokens: lastUsageMetadata.promptTokenCount ?? 0, + outputTokens: lastUsageMetadata.candidatesTokenCount ?? 0, + } } } - override getModel(): { id: GeminiModelId; info: ModelInfo } { - const modelId = this.options.apiModelId - if (modelId && modelId in geminiModels) { - const id = modelId as GeminiModelId - return { id, info: geminiModels[id] } + override getModel(): { + id: GeminiModelId + info: ModelInfo + thinkingConfig?: ThinkingConfig + maxOutputTokens?: number + } { + let id = this.options.apiModelId ? (this.options.apiModelId as GeminiModelId) : geminiDefaultModelId + let info: ModelInfo = geminiModels[id] + let thinkingConfig: ThinkingConfig | undefined = undefined + let maxOutputTokens: number | undefined = undefined + + if (id?.endsWith(":thinking")) { + id = id.slice(0, -9) as GeminiModelId + info = geminiModels[id] + thinkingConfig = { includeThoughts: true, thinkingBudget: 8192 } + maxOutputTokens = info.maxTokens ?? undefined } - return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] } + + if (!info) { + id = geminiDefaultModelId + info = geminiModels[geminiDefaultModelId] + thinkingConfig = undefined + maxOutputTokens = undefined + } + + return { id, info, thinkingConfig } } async completePrompt(prompt: string): Promise { try { - const model = this.client.getGenerativeModel( - { - model: this.getModel().id, - }, - { - baseUrl: this.options.googleGeminiBaseUrl || undefined, - }, - ) + const { id: model } = this.getModel() - const result = await model.generateContent({ + const result = await this.client.models.generateContent({ + model, contents: [{ role: "user", parts: [{ text: prompt }] }], - generationConfig: { - temperature: this.options.modelTemperature ?? GEMINI_DEFAULT_TEMPERATURE, + config: { + httpOptions: this.options.googleGeminiBaseUrl + ? { baseUrl: this.options.googleGeminiBaseUrl } + : undefined, + temperature: this.options.modelTemperature ?? 0, }, }) - return result.response.text() + return result.text ?? "" } catch (error) { if (error instanceof Error) { throw new Error(`Gemini completion error: ${error.message}`) } + throw error } } diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index c8fc80d769..10c9ae4d31 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -1,76 +1,23 @@ import { Anthropic } from "@anthropic-ai/sdk" -import { Content, FunctionCallPart, FunctionResponsePart, InlineDataPart, Part, TextPart } from "@google/generative-ai" +import { Content, Part } from "@google/genai" -function convertAnthropicContentToGemini(content: Anthropic.Messages.MessageParam["content"]): Part[] { +export function convertAnthropicContentToGemini(content: string | Anthropic.ContentBlockParam[]): Part[] { if (typeof content === "string") { - return [{ text: content } as TextPart] + return [{ text: content }] } - return content.flatMap((block) => { + return content.flatMap((block): Part => { switch (block.type) { case "text": - return { text: block.text } as TextPart + return { text: block.text } case "image": if (block.source.type !== "base64") { throw new Error("Unsupported image source type") } - return { - inlineData: { - data: block.source.data, - mimeType: block.source.media_type, - }, - } as InlineDataPart - case "tool_use": - return { - functionCall: { - name: block.name, - args: block.input, - }, - } as FunctionCallPart - case "tool_result": - const name = block.tool_use_id.split("-")[0] - if (!block.content) { - return [] - } - if (typeof block.content === "string") { - return { - functionResponse: { - name, - response: { - name, - content: block.content, - }, - }, - } as FunctionResponsePart - } else { - // The only case when tool_result could be array is when the tool failed and we're providing ie user feedback potentially with images - const textParts = block.content.filter((part) => part.type === "text") - const imageParts = block.content.filter((part) => part.type === "image") - const text = textParts.length > 0 ? textParts.map((part) => part.text).join("\n\n") : "" - const imageText = imageParts.length > 0 ? "\n\n(See next part for image)" : "" - return [ - { - functionResponse: { - name, - response: { - name, - content: text + imageText, - }, - }, - } as FunctionResponsePart, - ...imageParts.map( - (part) => - ({ - inlineData: { - data: part.source.data, - mimeType: part.source.media_type, - }, - }) as InlineDataPart, - ), - ] - } + + return { inlineData: { data: block.source.data, mimeType: block.source.media_type } } default: - throw new Error(`Unsupported content block type: ${(block as any).type}`) + throw new Error(`Unsupported content block type: ${block.type}`) } }) } diff --git a/src/shared/api.ts b/src/shared/api.ts index 4d71d947ba..cc9331bf2c 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -485,6 +485,15 @@ export const vertexModels = { inputPrice: 0.15, outputPrice: 0.6, }, + "gemini-2.5-flash-preview-04-17:thinking": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + thinking: true, + }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, contextWindow: 1_048_576, @@ -492,6 +501,7 @@ export const vertexModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, + thinking: false, }, "gemini-2.5-pro-preview-03-25": { maxTokens: 65_535, @@ -640,6 +650,15 @@ export const openAiModelInfoSaneDefaults: ModelInfo = { export type GeminiModelId = keyof typeof geminiModels export const geminiDefaultModelId: GeminiModelId = "gemini-2.0-flash-001" export const geminiModels = { + "gemini-2.5-flash-preview-04-17:thinking": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + thinking: true, + }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, contextWindow: 1_048_576, @@ -647,6 +666,7 @@ export const geminiModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, + thinking: false, }, "gemini-2.5-pro-exp-03-25": { maxTokens: 65_535, From 95752097ffd670edbe8838ae8d326c41523c9280 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:06:04 -0700 Subject: [PATCH 02/20] Fix tests --- src/api/transform/gemini-format.ts | 75 +++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index 10c9ae4d31..5d9fdde496 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -6,7 +6,7 @@ export function convertAnthropicContentToGemini(content: string | Anthropic.Cont return [{ text: content }] } - return content.flatMap((block): Part => { + return content.flatMap((block): Part | Part[] => { switch (block.type) { case "text": return { text: block.text } @@ -16,6 +16,79 @@ export function convertAnthropicContentToGemini(content: string | Anthropic.Cont } return { inlineData: { data: block.source.data, mimeType: block.source.media_type } } + case "tool_use": + return { + functionCall: { + name: block.name, + args: block.input as Record, + }, + } + case "tool_result": { + // Skip empty tool results + if (!block.content) { + return [] + } + + // Extract tool name from tool_use_id (e.g., "calculator-123" -> "calculator") + const toolName = block.tool_use_id.split("-")[0] + + // Handle string content + if (typeof block.content === "string") { + return { + functionResponse: { + name: toolName, + response: { + name: toolName, + content: block.content, + }, + }, + } + } + + // Handle array content + if (Array.isArray(block.content)) { + const textParts: string[] = [] + const imageParts: Part[] = [] + + block.content.forEach((item) => { + if (item.type === "text") { + textParts.push(item.text) + } else if (item.type === "image") { + if (item.source.type === "base64") { + imageParts.push({ + inlineData: { + data: item.source.data, + mimeType: item.source.media_type, + }, + }) + } + } + }) + + // Create content text with a note about images if present + const contentText = + textParts.join("\n\n") + + (imageParts.length > 0 + ? (textParts.length > 0 ? "\n\n" : "\n\n") + "(See next part for image)" + : "") + + // Return function response followed by any images + return [ + { + functionResponse: { + name: toolName, + response: { + name: toolName, + content: contentText, + }, + }, + }, + ...imageParts, + ] + } + + return [] + } default: throw new Error(`Unsupported content block type: ${block.type}`) } From f716cf8692bdf4479f0732b1f05aca32a5713ba6 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:07:56 -0700 Subject: [PATCH 03/20] Add a TODO --- src/shared/api.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shared/api.ts b/src/shared/api.ts index cc9331bf2c..ede1c83db8 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -492,7 +492,7 @@ export const vertexModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, - thinking: true, + thinking: true, // TODO: Max thinking budget is 24_576, so we need a new `ModelInfo` property for this. }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, @@ -657,7 +657,7 @@ export const geminiModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, - thinking: true, + thinking: true, // TODO: Max thinking budget is 24_576, so we need a new `ModelInfo` property for this. }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, From f51ec130f5e8a45477c637c69ab955bdcb56c027 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:13:02 -0700 Subject: [PATCH 04/20] Use the options --- src/api/providers/gemini.ts | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index a84e7df5af..0fff859486 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -74,8 +74,12 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl if (id?.endsWith(":thinking")) { id = id.slice(0, -9) as GeminiModelId info = geminiModels[id] - thinkingConfig = { includeThoughts: true, thinkingBudget: 8192 } - maxOutputTokens = info.maxTokens ?? undefined + + thinkingConfig = this.options.modelMaxThinkingTokens + ? { thinkingBudget: this.options.modelMaxThinkingTokens } + : undefined + + maxOutputTokens = this.options.modelMaxTokens ?? info.maxTokens ?? undefined } if (!info) { From 64d01ab8d32a73285c49a33581b8fff39e45271c Mon Sep 17 00:00:00 2001 From: Chris Estreich Date: Fri, 18 Apr 2025 01:13:27 -0700 Subject: [PATCH 05/20] Update src/api/providers/gemini.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- src/api/providers/gemini.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 0fff859486..626d6ecd70 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -89,7 +89,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl maxOutputTokens = undefined } - return { id, info, thinkingConfig } + return { id, info, thinkingConfig, maxOutputTokens } } async completePrompt(prompt: string): Promise { From 40b3f04b264e5523a35d3f87ec5ba418830590b6 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:16:09 -0700 Subject: [PATCH 06/20] PR feedback --- src/api/transform/gemini-format.ts | 31 ++++-------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index 5d9fdde496..a8cdc2bf43 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -24,7 +24,6 @@ export function convertAnthropicContentToGemini(content: string | Anthropic.Cont }, } case "tool_result": { - // Skip empty tool results if (!block.content) { return [] } @@ -32,20 +31,12 @@ export function convertAnthropicContentToGemini(content: string | Anthropic.Cont // Extract tool name from tool_use_id (e.g., "calculator-123" -> "calculator") const toolName = block.tool_use_id.split("-")[0] - // Handle string content if (typeof block.content === "string") { return { - functionResponse: { - name: toolName, - response: { - name: toolName, - content: block.content, - }, - }, + functionResponse: { name: toolName, response: { name: toolName, content: block.content } }, } } - // Handle array content if (Array.isArray(block.content)) { const textParts: string[] = [] const imageParts: Part[] = [] @@ -56,10 +47,7 @@ export function convertAnthropicContentToGemini(content: string | Anthropic.Cont } else if (item.type === "image") { if (item.source.type === "base64") { imageParts.push({ - inlineData: { - data: item.source.data, - mimeType: item.source.media_type, - }, + inlineData: { data: item.source.data, mimeType: item.source.media_type }, }) } } @@ -67,22 +55,11 @@ export function convertAnthropicContentToGemini(content: string | Anthropic.Cont // Create content text with a note about images if present const contentText = - textParts.join("\n\n") + - (imageParts.length > 0 - ? (textParts.length > 0 ? "\n\n" : "\n\n") + "(See next part for image)" - : "") + textParts.join("\n\n") + (imageParts.length > 0 ? "\n\n(See next part for image)" : "") // Return function response followed by any images return [ - { - functionResponse: { - name: toolName, - response: { - name: toolName, - content: contentText, - }, - }, - }, + { functionResponse: { name: toolName, response: { name: toolName, content: contentText } } }, ...imageParts, ] } From 3315e583086a5367537a471c34fa2aed2e25fb8f Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:18:20 -0700 Subject: [PATCH 07/20] Don't export this --- src/api/transform/gemini-format.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index a8cdc2bf43..57a21da10f 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -1,7 +1,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import { Content, Part } from "@google/genai" -export function convertAnthropicContentToGemini(content: string | Anthropic.ContentBlockParam[]): Part[] { +function convertAnthropicContentToGemini(content: string | Anthropic.ContentBlockParam[]): Part[] { if (typeof content === "string") { return [{ text: content }] } From 6ab88337a14caaec123c8a7c196a14106c214384 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:23:22 -0700 Subject: [PATCH 08/20] Small tweak --- src/api/transform/gemini-format.ts | 51 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index 57a21da10f..464c1337fa 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -24,7 +24,7 @@ function convertAnthropicContentToGemini(content: string | Anthropic.ContentBloc }, } case "tool_result": { - if (!block.content) { + if (!block.content || !Array.isArray(block.content)) { return [] } @@ -37,36 +37,37 @@ function convertAnthropicContentToGemini(content: string | Anthropic.ContentBloc } } - if (Array.isArray(block.content)) { - const textParts: string[] = [] - const imageParts: Part[] = [] + if (!Array.isArray(block.content)) { + return [] + } - block.content.forEach((item) => { - if (item.type === "text") { - textParts.push(item.text) - } else if (item.type === "image") { - if (item.source.type === "base64") { - imageParts.push({ - inlineData: { data: item.source.data, mimeType: item.source.media_type }, - }) - } - } - }) + const textParts: string[] = [] + const imageParts: Part[] = [] - // Create content text with a note about images if present - const contentText = - textParts.join("\n\n") + (imageParts.length > 0 ? "\n\n(See next part for image)" : "") + block.content.forEach((item) => { + if (item.type === "text") { + textParts.push(item.text) + } else if (item.type === "image") { + if (item.source.type === "base64") { + imageParts.push({ + inlineData: { data: item.source.data, mimeType: item.source.media_type }, + }) + } + } + }) - // Return function response followed by any images - return [ - { functionResponse: { name: toolName, response: { name: toolName, content: contentText } } }, - ...imageParts, - ] - } + // Create content text with a note about images if present + const contentText = + textParts.join("\n\n") + (imageParts.length > 0 ? "\n\n(See next part for image)" : "") - return [] + // Return function response followed by any images + return [ + { functionResponse: { name: toolName, response: { name: toolName, content: contentText } } }, + ...imageParts, + ] } default: + // Currently unsupported: "thinking" | "redacted_thinking" | "document" throw new Error(`Unsupported content block type: ${block.type}`) } }) From 45a61a83bef0fc91bac55ff4d3c6586c633971ee Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:52:19 -0700 Subject: [PATCH 09/20] Fix tests --- src/api/transform/gemini-format.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index 464c1337fa..43716527ca 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -24,7 +24,7 @@ function convertAnthropicContentToGemini(content: string | Anthropic.ContentBloc }, } case "tool_result": { - if (!block.content || !Array.isArray(block.content)) { + if (!block.content) { return [] } From 7a93a586df04e7d1600b6c46b32d81e0dc66e074 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 01:52:45 -0700 Subject: [PATCH 10/20] Add changeset --- .changeset/shiny-poems-search.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/shiny-poems-search.md diff --git a/.changeset/shiny-poems-search.md b/.changeset/shiny-poems-search.md new file mode 100644 index 0000000000..699843060e --- /dev/null +++ b/.changeset/shiny-poems-search.md @@ -0,0 +1,5 @@ +--- +"roo-cline": patch +--- + +Support Gemini 2.5 Flash thinking mode From ad8ad769fc64a04d0520eca2959638f4aece643f Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 09:08:42 -0700 Subject: [PATCH 11/20] PR feedback --- src/api/providers/__tests__/gemini.test.ts | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index c3a14f7960..7647ec4d99 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -3,6 +3,9 @@ import { Anthropic } from "@anthropic-ai/sdk" import { GeminiHandler } from "../gemini" +import { geminiDefaultModelId } from "../../../shared/api" + +const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219" describe("GeminiHandler", () => { let handler: GeminiHandler @@ -15,7 +18,7 @@ describe("GeminiHandler", () => { handler = new GeminiHandler({ apiKey: "test-key", - apiModelId: "gemini-2.0-flash-thinking-exp-1219", + apiModelId: GEMINI_20_FLASH_THINKING_NAME, geminiApiKey: "test-key", }) @@ -32,7 +35,7 @@ describe("GeminiHandler", () => { describe("constructor", () => { it("should initialize with provided config", () => { expect(handler["options"].geminiApiKey).toBe("test-key") - expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219") + expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME) }) }) @@ -86,7 +89,7 @@ describe("GeminiHandler", () => { // Verify the call to generateContentStream expect(handler.client.models.generateContentStream).toHaveBeenCalledWith( expect.objectContaining({ - model: "gemini-2.0-flash-thinking-exp-1219", + model: GEMINI_20_FLASH_THINKING_NAME, config: expect.objectContaining({ temperature: 0, systemInstruction: systemPrompt, @@ -121,7 +124,7 @@ describe("GeminiHandler", () => { // Verify the call to generateContent expect(handler.client.models.generateContent).toHaveBeenCalledWith({ - model: "gemini-2.0-flash-thinking-exp-1219", + model: GEMINI_20_FLASH_THINKING_NAME, contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], config: { httpOptions: undefined, @@ -153,7 +156,7 @@ describe("GeminiHandler", () => { describe("getModel", () => { it("should return correct model info", () => { const modelInfo = handler.getModel() - expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") + expect(modelInfo.id).toBe(GEMINI_20_FLASH_THINKING_NAME) expect(modelInfo.info).toBeDefined() expect(modelInfo.info.maxTokens).toBe(8192) expect(modelInfo.info.contextWindow).toBe(32_767) @@ -165,7 +168,7 @@ describe("GeminiHandler", () => { geminiApiKey: "test-key", }) const modelInfo = invalidHandler.getModel() - expect(modelInfo.id).toBe("gemini-2.0-flash-001") // Default model + expect(modelInfo.id).toBe(geminiDefaultModelId) // Default model }) }) }) From 1e739d34f28c0835a5c7dd34123e7bfc0ad1f9fd Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 09:10:01 -0700 Subject: [PATCH 12/20] Make this private --- src/api/providers/gemini.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 626d6ecd70..dd7be38136 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -15,7 +15,7 @@ import { BaseProvider } from "./base-provider" export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - public client: GoogleGenAI + private client: GoogleGenAI constructor(options: ApiHandlerOptions) { super() From d7e74c7dc3b2ca688cedc58118cdb65b8ef47085 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 09:11:53 -0700 Subject: [PATCH 13/20] Fix tsc errors --- src/api/providers/__tests__/gemini.test.ts | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index 7647ec4d99..897ece3ed3 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -23,7 +23,7 @@ describe("GeminiHandler", () => { }) // Replace the client with our mock - handler.client = { + handler["client"] = { models: { generateContentStream: mockGenerateContentStream, generateContent: mockGenerateContent, @@ -55,7 +55,7 @@ describe("GeminiHandler", () => { it("should handle text messages correctly", async () => { // Setup the mock implementation to return an async generator - ;(handler.client.models.generateContentStream as jest.Mock).mockResolvedValue({ + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ [Symbol.asyncIterator]: async function* () { yield { text: "Hello" } yield { text: " world!" } @@ -87,7 +87,7 @@ describe("GeminiHandler", () => { }) // Verify the call to generateContentStream - expect(handler.client.models.generateContentStream).toHaveBeenCalledWith( + expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith( expect.objectContaining({ model: GEMINI_20_FLASH_THINKING_NAME, config: expect.objectContaining({ @@ -100,7 +100,7 @@ describe("GeminiHandler", () => { it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - ;(handler.client.models.generateContentStream as jest.Mock).mockRejectedValue(mockError) + ;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(mockError) const stream = handler.createMessage(systemPrompt, mockMessages) @@ -115,7 +115,7 @@ describe("GeminiHandler", () => { describe("completePrompt", () => { it("should complete prompt successfully", async () => { // Mock the response with text property - ;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({ + ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ text: "Test response", }) @@ -123,7 +123,7 @@ describe("GeminiHandler", () => { expect(result).toBe("Test response") // Verify the call to generateContent - expect(handler.client.models.generateContent).toHaveBeenCalledWith({ + expect(handler["client"].models.generateContent).toHaveBeenCalledWith({ model: GEMINI_20_FLASH_THINKING_NAME, contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], config: { @@ -135,7 +135,7 @@ describe("GeminiHandler", () => { it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - ;(handler.client.models.generateContent as jest.Mock).mockRejectedValue(mockError) + ;(handler["client"].models.generateContent as jest.Mock).mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( "Gemini completion error: Gemini API error", @@ -144,7 +144,7 @@ describe("GeminiHandler", () => { it("should handle empty response", async () => { // Mock the response with empty text - ;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({ + ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ text: "", }) From c650fc676cecd9e4646824e0f1eac95c0e92780e Mon Sep 17 00:00:00 2001 From: Chris Estreich Date: Fri, 18 Apr 2025 09:20:28 -0700 Subject: [PATCH 14/20] Update src/api/transform/gemini-format.ts Co-authored-by: Felix Anhalt <40368420+felixAnhalt@users.noreply.github.com> --- src/api/transform/gemini-format.ts | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index 43716527ca..457d70eacf 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -45,15 +45,16 @@ function convertAnthropicContentToGemini(content: string | Anthropic.ContentBloc const imageParts: Part[] = [] block.content.forEach((item) => { - if (item.type === "text") { - textParts.push(item.text) - } else if (item.type === "image") { - if (item.source.type === "base64") { - imageParts.push({ - inlineData: { data: item.source.data, mimeType: item.source.media_type }, - }) - } - } + if (item.type === "text") { + textParts.push(item.text); + continue; + } + if (item.type === "image" && item.source.type === "base64") { + const { data, media_type } = item.source; + imageParts.push({ + inlineData: { data, mimeType: media_type }, + }); + } }) // Create content text with a note about images if present From 88dc098f6c76caf4e7d0c2365d8f7e50030cb4ff Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 09:21:14 -0700 Subject: [PATCH 15/20] DRY up suffix reference --- src/api/providers/gemini.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index dd7be38136..f6933e00d8 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -71,8 +71,10 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl let thinkingConfig: ThinkingConfig | undefined = undefined let maxOutputTokens: number | undefined = undefined - if (id?.endsWith(":thinking")) { - id = id.slice(0, -9) as GeminiModelId + const thinkingSuffix = ":thinking" + + if (id?.endsWith(thinkingSuffix)) { + id = id.slice(0, -thinkingSuffix.length) as GeminiModelId info = geminiModels[id] thinkingConfig = this.options.modelMaxThinkingTokens From 679e2c7008e4595abd377ffa13225ecade4b849a Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 11:19:04 -0700 Subject: [PATCH 16/20] Add Gemini token counting --- src/api/providers/gemini.ts | 23 ++++++++++++++++++++++- src/api/transform/gemini-format.ts | 22 +++++++++------------- src/exports/roo-code.d.ts | 5 +++++ src/exports/types.ts | 5 +++++ src/schemas/index.ts | 1 + 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index f6933e00d8..9942502a85 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -9,7 +9,7 @@ import { import { SingleCompletionHandler } from "../" import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api" import { geminiDefaultModelId, geminiModels } from "../../shared/api" -import { convertAnthropicMessageToGemini } from "../transform/gemini-format" +import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" import type { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" @@ -118,4 +118,25 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl throw error } } + + override async countTokens(content: Array): Promise { + try { + const { id: model } = this.getModel() + + const response = await this.client.models.countTokens({ + model, + contents: convertAnthropicContentToGemini(content), + }) + + if (!response.totalTokens) { + console.warn("Gemini token counting returned undefined, using fallback") + return super.countTokens(content) + } + + return response.totalTokens + } catch (error) { + console.warn("Gemini token counting failed, using fallback", error) + return super.countTokens(content) + } + } } diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index 457d70eacf..ee22cff32a 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -1,7 +1,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import { Content, Part } from "@google/genai" -function convertAnthropicContentToGemini(content: string | Anthropic.ContentBlockParam[]): Part[] { +export function convertAnthropicContentToGemini(content: string | Anthropic.ContentBlockParam[]): Part[] { if (typeof content === "string") { return [{ text: content }] } @@ -44,18 +44,14 @@ function convertAnthropicContentToGemini(content: string | Anthropic.ContentBloc const textParts: string[] = [] const imageParts: Part[] = [] - block.content.forEach((item) => { - if (item.type === "text") { - textParts.push(item.text); - continue; - } - if (item.type === "image" && item.source.type === "base64") { - const { data, media_type } = item.source; - imageParts.push({ - inlineData: { data, mimeType: media_type }, - }); - } - }) + for (const item of block.content) { + if (item.type === "text") { + textParts.push(item.text) + } else if (item.type === "image" && item.source.type === "base64") { + const { data, media_type } = item.source + imageParts.push({ inlineData: { data, mimeType: media_type } }) + } + } // Create content text with a note about images if present const contentText = diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index 422297ddd3..d2a28a8273 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -31,6 +31,7 @@ type ProviderSettings = { glamaModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -53,6 +54,7 @@ type ProviderSettings = { openRouterModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -95,6 +97,7 @@ type ProviderSettings = { openAiCustomModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -140,6 +143,7 @@ type ProviderSettings = { unboundModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -161,6 +165,7 @@ type ProviderSettings = { requestyModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined diff --git a/src/exports/types.ts b/src/exports/types.ts index df38b929ed..4280511a8d 100644 --- a/src/exports/types.ts +++ b/src/exports/types.ts @@ -32,6 +32,7 @@ type ProviderSettings = { glamaModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -54,6 +55,7 @@ type ProviderSettings = { openRouterModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -96,6 +98,7 @@ type ProviderSettings = { openAiCustomModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -141,6 +144,7 @@ type ProviderSettings = { unboundModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -162,6 +166,7 @@ type ProviderSettings = { requestyModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined diff --git a/src/schemas/index.ts b/src/schemas/index.ts index db9feb54b4..24b224c08b 100644 --- a/src/schemas/index.ts +++ b/src/schemas/index.ts @@ -99,6 +99,7 @@ export type ReasoningEffort = z.infer export const modelInfoSchema = z.object({ maxTokens: z.number().nullish(), + maxThinkingTokens: z.number().nullish(), contextWindow: z.number(), supportsImages: z.boolean().optional(), supportsComputerUse: z.boolean().optional(), From f3b449e77bbeb65bba43598da30c0c017d4689e4 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 11:46:36 -0700 Subject: [PATCH 17/20] Enforce Gemini's max thinking tokens limit --- src/shared/api.ts | 6 ++- .../components/settings/ThinkingBudget.tsx | 50 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/shared/api.ts b/src/shared/api.ts index ede1c83db8..ebc0b85c93 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -492,7 +492,8 @@ export const vertexModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, - thinking: true, // TODO: Max thinking budget is 24_576, so we need a new `ModelInfo` property for this. + thinking: true, + maxThinkingTokens: 24_576, }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, @@ -657,7 +658,8 @@ export const geminiModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, - thinking: true, // TODO: Max thinking budget is 24_576, so we need a new `ModelInfo` property for this. + thinking: true, + maxThinkingTokens: 24_576, }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, diff --git a/webview-ui/src/components/settings/ThinkingBudget.tsx b/webview-ui/src/components/settings/ThinkingBudget.tsx index e4cb4f0b9c..54c3ef108a 100644 --- a/webview-ui/src/components/settings/ThinkingBudget.tsx +++ b/webview-ui/src/components/settings/ThinkingBudget.tsx @@ -1,10 +1,12 @@ -import { useEffect, useMemo } from "react" import { useAppTranslation } from "@/i18n/TranslationContext" import { Slider } from "@/components/ui" import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api" +const DEFAULT_MAX_OUTPUT_TOKENS = 16_384 +const DEFAULT_MAX_THINKING_TOKENS = 8_192 + interface ThinkingBudgetProps { apiConfiguration: ApiConfiguration setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void @@ -13,27 +15,23 @@ interface ThinkingBudgetProps { export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, modelInfo }: ThinkingBudgetProps) => { const { t } = useAppTranslation() - const tokens = apiConfiguration?.modelMaxTokens || 16_384 - const tokensMin = 8192 - const tokensMax = modelInfo?.maxTokens || 64_000 - // Get the appropriate thinking tokens based on provider - const thinkingTokens = useMemo(() => { - const value = apiConfiguration?.modelMaxThinkingTokens - return value || Math.min(Math.floor(0.8 * tokens), 8192) - }, [apiConfiguration, tokens]) + if (!modelInfo || !modelInfo.thinking || !modelInfo.maxTokens) { + return null + } - const thinkingTokensMin = 1024 - const thinkingTokensMax = Math.floor(0.8 * tokens) + const customMaxOutputTokens = apiConfiguration.modelMaxTokens || DEFAULT_MAX_OUTPUT_TOKENS - useEffect(() => { - if (thinkingTokens > thinkingTokensMax) { - setApiConfigurationField("modelMaxThinkingTokens", thinkingTokensMax) - } - }, [thinkingTokens, thinkingTokensMax, setApiConfigurationField]) + // Dynamically expand or shrink the max thinking budget based on the custom + // max output tokens so that there's always a 20% buffer. + const modelMaxThinkingTokens = modelInfo.maxThinkingTokens + ? Math.min(modelInfo.maxThinkingTokens, Math.floor(0.8 * customMaxOutputTokens)) + : Math.floor(0.8 * customMaxOutputTokens) - if (!modelInfo?.thinking) { - return null + let customMaxThinkingTokens = apiConfiguration.modelMaxThinkingTokens || DEFAULT_MAX_THINKING_TOKENS + + if (customMaxThinkingTokens > modelMaxThinkingTokens) { + customMaxThinkingTokens = modelMaxThinkingTokens } return ( @@ -42,26 +40,26 @@ export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, mod
{t("settings:thinkingBudget.maxTokens")}
setApiConfigurationField("modelMaxTokens", value)} /> -
{tokens}
+
{customMaxOutputTokens}
{t("settings:thinkingBudget.maxThinkingTokens")}
setApiConfigurationField("modelMaxThinkingTokens", value)} /> -
{thinkingTokens}
+
{customMaxThinkingTokens}
From e942c63739f75386683c6fa85ac03dbc677f0a30 Mon Sep 17 00:00:00 2001 From: Chris Estreich Date: Fri, 18 Apr 2025 11:47:09 -0700 Subject: [PATCH 18/20] Update src/api/providers/gemini.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- src/api/providers/gemini.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 9942502a85..7389611300 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -128,7 +128,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl contents: convertAnthropicContentToGemini(content), }) - if (!response.totalTokens) { + if (response.totalTokens === undefined) { console.warn("Gemini token counting returned undefined, using fallback") return super.countTokens(content) } From 4fd61bbc638a690801be5e95d55e49d7956b7460 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 18 Apr 2025 11:58:49 -0700 Subject: [PATCH 19/20] PR feedback --- .../components/settings/ThinkingBudget.tsx | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/webview-ui/src/components/settings/ThinkingBudget.tsx b/webview-ui/src/components/settings/ThinkingBudget.tsx index 54c3ef108a..34c16daefe 100644 --- a/webview-ui/src/components/settings/ThinkingBudget.tsx +++ b/webview-ui/src/components/settings/ThinkingBudget.tsx @@ -1,3 +1,4 @@ +import { useEffect } from "react" import { useAppTranslation } from "@/i18n/TranslationContext" import { Slider } from "@/components/ui" @@ -16,25 +17,28 @@ interface ThinkingBudgetProps { export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, modelInfo }: ThinkingBudgetProps) => { const { t } = useAppTranslation() - if (!modelInfo || !modelInfo.thinking || !modelInfo.maxTokens) { - return null - } + const isThinkingModel = modelInfo && modelInfo.thinking && modelInfo.maxTokens const customMaxOutputTokens = apiConfiguration.modelMaxTokens || DEFAULT_MAX_OUTPUT_TOKENS + const customMaxThinkingTokens = apiConfiguration.modelMaxThinkingTokens || DEFAULT_MAX_THINKING_TOKENS // Dynamically expand or shrink the max thinking budget based on the custom // max output tokens so that there's always a 20% buffer. - const modelMaxThinkingTokens = modelInfo.maxThinkingTokens + const modelMaxThinkingTokens = modelInfo?.maxThinkingTokens ? Math.min(modelInfo.maxThinkingTokens, Math.floor(0.8 * customMaxOutputTokens)) : Math.floor(0.8 * customMaxOutputTokens) - let customMaxThinkingTokens = apiConfiguration.modelMaxThinkingTokens || DEFAULT_MAX_THINKING_TOKENS + // If the custom max thinking tokens are going to exceed it's limit due + // to the custom max output tokens being reduced then we need to shrink it + // appropriately. + useEffect(() => { + if (isThinkingModel && customMaxThinkingTokens > modelMaxThinkingTokens) { + console.log(`setApiConfigurationField("modelMaxThinkingTokens", ${modelMaxThinkingTokens})`) + setApiConfigurationField("modelMaxThinkingTokens", modelMaxThinkingTokens) + } + }, [isThinkingModel, customMaxThinkingTokens, modelMaxThinkingTokens, setApiConfigurationField]) - if (customMaxThinkingTokens > modelMaxThinkingTokens) { - customMaxThinkingTokens = modelMaxThinkingTokens - } - - return ( + return isThinkingModel ? ( <>
{t("settings:thinkingBudget.maxTokens")}
@@ -63,5 +67,5 @@ export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, mod
- ) + ) : null } From 31f225c9e42162b962227ea2533131a2a012c039 Mon Sep 17 00:00:00 2001 From: Chris Estreich Date: Fri, 18 Apr 2025 12:02:19 -0700 Subject: [PATCH 20/20] Update webview-ui/src/components/settings/ThinkingBudget.tsx Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- webview-ui/src/components/settings/ThinkingBudget.tsx | 1 - 1 file changed, 1 deletion(-) diff --git a/webview-ui/src/components/settings/ThinkingBudget.tsx b/webview-ui/src/components/settings/ThinkingBudget.tsx index 34c16daefe..5123d571b3 100644 --- a/webview-ui/src/components/settings/ThinkingBudget.tsx +++ b/webview-ui/src/components/settings/ThinkingBudget.tsx @@ -33,7 +33,6 @@ export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, mod // appropriately. useEffect(() => { if (isThinkingModel && customMaxThinkingTokens > modelMaxThinkingTokens) { - console.log(`setApiConfigurationField("modelMaxThinkingTokens", ${modelMaxThinkingTokens})`) setApiConfigurationField("modelMaxThinkingTokens", modelMaxThinkingTokens) } }, [isThinkingModel, customMaxThinkingTokens, modelMaxThinkingTokens, setApiConfigurationField])