Skip to content

Commit a417819

Browse files
committed
feat: add OpenAI-compatible provider for NVIDIA API support
- Add new OpenAiCompatibleHandler provider for generic OpenAI-compatible APIs - Add configuration support in provider-settings.ts and API index - Support custom base URLs and API keys for services like NVIDIA - Include comprehensive tests for the new provider - Fixes #8998
1 parent 5c738de commit a417819

File tree

5 files changed

+337
-1
lines changed

5 files changed

+337
-1
lines changed

packages/types/src/provider-settings.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ export const providerNames = [
132132
"mistral",
133133
"moonshot",
134134
"minimax",
135+
"openai-compatible",
135136
"openai-native",
136137
"qwen-code",
137138
"roo",
@@ -420,6 +421,11 @@ const vercelAiGatewaySchema = baseProviderSettingsSchema.extend({
420421
vercelAiGatewayModelId: z.string().optional(),
421422
})
422423

424+
const openAiCompatibleSchema = apiModelIdProviderModelSchema.extend({
425+
openAiCompatibleBaseUrl: z.string().optional(),
426+
openAiCompatibleApiKey: z.string().optional(),
427+
})
428+
423429
const defaultSchema = z.object({
424430
apiProvider: z.undefined(),
425431
})
@@ -462,6 +468,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
462468
qwenCodeSchema.merge(z.object({ apiProvider: z.literal("qwen-code") })),
463469
rooSchema.merge(z.object({ apiProvider: z.literal("roo") })),
464470
vercelAiGatewaySchema.merge(z.object({ apiProvider: z.literal("vercel-ai-gateway") })),
471+
openAiCompatibleSchema.merge(z.object({ apiProvider: z.literal("openai-compatible") })),
465472
defaultSchema,
466473
])
467474

@@ -504,6 +511,7 @@ export const providerSettingsSchema = z.object({
504511
...qwenCodeSchema.shape,
505512
...rooSchema.shape,
506513
...vercelAiGatewaySchema.shape,
514+
...openAiCompatibleSchema.shape,
507515
...codebaseIndexProviderSchema.shape,
508516
})
509517

@@ -590,6 +598,7 @@ export const modelIdKeysByProvider: Record<TypicalProvider, ModelIdKey> = {
590598
"io-intelligence": "ioIntelligenceModelId",
591599
roo: "apiModelId",
592600
"vercel-ai-gateway": "vercelAiGatewayModelId",
601+
"openai-compatible": "apiModelId",
593602
}
594603

595604
/**
@@ -626,7 +635,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str
626635
*/
627636

