Skip to content

Commit 31c855e

Browse files
committed
Add xAI provider
1 parent 3b19d7a commit 31c855e

File tree

17 files changed

+574
-0
lines changed

17 files changed

+574
-0
lines changed

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { UnboundHandler } from "./providers/unbound"
2121
import { RequestyHandler } from "./providers/requesty"
2222
import { HumanRelayHandler } from "./providers/human-relay"
2323
import { FakeAIHandler } from "./providers/fake-ai"
24+
import { XAIHandler } from "./providers/xai"
2425

2526
export interface SingleCompletionHandler {
2627
completePrompt(prompt: string): Promise<string>
@@ -78,6 +79,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
7879
return new HumanRelayHandler(options)
7980
case "fake-ai":
8081
return new FakeAIHandler(options)
82+
case "xai":
83+
return new XAIHandler(options)
8184
default:
8285
return new AnthropicHandler(options)
8386
}
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import { XAIHandler } from "../xai"
2+
import { xaiDefaultModelId, xaiModels } from "../../../shared/api"
3+
import OpenAI from "openai"
4+
import { Anthropic } from "@anthropic-ai/sdk"
5+
6+
// Mock OpenAI client
7+
jest.mock("openai", () => {
8+
const createMock = jest.fn()
9+
return jest.fn(() => ({
10+
chat: {
11+
completions: {
12+
create: createMock,
13+
},
14+
},
15+
}))
16+
})
17+
18+
describe("XAIHandler", () => {
19+
let handler: XAIHandler
20+
let mockCreate: jest.Mock
21+
22+
beforeEach(() => {
23+
// Reset all mocks
24+
jest.clearAllMocks()
25+
26+
// Get the mock create function
27+
mockCreate = (OpenAI as unknown as jest.Mock)().chat.completions.create
28+
29+
// Create handler with mock
30+
handler = new XAIHandler({})
31+
})
32+
33+
test("should use the correct X.AI base URL", () => {
34+
expect(OpenAI).toHaveBeenCalledWith(
35+
expect.objectContaining({
36+
baseURL: "https://api.x.ai/v1",
37+
}),
38+
)
39+
})
40+
41+
test("should use the provided API key", () => {
42+
// Clear mocks before this specific test
43+
jest.clearAllMocks()
44+
45+
// Create a handler with our API key
46+
const xaiApiKey = "test-api-key"
47+
new XAIHandler({ xaiApiKey })
48+
49+
// Verify the OpenAI constructor was called with our API key
50+
expect(OpenAI).toHaveBeenCalledWith(
51+
expect.objectContaining({
52+
apiKey: xaiApiKey,
53+
}),
54+
)
55+
})
56+
57+
test("should return default model when no model is specified", () => {
58+
const model = handler.getModel()
59+
expect(model.id).toBe(xaiDefaultModelId)
60+
expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
61+
})
62+
63+
test("should return specified model when valid model is provided", () => {
64+
const testModelId = "grok-2-latest"
65+
const handlerWithModel = new XAIHandler({ apiModelId: testModelId })
66+
const model = handlerWithModel.getModel()
67+
68+
expect(model.id).toBe(testModelId)
69+
expect(model.info).toEqual(xaiModels[testModelId])
70+
})
71+
72+
test("should include reasoning_effort parameter for mini models", async () => {
73+
const miniModelHandler = new XAIHandler({
74+
apiModelId: "grok-3-mini-beta",
75+
reasoningEffort: "high",
76+
})
77+
78+
// Setup mock for streaming response
79+
mockCreate.mockImplementationOnce(() => {
80+
return {
81+
[Symbol.asyncIterator]: () => ({
82+
async next() {
83+
return { done: true }
84+
},
85+
}),
86+
}
87+
})
88+
89+
// Start generating a message
90+
const messageGenerator = miniModelHandler.createMessage("test prompt", [])
91+
await messageGenerator.next() // Start the generator
92+
93+
// Check that reasoning_effort was included
94+
expect(mockCreate).toHaveBeenCalledWith(
95+
expect.objectContaining({
96+
reasoning_effort: "high",
97+
}),
98+
)
99+
})
100+
101+
test("should not include reasoning_effort parameter for non-mini models", async () => {
102+
const regularModelHandler = new XAIHandler({
103+
apiModelId: "grok-2-latest",
104+
reasoningEffort: "high",
105+
})
106+
107+
// Setup mock for streaming response
108+
mockCreate.mockImplementationOnce(() => {
109+
return {
110+
[Symbol.asyncIterator]: () => ({
111+
async next() {
112+
return { done: true }
113+
},
114+
}),
115+
}
116+
})
117+
118+
// Start generating a message
119+
const messageGenerator = regularModelHandler.createMessage("test prompt", [])
120+
await messageGenerator.next() // Start the generator
121+
122+
// Check call args for reasoning_effort
123+
const calls = mockCreate.mock.calls
124+
const lastCall = calls[calls.length - 1][0]
125+
expect(lastCall).not.toHaveProperty("reasoning_effort")
126+
})
127+
128+
test("completePrompt method should return text from OpenAI API", async () => {
129+
const expectedResponse = "This is a test response"
130+
131+
mockCreate.mockResolvedValueOnce({
132+
choices: [
133+
{
134+
message: {
135+
content: expectedResponse,
136+
},
137+
},
138+
],
139+
})
140+
141+
const result = await handler.completePrompt("test prompt")
142+
expect(result).toBe(expectedResponse)
143+
})
144+
145+
test("should handle errors in completePrompt", async () => {
146+
const errorMessage = "API error"
147+
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
148+
149+
await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`)
150+
})
151+
152+
test("createMessage should yield text content from stream", async () => {
153+
const testContent = "This is test content"
154+
155+
// Setup mock for streaming response
156+
mockCreate.mockImplementationOnce(() => {
157+
return {
158+
[Symbol.asyncIterator]: () => ({
159+
next: jest
160+
.fn()
161+
.mockResolvedValueOnce({
162+
done: false,
163+
value: {
164+
choices: [{ delta: { content: testContent } }],
165+
},
166+
})
167+
.mockResolvedValueOnce({ done: true }),
168+
}),
169+
}
170+
})
171+
172+
// Create and consume the stream
173+
const stream = handler.createMessage("system prompt", [])
174+
const firstChunk = await stream.next()
175+
176+
// Verify the content
177+
expect(firstChunk.done).toBe(false)
178+
expect(firstChunk.value).toEqual({
179+
type: "text",
180+
text: testContent,
181+
})
182+
})
183+
184+
test("createMessage should yield reasoning content from stream", async () => {
185+
const testReasoning = "Test reasoning content"
186+
187+
// Setup mock for streaming response
188+
mockCreate.mockImplementationOnce(() => {
189+
return {
190+
[Symbol.asyncIterator]: () => ({
191+
next: jest
192+
.fn()
193+
.mockResolvedValueOnce({
194+
done: false,
195+
value: {
196+
choices: [{ delta: { reasoning_content: testReasoning } }],
197+
},
198+
})
199+
.mockResolvedValueOnce({ done: true }),
200+
}),
201+
}
202+
})
203+
204+
// Create and consume the stream
205+
const stream = handler.createMessage("system prompt", [])
206+
const firstChunk = await stream.next()
207+
208+
// Verify the reasoning content
209+
expect(firstChunk.done).toBe(false)
210+
expect(firstChunk.value).toEqual({
211+
type: "reasoning",
212+
text: testReasoning,
213+
})
214+
})
215+
216+
test("createMessage should yield usage data from stream", async () => {
217+
// Setup mock for streaming response that includes usage data
218+
mockCreate.mockImplementationOnce(() => {
219+
return {
220+
[Symbol.asyncIterator]: () => ({
221+
next: jest
222+
.fn()
223+
.mockResolvedValueOnce({
224+
done: false,
225+
value: {
226+
choices: [{ delta: {} }], // Needs to have choices array to avoid error
227+
usage: {
228+
prompt_tokens: 10,
229+
completion_tokens: 20,
230+
cache_read_input_tokens: 5,
231+
cache_creation_input_tokens: 15,
232+
},
233+
},
234+
})
235+
.mockResolvedValueOnce({ done: true }),
236+
}),
237+
}
238+
})
239+
240+
// Create and consume the stream
241+
const stream = handler.createMessage("system prompt", [])
242+
const firstChunk = await stream.next()
243+
244+
// Verify the usage data
245+
expect(firstChunk.done).toBe(false)
246+
expect(firstChunk.value).toEqual({
247+
type: "usage",
248+
inputTokens: 10,
249+
outputTokens: 20,
250+
cacheReadTokens: 5,
251+
cacheWriteTokens: 15,
252+
})
253+
})
254+
255+
test("createMessage should pass correct parameters to OpenAI client", async () => {
256+
// Setup a handler with specific model
257+
const modelId = "grok-2-latest"
258+
const modelInfo = xaiModels[modelId]
259+
const handlerWithModel = new XAIHandler({ apiModelId: modelId })
260+
261+
// Setup mock for streaming response
262+
mockCreate.mockImplementationOnce(() => {
263+
return {
264+
[Symbol.asyncIterator]: () => ({
265+
async next() {
266+
return { done: true }
267+
},
268+
}),
269+
}
270+
})
271+
272+
// System prompt and messages
273+
const systemPrompt = "Test system prompt"
274+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
275+
276+
// Start generating a message
277+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
278+
await messageGenerator.next() // Start the generator
279+
280+
// Check that all parameters were passed correctly
281+
expect(mockCreate).toHaveBeenCalledWith(
282+
expect.objectContaining({
283+
model: modelId,
284+
max_tokens: modelInfo.maxTokens,
285+
temperature: 0,
286+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
287+
stream: true,
288+
stream_options: { include_usage: true },
289+
}),
290+
)
291+
})
292+
})

0 commit comments

Comments
 (0)