diff --git a/.changeset/shiny-poems-search.md b/.changeset/shiny-poems-search.md new file mode 100644 index 0000000000..699843060e --- /dev/null +++ b/.changeset/shiny-poems-search.md @@ -0,0 +1,5 @@ +--- +"roo-cline": patch +--- + +Support Gemini 2.5 Flash thinking mode diff --git a/package-lock.json b/package-lock.json index 91f5affa09..56a5b9fd54 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,7 +13,7 @@ "@anthropic-ai/vertex-sdk": "^0.7.0", "@aws-sdk/client-bedrock-runtime": "^3.779.0", "@google-cloud/vertexai": "^1.9.3", - "@google/generative-ai": "^0.18.0", + "@google/genai": "^0.9.0", "@mistralai/mistralai": "^1.3.6", "@modelcontextprotocol/sdk": "^1.7.0", "@types/clone-deep": "^4.0.4", @@ -5781,14 +5781,39 @@ "node": ">=18.0.0" } }, - "node_modules/@google/generative-ai": { - "version": "0.18.0", - "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.18.0.tgz", - "integrity": "sha512-AhaIWSpk2tuhYHrBhUqC0xrWWznmYEja1/TRDIb+5kruBU5kUzMlFsXCQNO9PzyTZ4clUJ3CX/Rvy+Xm9x+w3g==", + "node_modules/@google/genai": { + "version": "0.9.0", + "resolved": "https://registry.npmjs.org/@google/genai/-/genai-0.9.0.tgz", + "integrity": "sha512-FD2RizYGInsvfjeaN6O+wQGpRnGVglS1XWrGQr8K7D04AfMmvPodDSw94U9KyFtsVLzWH9kmlPyFM+G4jbmkqg==", + "license": "Apache-2.0", + "dependencies": { + "google-auth-library": "^9.14.2", + "ws": "^8.18.0", + "zod": "^3.22.4", + "zod-to-json-schema": "^3.22.4" + }, "engines": { "node": ">=18.0.0" } }, + "node_modules/@google/genai/node_modules/zod": { + "version": "3.24.3", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.3.tgz", + "integrity": "sha512-HhY1oqzWCQWuUqvBFnsyrtZRhyPeR7SUGv+C4+MsisMuVfSPx8HpwWqH8tRahSlt6M3PiFAcoeFhZAqIXTxoSg==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/@google/genai/node_modules/zod-to-json-schema": { + "version": "3.24.5", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.5.tgz", + "integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.24.1" + } + }, "node_modules/@humanwhocodes/config-array": { "version": "0.13.0", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.13.0.tgz", diff --git a/package.json b/package.json index 4804150ffd..4db450c42a 100644 --- a/package.json +++ b/package.json @@ -405,7 +405,7 @@ "@anthropic-ai/vertex-sdk": "^0.7.0", "@aws-sdk/client-bedrock-runtime": "^3.779.0", "@google-cloud/vertexai": "^1.9.3", - "@google/generative-ai": "^0.18.0", + "@google/genai": "^0.9.0", "@mistralai/mistralai": "^1.3.6", "@modelcontextprotocol/sdk": "^1.7.0", "@types/clone-deep": "^4.0.4", diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index d12c261b79..897ece3ed3 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -1,45 +1,41 @@ -import { GeminiHandler } from "../gemini" +// npx jest src/api/providers/__tests__/gemini.test.ts + import { Anthropic } from "@anthropic-ai/sdk" -import { GoogleGenerativeAI } from "@google/generative-ai" - -// Mock the Google Generative AI SDK -jest.mock("@google/generative-ai", () => ({ - GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ - getGenerativeModel: jest.fn().mockReturnValue({ - generateContentStream: jest.fn(), - generateContent: jest.fn().mockResolvedValue({ - response: { - text: () => "Test response", - }, - }), - }), - })), -})) + +import { GeminiHandler } from "../gemini" +import { geminiDefaultModelId } from "../../../shared/api" + +const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219" describe("GeminiHandler", () => { let handler: GeminiHandler beforeEach(() => { + // Create mock functions + const mockGenerateContentStream = jest.fn() + const mockGenerateContent = jest.fn() + const mockGetGenerativeModel = jest.fn() + handler = new GeminiHandler({ apiKey: "test-key", - apiModelId: "gemini-2.0-flash-thinking-exp-1219", + apiModelId: GEMINI_20_FLASH_THINKING_NAME, geminiApiKey: "test-key", }) + + // Replace the client with our mock + handler["client"] = { + models: { + generateContentStream: mockGenerateContentStream, + generateContent: mockGenerateContent, + getGenerativeModel: mockGetGenerativeModel, + }, + } as any }) describe("constructor", () => { it("should initialize with provided config", () => { expect(handler["options"].geminiApiKey).toBe("test-key") - expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219") - }) - - it.skip("should throw if API key is missing", () => { - expect(() => { - new GeminiHandler({ - apiModelId: "gemini-2.0-flash-thinking-exp-1219", - geminiApiKey: "", - }) - }).toThrow("API key is required for Google Gemini") + expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME) }) }) @@ -58,25 +54,15 @@ describe("GeminiHandler", () => { const systemPrompt = "You are a helpful assistant" it("should handle text messages correctly", async () => { - // Mock the stream response - const mockStream = { - stream: [{ text: () => "Hello" }, { text: () => " world!" }], - response: { - usageMetadata: { - promptTokenCount: 10, - candidatesTokenCount: 5, - }, + // Setup the mock implementation to return an async generator + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Hello" } + yield { text: " world!" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } }, - } - - // Setup the mock implementation - const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContentStream: mockGenerateContentStream, }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel - const stream = handler.createMessage(systemPrompt, mockMessages) const chunks = [] @@ -100,35 +86,21 @@ describe("GeminiHandler", () => { outputTokens: 5, }) - // Verify the model configuration - expect(mockGetGenerativeModel).toHaveBeenCalledWith( - { - model: "gemini-2.0-flash-thinking-exp-1219", - systemInstruction: systemPrompt, - }, - { - baseUrl: undefined, - }, - ) - - // Verify generation config - expect(mockGenerateContentStream).toHaveBeenCalledWith( + // Verify the call to generateContentStream + expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith( expect.objectContaining({ - generationConfig: { + model: GEMINI_20_FLASH_THINKING_NAME, + config: expect.objectContaining({ temperature: 0, - }, + systemInstruction: systemPrompt, + }), }), ) }) it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContentStream: mockGenerateContentStream, - }) - - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel + ;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(mockError) const stream = handler.createMessage(systemPrompt, mockMessages) @@ -136,35 +108,26 @@ describe("GeminiHandler", () => { for await (const chunk of stream) { // Should throw before yielding any chunks } - }).rejects.toThrow("Gemini API error") + }).rejects.toThrow() }) }) describe("completePrompt", () => { it("should complete prompt successfully", async () => { - const mockGenerateContent = jest.fn().mockResolvedValue({ - response: { - text: () => "Test response", - }, - }) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent, + // Mock the response with text property + ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ + text: "Test response", }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockGetGenerativeModel).toHaveBeenCalledWith( - { - model: "gemini-2.0-flash-thinking-exp-1219", - }, - { - baseUrl: undefined, - }, - ) - expect(mockGenerateContent).toHaveBeenCalledWith({ + + // Verify the call to generateContent + expect(handler["client"].models.generateContent).toHaveBeenCalledWith({ + model: GEMINI_20_FLASH_THINKING_NAME, contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], - generationConfig: { + config: { + httpOptions: undefined, temperature: 0, }, }) @@ -172,11 +135,7 @@ describe("GeminiHandler", () => { it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - const mockGenerateContent = jest.fn().mockRejectedValue(mockError) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent, - }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel + ;(handler["client"].models.generateContent as jest.Mock).mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( "Gemini completion error: Gemini API error", @@ -184,15 +143,10 @@ describe("GeminiHandler", () => { }) it("should handle empty response", async () => { - const mockGenerateContent = jest.fn().mockResolvedValue({ - response: { - text: () => "", - }, - }) - const mockGetGenerativeModel = jest.fn().mockReturnValue({ - generateContent: mockGenerateContent, + // Mock the response with empty text + ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ + text: "", }) - ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel const result = await handler.completePrompt("Test prompt") expect(result).toBe("") @@ -202,7 +156,7 @@ describe("GeminiHandler", () => { describe("getModel", () => { it("should return correct model info", () => { const modelInfo = handler.getModel() - expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219") + expect(modelInfo.id).toBe(GEMINI_20_FLASH_THINKING_NAME) expect(modelInfo.info).toBeDefined() expect(modelInfo.info.maxTokens).toBe(8192) expect(modelInfo.info.contextWindow).toBe(32_767) @@ -214,7 +168,7 @@ describe("GeminiHandler", () => { geminiApiKey: "test-key", }) const modelInfo = invalidHandler.getModel() - expect(modelInfo.id).toBe("gemini-2.0-flash-001") // Default model + expect(modelInfo.id).toBe(geminiDefaultModelId) // Default model }) }) }) diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index a906ad6e7e..9032754ac6 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -23,6 +23,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa const apiKeyFieldName = this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey" + this.client = new Anthropic({ baseURL: this.options.anthropicBaseUrl || undefined, [apiKeyFieldName]: this.options.apiKey, @@ -217,10 +218,10 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } async completePrompt(prompt: string) { - let { id: modelId, temperature } = this.getModel() + let { id: model, temperature } = this.getModel() const message = await this.client.messages.create({ - model: modelId, + model, max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS, thinking: undefined, temperature, @@ -241,16 +242,11 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa override async countTokens(content: Array): Promise { try { // Use the current model - const actualModelId = this.getModel().id + const { id: model } = this.getModel() const response = await this.client.messages.countTokens({ - model: actualModelId, - messages: [ - { - role: "user", - content: content, - }, - ], + model, + messages: [{ role: "user", content: content }], }) return response.input_tokens diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 98117e99a9..7389611300 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -1,89 +1,142 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { GoogleGenerativeAI } from "@google/generative-ai" +import type { Anthropic } from "@anthropic-ai/sdk" +import { + GoogleGenAI, + ThinkingConfig, + type GenerateContentResponseUsageMetadata, + type GenerateContentParameters, +} from "@google/genai" + import { SingleCompletionHandler } from "../" -import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api" -import { convertAnthropicMessageToGemini } from "../transform/gemini-format" -import { ApiStream } from "../transform/stream" +import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api" +import { geminiDefaultModelId, geminiModels } from "../../shared/api" +import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" +import type { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" -const GEMINI_DEFAULT_TEMPERATURE = 0 - export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - private client: GoogleGenerativeAI + private client: GoogleGenAI constructor(options: ApiHandlerOptions) { super() this.options = options - this.client = new GoogleGenerativeAI(options.geminiApiKey ?? "not-provided") + this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" }) } - override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const model = this.client.getGenerativeModel( - { - model: this.getModel().id, - systemInstruction: systemPrompt, - }, - { - baseUrl: this.options.googleGeminiBaseUrl || undefined, - }, - ) - const result = await model.generateContentStream({ + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const { id: model, thinkingConfig, maxOutputTokens } = this.getModel() + + const params: GenerateContentParameters = { + model, contents: messages.map(convertAnthropicMessageToGemini), - generationConfig: { - // maxOutputTokens: this.getModel().info.maxTokens, - temperature: this.options.modelTemperature ?? GEMINI_DEFAULT_TEMPERATURE, + config: { + thinkingConfig, + maxOutputTokens, + temperature: this.options.modelTemperature ?? 0, + systemInstruction: systemPrompt, }, - }) + } - for await (const chunk of result.stream) { - yield { - type: "text", - text: chunk.text(), + const result = await this.client.models.generateContentStream(params) + + let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined + + for await (const chunk of result) { + if (chunk.text) { + yield { type: "text", text: chunk.text } + } + + if (chunk.usageMetadata) { + lastUsageMetadata = chunk.usageMetadata } } - const response = await result.response - yield { - type: "usage", - inputTokens: response.usageMetadata?.promptTokenCount ?? 0, - outputTokens: response.usageMetadata?.candidatesTokenCount ?? 0, + if (lastUsageMetadata) { + yield { + type: "usage", + inputTokens: lastUsageMetadata.promptTokenCount ?? 0, + outputTokens: lastUsageMetadata.candidatesTokenCount ?? 0, + } } } - override getModel(): { id: GeminiModelId; info: ModelInfo } { - const modelId = this.options.apiModelId - if (modelId && modelId in geminiModels) { - const id = modelId as GeminiModelId - return { id, info: geminiModels[id] } + override getModel(): { + id: GeminiModelId + info: ModelInfo + thinkingConfig?: ThinkingConfig + maxOutputTokens?: number + } { + let id = this.options.apiModelId ? (this.options.apiModelId as GeminiModelId) : geminiDefaultModelId + let info: ModelInfo = geminiModels[id] + let thinkingConfig: ThinkingConfig | undefined = undefined + let maxOutputTokens: number | undefined = undefined + + const thinkingSuffix = ":thinking" + + if (id?.endsWith(thinkingSuffix)) { + id = id.slice(0, -thinkingSuffix.length) as GeminiModelId + info = geminiModels[id] + + thinkingConfig = this.options.modelMaxThinkingTokens + ? { thinkingBudget: this.options.modelMaxThinkingTokens } + : undefined + + maxOutputTokens = this.options.modelMaxTokens ?? info.maxTokens ?? undefined + } + + if (!info) { + id = geminiDefaultModelId + info = geminiModels[geminiDefaultModelId] + thinkingConfig = undefined + maxOutputTokens = undefined } - return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] } + + return { id, info, thinkingConfig, maxOutputTokens } } async completePrompt(prompt: string): Promise { try { - const model = this.client.getGenerativeModel( - { - model: this.getModel().id, - }, - { - baseUrl: this.options.googleGeminiBaseUrl || undefined, - }, - ) + const { id: model } = this.getModel() - const result = await model.generateContent({ + const result = await this.client.models.generateContent({ + model, contents: [{ role: "user", parts: [{ text: prompt }] }], - generationConfig: { - temperature: this.options.modelTemperature ?? GEMINI_DEFAULT_TEMPERATURE, + config: { + httpOptions: this.options.googleGeminiBaseUrl + ? { baseUrl: this.options.googleGeminiBaseUrl } + : undefined, + temperature: this.options.modelTemperature ?? 0, }, }) - return result.response.text() + return result.text ?? "" } catch (error) { if (error instanceof Error) { throw new Error(`Gemini completion error: ${error.message}`) } + throw error } } + + override async countTokens(content: Array): Promise { + try { + const { id: model } = this.getModel() + + const response = await this.client.models.countTokens({ + model, + contents: convertAnthropicContentToGemini(content), + }) + + if (response.totalTokens === undefined) { + console.warn("Gemini token counting returned undefined, using fallback") + return super.countTokens(content) + } + + return response.totalTokens + } catch (error) { + console.warn("Gemini token counting failed, using fallback", error) + return super.countTokens(content) + } + } } diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts index c8fc80d769..ee22cff32a 100644 --- a/src/api/transform/gemini-format.ts +++ b/src/api/transform/gemini-format.ts @@ -1,76 +1,71 @@ import { Anthropic } from "@anthropic-ai/sdk" -import { Content, FunctionCallPart, FunctionResponsePart, InlineDataPart, Part, TextPart } from "@google/generative-ai" +import { Content, Part } from "@google/genai" -function convertAnthropicContentToGemini(content: Anthropic.Messages.MessageParam["content"]): Part[] { +export function convertAnthropicContentToGemini(content: string | Anthropic.ContentBlockParam[]): Part[] { if (typeof content === "string") { - return [{ text: content } as TextPart] + return [{ text: content }] } - return content.flatMap((block) => { + return content.flatMap((block): Part | Part[] => { switch (block.type) { case "text": - return { text: block.text } as TextPart + return { text: block.text } case "image": if (block.source.type !== "base64") { throw new Error("Unsupported image source type") } - return { - inlineData: { - data: block.source.data, - mimeType: block.source.media_type, - }, - } as InlineDataPart + + return { inlineData: { data: block.source.data, mimeType: block.source.media_type } } case "tool_use": return { functionCall: { name: block.name, - args: block.input, + args: block.input as Record, }, - } as FunctionCallPart - case "tool_result": - const name = block.tool_use_id.split("-")[0] + } + case "tool_result": { if (!block.content) { return [] } + + // Extract tool name from tool_use_id (e.g., "calculator-123" -> "calculator") + const toolName = block.tool_use_id.split("-")[0] + if (typeof block.content === "string") { return { - functionResponse: { - name, - response: { - name, - content: block.content, - }, - }, - } as FunctionResponsePart - } else { - // The only case when tool_result could be array is when the tool failed and we're providing ie user feedback potentially with images - const textParts = block.content.filter((part) => part.type === "text") - const imageParts = block.content.filter((part) => part.type === "image") - const text = textParts.length > 0 ? textParts.map((part) => part.text).join("\n\n") : "" - const imageText = imageParts.length > 0 ? "\n\n(See next part for image)" : "" - return [ - { - functionResponse: { - name, - response: { - name, - content: text + imageText, - }, - }, - } as FunctionResponsePart, - ...imageParts.map( - (part) => - ({ - inlineData: { - data: part.source.data, - mimeType: part.source.media_type, - }, - }) as InlineDataPart, - ), - ] + functionResponse: { name: toolName, response: { name: toolName, content: block.content } }, + } + } + + if (!Array.isArray(block.content)) { + return [] } + + const textParts: string[] = [] + const imageParts: Part[] = [] + + for (const item of block.content) { + if (item.type === "text") { + textParts.push(item.text) + } else if (item.type === "image" && item.source.type === "base64") { + const { data, media_type } = item.source + imageParts.push({ inlineData: { data, mimeType: media_type } }) + } + } + + // Create content text with a note about images if present + const contentText = + textParts.join("\n\n") + (imageParts.length > 0 ? "\n\n(See next part for image)" : "") + + // Return function response followed by any images + return [ + { functionResponse: { name: toolName, response: { name: toolName, content: contentText } } }, + ...imageParts, + ] + } default: - throw new Error(`Unsupported content block type: ${(block as any).type}`) + // Currently unsupported: "thinking" | "redacted_thinking" | "document" + throw new Error(`Unsupported content block type: ${block.type}`) } }) } diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index 422297ddd3..d2a28a8273 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -31,6 +31,7 @@ type ProviderSettings = { glamaModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -53,6 +54,7 @@ type ProviderSettings = { openRouterModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -95,6 +97,7 @@ type ProviderSettings = { openAiCustomModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -140,6 +143,7 @@ type ProviderSettings = { unboundModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -161,6 +165,7 @@ type ProviderSettings = { requestyModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined diff --git a/src/exports/types.ts b/src/exports/types.ts index df38b929ed..4280511a8d 100644 --- a/src/exports/types.ts +++ b/src/exports/types.ts @@ -32,6 +32,7 @@ type ProviderSettings = { glamaModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -54,6 +55,7 @@ type ProviderSettings = { openRouterModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -96,6 +98,7 @@ type ProviderSettings = { openAiCustomModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -141,6 +144,7 @@ type ProviderSettings = { unboundModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined @@ -162,6 +166,7 @@ type ProviderSettings = { requestyModelInfo?: | ({ maxTokens?: (number | null) | undefined + maxThinkingTokens?: (number | null) | undefined contextWindow: number supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined diff --git a/src/schemas/index.ts b/src/schemas/index.ts index db9feb54b4..24b224c08b 100644 --- a/src/schemas/index.ts +++ b/src/schemas/index.ts @@ -99,6 +99,7 @@ export type ReasoningEffort = z.infer export const modelInfoSchema = z.object({ maxTokens: z.number().nullish(), + maxThinkingTokens: z.number().nullish(), contextWindow: z.number(), supportsImages: z.boolean().optional(), supportsComputerUse: z.boolean().optional(), diff --git a/src/shared/api.ts b/src/shared/api.ts index 4d71d947ba..ebc0b85c93 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -485,6 +485,16 @@ export const vertexModels = { inputPrice: 0.15, outputPrice: 0.6, }, + "gemini-2.5-flash-preview-04-17:thinking": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + thinking: true, + maxThinkingTokens: 24_576, + }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, contextWindow: 1_048_576, @@ -492,6 +502,7 @@ export const vertexModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, + thinking: false, }, "gemini-2.5-pro-preview-03-25": { maxTokens: 65_535, @@ -640,6 +651,16 @@ export const openAiModelInfoSaneDefaults: ModelInfo = { export type GeminiModelId = keyof typeof geminiModels export const geminiDefaultModelId: GeminiModelId = "gemini-2.0-flash-001" export const geminiModels = { + "gemini-2.5-flash-preview-04-17:thinking": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + thinking: true, + maxThinkingTokens: 24_576, + }, "gemini-2.5-flash-preview-04-17": { maxTokens: 65_535, contextWindow: 1_048_576, @@ -647,6 +668,7 @@ export const geminiModels = { supportsPromptCache: false, inputPrice: 0.15, outputPrice: 0.6, + thinking: false, }, "gemini-2.5-pro-exp-03-25": { maxTokens: 65_535, diff --git a/webview-ui/src/components/settings/ThinkingBudget.tsx b/webview-ui/src/components/settings/ThinkingBudget.tsx index e4cb4f0b9c..5123d571b3 100644 --- a/webview-ui/src/components/settings/ThinkingBudget.tsx +++ b/webview-ui/src/components/settings/ThinkingBudget.tsx @@ -1,10 +1,13 @@ -import { useEffect, useMemo } from "react" +import { useEffect } from "react" import { useAppTranslation } from "@/i18n/TranslationContext" import { Slider } from "@/components/ui" import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api" +const DEFAULT_MAX_OUTPUT_TOKENS = 16_384 +const DEFAULT_MAX_THINKING_TOKENS = 8_192 + interface ThinkingBudgetProps { apiConfiguration: ApiConfiguration setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void @@ -13,57 +16,55 @@ interface ThinkingBudgetProps { export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, modelInfo }: ThinkingBudgetProps) => { const { t } = useAppTranslation() - const tokens = apiConfiguration?.modelMaxTokens || 16_384 - const tokensMin = 8192 - const tokensMax = modelInfo?.maxTokens || 64_000 - // Get the appropriate thinking tokens based on provider - const thinkingTokens = useMemo(() => { - const value = apiConfiguration?.modelMaxThinkingTokens - return value || Math.min(Math.floor(0.8 * tokens), 8192) - }, [apiConfiguration, tokens]) + const isThinkingModel = modelInfo && modelInfo.thinking && modelInfo.maxTokens + + const customMaxOutputTokens = apiConfiguration.modelMaxTokens || DEFAULT_MAX_OUTPUT_TOKENS + const customMaxThinkingTokens = apiConfiguration.modelMaxThinkingTokens || DEFAULT_MAX_THINKING_TOKENS - const thinkingTokensMin = 1024 - const thinkingTokensMax = Math.floor(0.8 * tokens) + // Dynamically expand or shrink the max thinking budget based on the custom + // max output tokens so that there's always a 20% buffer. + const modelMaxThinkingTokens = modelInfo?.maxThinkingTokens + ? Math.min(modelInfo.maxThinkingTokens, Math.floor(0.8 * customMaxOutputTokens)) + : Math.floor(0.8 * customMaxOutputTokens) + // If the custom max thinking tokens are going to exceed it's limit due + // to the custom max output tokens being reduced then we need to shrink it + // appropriately. useEffect(() => { - if (thinkingTokens > thinkingTokensMax) { - setApiConfigurationField("modelMaxThinkingTokens", thinkingTokensMax) + if (isThinkingModel && customMaxThinkingTokens > modelMaxThinkingTokens) { + setApiConfigurationField("modelMaxThinkingTokens", modelMaxThinkingTokens) } - }, [thinkingTokens, thinkingTokensMax, setApiConfigurationField]) - - if (!modelInfo?.thinking) { - return null - } + }, [isThinkingModel, customMaxThinkingTokens, modelMaxThinkingTokens, setApiConfigurationField]) - return ( + return isThinkingModel ? ( <>
{t("settings:thinkingBudget.maxTokens")}
setApiConfigurationField("modelMaxTokens", value)} /> -
{tokens}
+
{customMaxOutputTokens}
{t("settings:thinkingBudget.maxThinkingTokens")}
setApiConfigurationField("modelMaxThinkingTokens", value)} /> -
{thinkingTokens}
+
{customMaxThinkingTokens}
- ) + ) : null }