From fdc8085ed85e69c3ced70ad55aa7cccb04b4cd27 Mon Sep 17 00:00:00 2001 From: amittell Date: Thu, 10 Apr 2025 23:48:10 -0400 Subject: [PATCH] Add reasoning_effort to openai compatible provider for grok-mini --- src/api/providers/__tests__/openai.test.ts | 132 ++++++++++++++++++ src/api/providers/openai.ts | 8 +- src/shared/ExtensionMessage.ts | 2 + src/shared/WebviewMessage.ts | 1 + src/shared/api.ts | 1 + .../src/components/settings/ApiOptions.tsx | 29 ++++ .../settings/GrokReasoningSettings.tsx | 29 ++++ 7 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 webview-ui/src/components/settings/GrokReasoningSettings.tsx diff --git a/src/api/providers/__tests__/openai.test.ts b/src/api/providers/__tests__/openai.test.ts index 950b2165410..1796a8b0ece 100644 --- a/src/api/providers/__tests__/openai.test.ts +++ b/src/api/providers/__tests__/openai.test.ts @@ -392,4 +392,136 @@ describe("OpenAiHandler", () => { expect(lastCall[0]).not.toHaveProperty("stream_options") }) }) + + describe("Grok 3 Mini models with reasoning", () => { + const grokMiniOptions = { + ...mockOptions, + openAiBaseUrl: "https://api.x.ai/v1", + openAiModelId: "grok-3-mini-beta", + openAiCustomModelInfo: { + reasoningEffort: "low" as const, + thinking: true, + contextWindow: 128_000, + supportsPromptCache: false, + maxTokens: -1, + supportsImages: true, + inputPrice: 0, + outputPrice: 0, + }, + } + it("should include reasoning_effort parameter for Grok mini models", async () => { + const grokHandler = new OpenAiHandler(grokMiniOptions) + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] + + const stream = grokHandler.createMessage(systemPrompt, messages) + await stream.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: grokMiniOptions.openAiModelId, + stream: true, + reasoning_effort: "low", + }), + {}, + ) + }) + + it("should use the specified reasoningEffort value", async () => { + const grokHandler = new OpenAiHandler({ + ...grokMiniOptions, + openAiCustomModelInfo: { + ...grokMiniOptions.openAiCustomModelInfo, + reasoningEffort: "high", + }, + }) + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] + + const stream = grokHandler.createMessage(systemPrompt, messages) + await stream.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: grokMiniOptions.openAiModelId, + stream: true, + reasoning_effort: "high", + }), + {}, + ) + }) + + it("should process reasoning_content from response", async () => { + // Update the mock to include reasoning_content in the response + mockCreate.mockImplementationOnce(() => ({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: { reasoning_content: "This is reasoning content" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + })) + + const grokHandler = new OpenAiHandler(grokMiniOptions) + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] + + const stream = grokHandler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((chunk) => chunk.type === "text") + const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") + + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + + expect(reasoningChunks).toHaveLength(1) + expect(reasoningChunks[0].text).toBe("This is reasoning content") + }) + }) }) diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index fc739b31105..06b5a8a7269 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -138,6 +138,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + const isGrokMiniModel = modelId.includes("grok-3-mini") + const useGrokReasoning = modelInfo.thinking || (isGrokXAI && isGrokMiniModel) const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, @@ -145,6 +147,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl messages: convertedMessages, stream: true as const, ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), + ...(useGrokReasoning ? { reasoning_effort: modelInfo.reasoningEffort || "low" } : {}), } if (this.options.includeMaxTokens) { requestOptions.max_tokens = modelInfo.maxTokens @@ -267,7 +270,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl if (this.options.openAiStreamingEnabled ?? true) { const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) + const modelInfo = this.getModel().info const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + const isGrokMiniModel = modelId.includes("grok-3-mini") + const useGrokReasoning = modelInfo.thinking || (isGrokXAI && isGrokMiniModel) const stream = await this.client.chat.completions.create( { @@ -281,7 +287,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ], stream: true, ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), - reasoning_effort: this.getModel().info.reasoningEffort, + ...(useGrokReasoning ? { reasoning_effort: this.getModel().info.reasoningEffort || "low" } : {}), }, methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {}, ) diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 38277a7c2de..f7ea149bbd5 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -199,6 +199,8 @@ export type ExtensionState = Pick< telemetrySetting: TelemetrySetting telemetryKey?: string machineId?: string + reasoningEffort?: "low" | "medium" | "high" // The reasoning effort level for models that support reasoning + grokReasoningEffort?: "low" | "high" // The reasoning effort level for Grok 3 Mini models renderContext: "sidebar" | "editor" settingsImportedAt?: number diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 972845959e3..37e5aa10915 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -120,6 +120,7 @@ export interface WebviewMessage { | "maxReadFileLine" | "searchFiles" | "toggleApiConfigPin" + | "reasoningEffort" text?: string disabled?: boolean askResponse?: ClineAskResponse diff --git a/src/shared/api.ts b/src/shared/api.ts index cd818fd1a5d..96e54702f64 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -622,6 +622,7 @@ export const openAiModelInfoSaneDefaults: ModelInfo = { supportsPromptCache: false, inputPrice: 0, outputPrice: 0, + reasoningEffort: "low", } // Gemini diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 55690d48069..79ce79774ef 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -7,6 +7,7 @@ import { LanguageModelChatSelector } from "vscode" import { Checkbox } from "vscrui" import { VSCodeLink, VSCodeRadio, VSCodeRadioGroup, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" import { ExternalLinkIcon } from "@radix-ui/react-icons" +import { GrokReasoningSettings } from "./GrokReasoningSettings" import { ApiConfiguration, @@ -129,6 +130,16 @@ const ApiOptions = ({ [apiConfiguration], ) + // Check if the current model is a Grok 3 Mini model + const isGrokMiniModel = useMemo(() => { + return selectedModelId?.includes("grok-3-mini-beta") || selectedModelId?.includes("grok-3-mini-fast-beta") + }, [selectedModelId]) + + // Check if the endpoint is x.ai + const isXaiEndpoint = useMemo(() => { + return apiConfiguration?.openAiBaseUrl?.includes("x.ai") || false + }, [apiConfiguration?.openAiBaseUrl]) + // Debounced refresh model updates, only executed 250ms after the user // stops typing. useDebounce( @@ -800,6 +811,24 @@ const ApiOptions = ({ onChange={handleInputChange("openAiStreamingEnabled", noTransform)}> {t("settings:modelInfo.enableStreaming")} + + {/* Grok3 Reasoning Settings - Only show for Grok Mini models */} + {isGrokMiniModel && isXaiEndpoint && ( + { + setApiConfigurationField("openAiCustomModelInfo", { + ...(apiConfiguration?.openAiCustomModelInfo || openAiModelInfoSaneDefaults), + reasoningEffort: value, + }) + }} + /> + )} + diff --git a/webview-ui/src/components/settings/GrokReasoningSettings.tsx b/webview-ui/src/components/settings/GrokReasoningSettings.tsx new file mode 100644 index 00000000000..0a4c618a114 --- /dev/null +++ b/webview-ui/src/components/settings/GrokReasoningSettings.tsx @@ -0,0 +1,29 @@ +import React from "react" +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui" + +interface GrokReasoningSettingsProps { + reasoningEffort: "low" | "high" + setReasoningEffort: (value: "low" | "high") => void +} + +export const GrokReasoningSettings: React.FC = ({ + reasoningEffort, + setReasoningEffort, +}) => { + return ( +
+
+ + +
+
+ ) +}