diff --git a/src/api/providers/__tests__/openai.test.ts b/src/api/providers/__tests__/openai.test.ts index a41a1cc4fa1..950b2165410 100644 --- a/src/api/providers/__tests__/openai.test.ts +++ b/src/api/providers/__tests__/openai.test.ts @@ -352,4 +352,44 @@ describe("OpenAiHandler", () => { ) }) }) + + describe("Grok xAI Provider", () => { + const grokOptions = { + ...mockOptions, + openAiBaseUrl: "https://api.x.ai/v1", + openAiModelId: "grok-1", + } + + it("should initialize with Grok xAI configuration", () => { + const grokHandler = new OpenAiHandler(grokOptions) + expect(grokHandler).toBeInstanceOf(OpenAiHandler) + expect(grokHandler.getModel().id).toBe(grokOptions.openAiModelId) + }) + + it("should exclude stream_options when streaming with Grok xAI", async () => { + const grokHandler = new OpenAiHandler(grokOptions) + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] + + const stream = grokHandler.createMessage(systemPrompt, messages) + await stream.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: grokOptions.openAiModelId, + stream: true, + }), + {}, + ) + + const mockCalls = mockCreate.mock.calls + const lastCall = mockCalls[mockCalls.length - 1] + expect(lastCall[0]).not.toHaveProperty("stream_options") + }) + }) }) diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 4f5477d97d6..fc739b31105 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -137,12 +137,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } + const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), messages: convertedMessages, stream: true as const, - stream_options: { include_usage: true }, + ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), } if (this.options.includeMaxTokens) { requestOptions.max_tokens = modelInfo.maxTokens @@ -265,6 +267,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl if (this.options.openAiStreamingEnabled ?? true) { const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) + const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + const stream = await this.client.chat.completions.create( { model: modelId, @@ -276,7 +280,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ...convertToOpenAiMessages(messages), ], stream: true, - stream_options: { include_usage: true }, + ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), reasoning_effort: this.getModel().info.reasoningEffort, }, methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {}, @@ -337,6 +341,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } + private _isGrokXAI(baseUrl?: string): boolean { + const urlHost = this._getUrlHost(baseUrl) + return urlHost.includes("x.ai") + } + private _isAzureAiInference(baseUrl?: string): boolean { const urlHost = this._getUrlHost(baseUrl) return urlHost.endsWith(".services.ai.azure.com")