Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { UnboundHandler } from "./providers/unbound"
import { RequestyHandler } from "./providers/requesty"
import { HumanRelayHandler } from "./providers/human-relay"
import { FakeAIHandler } from "./providers/fake-ai"
import { XAIHandler } from "./providers/xai"

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
Expand Down Expand Up @@ -78,6 +79,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new HumanRelayHandler(options)
case "fake-ai":
return new FakeAIHandler(options)
case "xai":
return new XAIHandler(options)
default:
return new AnthropicHandler(options)
}
Expand Down
292 changes: 292 additions & 0 deletions src/api/providers/__tests__/xai.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import { XAIHandler } from "../xai"
import { xaiDefaultModelId, xaiModels } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

// Mock OpenAI client
jest.mock("openai", () => {
const createMock = jest.fn()
return jest.fn(() => ({
chat: {
completions: {
create: createMock,
},
},
}))
})

describe("XAIHandler", () => {
let handler: XAIHandler
let mockCreate: jest.Mock

beforeEach(() => {
// Reset all mocks
jest.clearAllMocks()

// Get the mock create function
mockCreate = (OpenAI as unknown as jest.Mock)().chat.completions.create

// Create handler with mock
handler = new XAIHandler({})
})

test("should use the correct X.AI base URL", () => {
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: "https://api.x.ai/v1",
}),
)
})

test("should use the provided API key", () => {
// Clear mocks before this specific test
jest.clearAllMocks()

// Create a handler with our API key
const xaiApiKey = "test-api-key"
new XAIHandler({ xaiApiKey })

// Verify the OpenAI constructor was called with our API key
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: xaiApiKey,
}),
)
})

test("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(xaiDefaultModelId)
expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
})

test("should return specified model when valid model is provided", () => {
const testModelId = "grok-2-latest"
const handlerWithModel = new XAIHandler({ apiModelId: testModelId })
const model = handlerWithModel.getModel()

expect(model.id).toBe(testModelId)
expect(model.info).toEqual(xaiModels[testModelId])
})

test("should include reasoning_effort parameter for mini models", async () => {
const miniModelHandler = new XAIHandler({
apiModelId: "grok-3-mini-beta",
reasoningEffort: "high",
})

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

// Start generating a message
const messageGenerator = miniModelHandler.createMessage("test prompt", [])
await messageGenerator.next() // Start the generator

// Check that reasoning_effort was included
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
reasoning_effort: "high",
}),
)
})

test("should not include reasoning_effort parameter for non-mini models", async () => {
const regularModelHandler = new XAIHandler({
apiModelId: "grok-2-latest",
reasoningEffort: "high",
})

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

// Start generating a message
const messageGenerator = regularModelHandler.createMessage("test prompt", [])
await messageGenerator.next() // Start the generator

// Check call args for reasoning_effort
const calls = mockCreate.mock.calls
const lastCall = calls[calls.length - 1][0]
expect(lastCall).not.toHaveProperty("reasoning_effort")
})

test("completePrompt method should return text from OpenAI API", async () => {
const expectedResponse = "This is a test response"

mockCreate.mockResolvedValueOnce({
choices: [
{
message: {
content: expectedResponse,
},
},
],
})

const result = await handler.completePrompt("test prompt")
expect(result).toBe(expectedResponse)
})

test("should handle errors in completePrompt", async () => {
const errorMessage = "API error"
mockCreate.mockRejectedValueOnce(new Error(errorMessage))

await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`)
})

test("createMessage should yield text content from stream", async () => {
const testContent = "This is test content"

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: { content: testContent } }],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

// Create and consume the stream
const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

// Verify the content
expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
type: "text",
text: testContent,
})
})

test("createMessage should yield reasoning content from stream", async () => {
const testReasoning = "Test reasoning content"

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: { reasoning_content: testReasoning } }],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

// Create and consume the stream
const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

// Verify the reasoning content
expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
type: "reasoning",
text: testReasoning,
})
})

test("createMessage should yield usage data from stream", async () => {
// Setup mock for streaming response that includes usage data
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: {} }], // Needs to have choices array to avoid error
usage: {
prompt_tokens: 10,
completion_tokens: 20,
cache_read_input_tokens: 5,
cache_creation_input_tokens: 15,
},
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

// Create and consume the stream
const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

// Verify the usage data
expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 20,
cacheReadTokens: 5,
cacheWriteTokens: 15,
})
})

test("createMessage should pass correct parameters to OpenAI client", async () => {
// Setup a handler with specific model
const modelId = "grok-2-latest"
const modelInfo = xaiModels[modelId]
const handlerWithModel = new XAIHandler({ apiModelId: modelId })

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

// System prompt and messages
const systemPrompt = "Test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]

// Start generating a message
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next() // Start the generator

// Check that all parameters were passed correctly
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
}),
)
})
})
9 changes: 9 additions & 0 deletions src/api/providers/constants.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
export const DEFAULT_HEADERS = {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
}

export const ANTHROPIC_DEFAULT_MAX_TOKENS = 8192

export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6

export const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"

export const REASONING_MODELS = new Set(["x-ai/grok-3-mini-beta", "grok-3-mini-beta", "grok-3-mini-fast-beta"])
15 changes: 4 additions & 11 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,10 @@ import { convertToSimpleMessages } from "../transform/simple-format"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { XmlMatcher } from "../../utils/xml-matcher"
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"

export const defaultHeaders = {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
}
import { DEEP_SEEK_DEFAULT_TEMPERATURE, DEFAULT_HEADERS, AZURE_AI_INFERENCE_PATH } from "./constants"

export interface OpenAiHandlerOptions extends ApiHandlerOptions {}

const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"

export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: OpenAiHandlerOptions
private client: OpenAI
Expand All @@ -45,7 +38,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders,
defaultHeaders: DEFAULT_HEADERS,
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
})
} else if (isAzureOpenAi) {
Expand All @@ -56,7 +49,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
apiKey,
apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
defaultHeaders: {
...defaultHeaders,
...DEFAULT_HEADERS,
...(this.options.openAiHostHeader ? { Host: this.options.openAiHostHeader } : {}),
},
})
Expand All @@ -65,7 +58,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
baseURL,
apiKey,
defaultHeaders: {
...defaultHeaders,
...DEFAULT_HEADERS,
...(this.options.openAiHostHeader ? { Host: this.options.openAiHostHeader } : {}),
},
})
Expand Down
5 changes: 2 additions & 3 deletions src/api/providers/openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
import { convertToR1Format } from "../transform/r1-format"

import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
import { getModelParams, SingleCompletionHandler } from ".."
import { BaseProvider } from "./base-provider"
import { defaultHeaders } from "./openai"

const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]"

Expand Down Expand Up @@ -40,7 +39,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
const baseURL = this.options.openRouterBaseUrl || "https://openrouter.ai/api/v1"
const apiKey = this.options.openRouterApiKey ?? "not-provided"

this.client = new OpenAI({ baseURL, apiKey, defaultHeaders })
this.client = new OpenAI({ baseURL, apiKey, defaultHeaders: DEFAULT_HEADERS })
}

override async *createMessage(
Expand Down
Loading