628637
export const MODELS_BY_PROVIDER: Record<
629-
Exclude<ProviderName, "fake-ai" | "human-relay" | "gemini-cli" | "openai">,
638+
Exclude<ProviderName, "fake-ai" | "human-relay" | "gemini-cli" | "openai" | "openai-compatible">,
630639
{ id: ProviderName; label: string; models: string[] }
631640
> = {
632641
anthropic: {

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import {
1313
VertexHandler,
1414
AnthropicVertexHandler,
1515
OpenAiHandler,
16+
OpenAiCompatibleHandler,
1617
LmStudioHandler,
1718
GeminiHandler,
1819
OpenAiNativeHandler,
@@ -168,6 +169,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
168169
return new VercelAiGatewayHandler(options)
169170
case "minimax":
170171
return new MiniMaxHandler(options)
172+
case "openai-compatible":
173+
return new OpenAiCompatibleHandler(options)
171174
default:
172175
apiProvider satisfies "gemini-cli" | undefined
173176
return new AnthropicHandler(options)
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import OpenAI from "openai"
3+
4+
import { OpenAiCompatibleHandler } from "../openai-compatible"
5+
6+
vi.mock("openai")
7+
8+
describe("OpenAiCompatibleHandler", () => {
9+
let handler: OpenAiCompatibleHandler
10+
let mockCreate: any
11+
12+
beforeEach(() => {
13+
vi.clearAllMocks()
14+
mockCreate = vi.fn()
15+
;(OpenAI as any).mockImplementation(() => ({
16+
chat: {
17+
completions: {
18+
create: mockCreate,
19+
},
20+
},
21+
}))
22+
})
23+
24+
describe("initialization", () => {
25+
it("should create handler with valid configuration", () => {
26+
handler = new OpenAiCompatibleHandler({
27+
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
28+
openAiCompatibleApiKey: "test-api-key",
29+
apiModelId: "minimaxai/minimax-m2",
30+
} as any)
31+
32+
expect(handler).toBeInstanceOf(OpenAiCompatibleHandler)
33+
expect(OpenAI).toHaveBeenCalledWith(
34+
expect.objectContaining({
35+
baseURL: "https://integrate.api.nvidia.com/v1",
36+
apiKey: "test-api-key",
37+
}),
38+
)
39+
})
40+
41+
it("should throw error when base URL is missing", () => {
42+
expect(
43+
() =>
44+
new OpenAiCompatibleHandler({
45+
openAiCompatibleApiKey: "test-api-key",
46+
} as any),
47+
).toThrow("OpenAI-compatible base URL is required")
48+
})
49+
50+
it("should throw error when API key is missing", () => {
51+
expect(
52+
() =>
53+
new OpenAiCompatibleHandler({
54+
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
55+
} as any),
56+
).toThrow("OpenAI-compatible API key is required")
57+
})
58+
59+
it("should use fallback properties if openAiCompatible ones are not present", () => {
60+
handler = new OpenAiCompatibleHandler({
61+
openAiBaseUrl: "https://integrate.api.nvidia.com/v1",
62+
openAiApiKey: "test-api-key",
63+
apiModelId: "minimaxai/minimax-m2",
64+
} as any)
65+
66+
expect(handler).toBeInstanceOf(OpenAiCompatibleHandler)
67+
expect(OpenAI).toHaveBeenCalledWith(
68+
expect.objectContaining({
69+
baseURL: "https://integrate.api.nvidia.com/v1",
70+
apiKey: "test-api-key",
71+
}),
72+
)
73+
})
74+
75+
it("should use default model when apiModelId is not provided", () => {
76+
handler = new OpenAiCompatibleHandler({
77+
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
78+
openAiCompatibleApiKey: "test-api-key",
79+
} as any)
80+
81+
const model = handler.getModel()
82+
expect(model.id).toBe("default")
83+
})
84+
85+
it("should support NVIDIA API with MiniMax model", () => {
86+
handler = new OpenAiCompatibleHandler({
87+
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
88+
openAiCompatibleApiKey: "nvapi-test-key",
89+
apiModelId: "minimaxai/minimax-m2",
90+
} as any)
91+
92+
const model = handler.getModel()
93+
expect(model.id).toBe("minimaxai/minimax-m2")
94+
expect(model.info.maxTokens).toBe(128000)
95+
expect(model.info.contextWindow).toBe(128000)
96+
})
97+
98+
it("should support any custom OpenAI-compatible endpoint", () => {
99+
handler = new OpenAiCompatibleHandler({
100+
openAiCompatibleBaseUrl: "https://custom.api.example.com/v1",
101+
openAiCompatibleApiKey: "custom-api-key",
102+
apiModelId: "custom-model",
103+
} as any)
104+
105+
const model = handler.getModel()
106+
expect(model.id).toBe("custom-model")
107+
})
108+
})
109+
110+
describe("getModel", () => {
111+
beforeEach(() => {
112+
handler = new OpenAiCompatibleHandler({
113+
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
114+
openAiCompatibleApiKey: "test-api-key",
115+
apiModelId: "minimaxai/minimax-m2",
116+
} as any)
117+
})
118+
119+
it("should return correct model info", () => {
120+
const model = handler.getModel()
121+
expect(model.id).toBe("minimaxai/minimax-m2")
122+
expect(model.info).toMatchObject({
123+
maxTokens: 128000,
124+
contextWindow: 128000,
125+
supportsPromptCache: false,
126+
supportsImages: false,
127+
})
128+
})
129+
})
130+
131+
describe("createMessage", () => {
132+
beforeEach(() => {
133+
handler = new OpenAiCompatibleHandler({
134+
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
135+
openAiCompatibleApiKey: "test-api-key",
136+
apiModelId: "minimaxai/minimax-m2",
137+
} as any)
138+
})
139+
140+
it("should create streaming request with correct parameters", async () => {
141+
const mockStream = (async function* () {
142+
yield {
143+
choices: [
144+
{
145+
delta: { content: "Test response" },
146+
},
147+
],
148+
}
149+
yield {
150+
choices: [
151+
{
152+
delta: {},
153+
},
154+
],
155+
usage: {
156+
prompt_tokens: 10,
157+
completion_tokens: 5,
158+
},
159+
}
160+
})()
161+
162+
mockCreate.mockReturnValue(mockStream)
163+
164+
const systemPrompt = "You are a helpful assistant."
165+
const messages = [
166+
{
167+
role: "user" as const,
168+
content: "Hello, can you help me?",
169+
},
170+
]
171+
172+
const stream = handler.createMessage(systemPrompt, messages)
173+
const chunks = []
174+
175+
for await (const chunk of stream) {
176+
chunks.push(chunk)
177+
}
178+
179+
expect(mockCreate).toHaveBeenCalledWith(
180+
expect.objectContaining({
181+
model: "minimaxai/minimax-m2",
182+
messages: [
183+
{ role: "system", content: systemPrompt },
184+
{ role: "user", content: "Hello, can you help me?" },
185+
],
186+
stream: true,
187+
stream_options: { include_usage: true },
188+
temperature: 0.7,
189+
max_tokens: 25600, // 20% of context window (128000)
190+
}),
191+
undefined,
192+
)
193+
194+
expect(chunks).toContainEqual(
195+
expect.objectContaining({
196+
type: "text",
197+
text: "Test response",
198+
}),
199+
)
200+
201+
expect(chunks).toContainEqual(
202+
expect.objectContaining({
203+
type: "usage",
204+
inputTokens: 10,
205+
outputTokens: 5,
206+
}),
207+
)
208+
})
209+
})
210+
211+
describe("completePrompt", () => {
212+
beforeEach(() => {
213+
handler = new OpenAiCompatibleHandler({
214+
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
215+
openAiCompatibleApiKey: "test-api-key",
216+
apiModelId: "minimaxai/minimax-m2",
217+
} as any)
218+
})
219+
220+
it("should complete prompt correctly", async () => {
221+
const mockResponse = {
222+
choices: [
223+
{
224+
message: { content: "Completed response" },
225+
},
226+
],
227+
}
228+
229+
mockCreate.mockResolvedValue(mockResponse)
230+
231+
const result = await handler.completePrompt("Complete this: Hello")
232+
233+
expect(mockCreate).toHaveBeenCalledWith(
234+
expect.objectContaining({
235+
model: "minimaxai/minimax-m2",
236+
messages: [{ role: "user", content: "Complete this: Hello" }],
237+
}),
238+
)
239+
240+
expect(result).toBe("Completed response")
241+
})
242+
243+
it("should return empty string when no content", async () => {
244+
const mockResponse = {
245+
choices: [
246+
{
247+
message: { content: null },
248+
},
249+
],
250+
}
251+
252+
mockCreate.mockResolvedValue(mockResponse)
253+
254+
const result = await handler.completePrompt("Complete this")
255+
256+
expect(result).toBe("")
257+
})
258+
})
259+
})

src/api/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export { MistralHandler } from "./mistral"
2020
export { OllamaHandler } from "./ollama"
2121
export { OpenAiNativeHandler } from "./openai-native"
2222
export { OpenAiHandler } from "./openai"
23+
export { OpenAiCompatibleHandler } from "./openai-compatible"
2324
export { OpenRouterHandler } from "./openrouter"
2425
export { QwenCodeHandler } from "./qwen-code"
2526
export { RequestyHandler } from "./requesty"

0 commit comments

Comments
 (0)