Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { UnboundHandler } from "./providers/unbound"
import { RequestyHandler } from "./providers/requesty"
import { HumanRelayHandler } from "./providers/human-relay"
import { FakeAIHandler } from "./providers/fake-ai"
import { XAIHandler } from "./providers/xai"

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
Expand Down Expand Up @@ -78,6 +79,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new HumanRelayHandler(options)
case "fake-ai":
return new FakeAIHandler(options)
case "xai":
return new XAIHandler(options)
default:
return new AnthropicHandler(options)
}
Expand Down
292 changes: 292 additions & 0 deletions src/api/providers/__tests__/xai.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import { XAIHandler } from "../xai"
import { xaiDefaultModelId, xaiModels } from "../../../shared/api"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

// Mock OpenAI client
jest.mock("openai", () => {
const createMock = jest.fn()
return jest.fn(() => ({
chat: {
completions: {
create: createMock,
},
},
}))
})

describe("XAIHandler", () => {
let handler: XAIHandler
let mockCreate: jest.Mock

beforeEach(() => {
// Reset all mocks
jest.clearAllMocks()

// Get the mock create function
mockCreate = (OpenAI as unknown as jest.Mock)().chat.completions.create

// Create handler with mock
handler = new XAIHandler({})
})

test("should use the correct X.AI base URL", () => {
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: "https://api.x.ai/v1",
}),
)
})

test("should use the provided API key", () => {
// Clear mocks before this specific test
jest.clearAllMocks()

// Create a handler with our API key
const xaiApiKey = "test-api-key"
new XAIHandler({ xaiApiKey })

// Verify the OpenAI constructor was called with our API key
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: xaiApiKey,
}),
)
})

test("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(xaiDefaultModelId)
expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
})

test("should return specified model when valid model is provided", () => {
const testModelId = "grok-2-latest"
const handlerWithModel = new XAIHandler({ apiModelId: testModelId })
const model = handlerWithModel.getModel()

expect(model.id).toBe(testModelId)
expect(model.info).toEqual(xaiModels[testModelId])
})

test("should include reasoning_effort parameter for mini models", async () => {
const miniModelHandler = new XAIHandler({
apiModelId: "grok-3-mini-beta",
reasoningEffort: "high",
})

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

// Start generating a message
const messageGenerator = miniModelHandler.createMessage("test prompt", [])
await messageGenerator.next() // Start the generator

// Check that reasoning_effort was included
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
reasoning_effort: "high",
}),
)
})

test("should not include reasoning_effort parameter for non-mini models", async () => {
const regularModelHandler = new XAIHandler({
apiModelId: "grok-2-latest",
reasoningEffort: "high",
})

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

// Start generating a message
const messageGenerator = regularModelHandler.createMessage("test prompt", [])
await messageGenerator.next() // Start the generator

// Check call args for reasoning_effort
const calls = mockCreate.mock.calls
const lastCall = calls[calls.length - 1][0]
expect(lastCall).not.toHaveProperty("reasoning_effort")
})

test("completePrompt method should return text from OpenAI API", async () => {
const expectedResponse = "This is a test response"

mockCreate.mockResolvedValueOnce({
choices: [
{
message: {
content: expectedResponse,
},
},
],
})

const result = await handler.completePrompt("test prompt")
expect(result).toBe(expectedResponse)
})

test("should handle errors in completePrompt", async () => {
const errorMessage = "API error"
mockCreate.mockRejectedValueOnce(new Error(errorMessage))

await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`)
})

test("createMessage should yield text content from stream", async () => {
const testContent = "This is test content"

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: { content: testContent } }],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

// Create and consume the stream
const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

// Verify the content
expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
type: "text",
text: testContent,
})
})

test("createMessage should yield reasoning content from stream", async () => {
const testReasoning = "Test reasoning content"

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: { reasoning_content: testReasoning } }],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

// Create and consume the stream
const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

// Verify the reasoning content
expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
type: "reasoning",
text: testReasoning,
})
})

test("createMessage should yield usage data from stream", async () => {
// Setup mock for streaming response that includes usage data
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: {} }], // Needs to have choices array to avoid error
usage: {
prompt_tokens: 10,
completion_tokens: 20,
cache_read_input_tokens: 5,
cache_creation_input_tokens: 15,
},
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

// Create and consume the stream
const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

// Verify the usage data
expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 20,
cacheReadTokens: 5,
cacheWriteTokens: 15,
})
})

test("createMessage should pass correct parameters to OpenAI client", async () => {
// Setup a handler with specific model
const modelId = "grok-2-latest"
const modelInfo = xaiModels[modelId]
const handlerWithModel = new XAIHandler({ apiModelId: modelId })

// Setup mock for streaming response
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

// System prompt and messages
const systemPrompt = "Test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]

// Start generating a message
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next() // Start the generator

// Check that all parameters were passed correctly
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
}),
)
})
})
9 changes: 9 additions & 0 deletions src/api/providers/constants.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
export const DEFAULT_HEADERS = {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
}

export const ANTHROPIC_DEFAULT_MAX_TOKENS = 8192

export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6

export const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"

export const REASONING_MODELS = new Set(["x-ai/grok-3-mini-beta", "grok-3-mini-beta", "grok-3-mini-fast-beta"])
15 changes: 4 additions & 11 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,10 @@ import { convertToSimpleMessages } from "../transform/simple-format"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { XmlMatcher } from "../../utils/xml-matcher"
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"

export const defaultHeaders = {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
}
import { DEEP_SEEK_DEFAULT_TEMPERATURE, DEFAULT_HEADERS, AZURE_AI_INFERENCE_PATH } from "./constants"

export interface OpenAiHandlerOptions extends ApiHandlerOptions {}

const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"

export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: OpenAiHandlerOptions
private client: OpenAI
Expand All @@ -45,7 +38,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders,
defaultHeaders: DEFAULT_HEADERS,
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
})
} else if (isAzureOpenAi) {
Expand All @@ -56,7 +49,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
apiKey,
apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
defaultHeaders: {
...defaultHeaders,
...DEFAULT_HEADERS,
...(this.options.openAiHostHeader ? { Host: this.options.openAiHostHeader } : {}),
},
})
Expand All @@ -65,7 +58,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
baseURL,
apiKey,
defaultHeaders: {
...defaultHeaders,
...DEFAULT_HEADERS,
...(this.options.openAiHostHeader ? { Host: this.options.openAiHostHeader } : {}),
},
})
Expand Down
5 changes: 2 additions & 3 deletions src/api/providers/openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
import { convertToR1Format } from "../transform/r1-format"

import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
import { getModelParams, SingleCompletionHandler } from ".."
import { BaseProvider } from "./base-provider"
import { defaultHeaders } from "./openai"

const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]"

Expand Down Expand Up @@ -40,7 +39,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
const baseURL = this.options.openRouterBaseUrl || "https://openrouter.ai/api/v1"
const apiKey = this.options.openRouterApiKey ?? "not-provided"

this.client = new OpenAI({ baseURL, apiKey, defaultHeaders })
this.client = new OpenAI({ baseURL, apiKey, defaultHeaders: DEFAULT_HEADERS })
}

override async *createMessage(
Expand Down
Loading
Loading