Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/api/providers/__tests__/openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,44 @@ describe("OpenAiHandler", () => {
)
})
})

describe("Grok xAI Provider", () => {
const grokOptions = {
...mockOptions,
openAiBaseUrl: "https://api.x.ai/v1",
openAiModelId: "grok-1",
}

it("should initialize with Grok xAI configuration", () => {
const grokHandler = new OpenAiHandler(grokOptions)
expect(grokHandler).toBeInstanceOf(OpenAiHandler)
expect(grokHandler.getModel().id).toBe(grokOptions.openAiModelId)
})

it("should exclude stream_options when streaming with Grok xAI", async () => {
const grokHandler = new OpenAiHandler(grokOptions)
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = grokHandler.createMessage(systemPrompt, messages)
await stream.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: grokOptions.openAiModelId,
stream: true,
}),
{},
)

const mockCalls = mockCreate.mock.calls
const lastCall = mockCalls[mockCalls.length - 1]
expect(lastCall[0]).not.toHaveProperty("stream_options")
})
})
})
13 changes: 11 additions & 2 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}
}

const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting the logic for constructing streaming options into a helper function to reduce duplication.

model: modelId,
temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
messages: convertedMessages,
stream: true as const,
stream_options: { include_usage: true },
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
}
if (this.options.includeMaxTokens) {
requestOptions.max_tokens = modelInfo.maxTokens
Expand Down Expand Up @@ -265,6 +267,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
if (this.options.openAiStreamingEnabled ?? true) {
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)

const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)

const stream = await this.client.chat.completions.create(
{
model: modelId,
Expand All @@ -276,7 +280,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
...convertToOpenAiMessages(messages),
],
stream: true,
stream_options: { include_usage: true },
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
reasoning_effort: this.getModel().info.reasoningEffort,
},
methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
Expand Down Expand Up @@ -337,6 +341,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}
}

private _isGrokXAI(baseUrl?: string): boolean {
const urlHost = this._getUrlHost(baseUrl)
return urlHost.includes("x.ai")
}

private _isAzureAiInference(baseUrl?: string): boolean {
const urlHost = this._getUrlHost(baseUrl)
return urlHost.endsWith(".services.ai.azure.com")
Expand Down