Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ export const providerNames = [
"mistral",
"moonshot",
"minimax",
"openai-compatible",
"openai-native",
"qwen-code",
"roo",
Expand Down Expand Up @@ -420,6 +421,11 @@ const vercelAiGatewaySchema = baseProviderSettingsSchema.extend({
vercelAiGatewayModelId: z.string().optional(),
})

const openAiCompatibleSchema = apiModelIdProviderModelSchema.extend({
openAiCompatibleBaseUrl: z.string().optional(),
openAiCompatibleApiKey: z.string().optional(),
})

const defaultSchema = z.object({
apiProvider: z.undefined(),
})
Expand Down Expand Up @@ -462,6 +468,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
qwenCodeSchema.merge(z.object({ apiProvider: z.literal("qwen-code") })),
rooSchema.merge(z.object({ apiProvider: z.literal("roo") })),
vercelAiGatewaySchema.merge(z.object({ apiProvider: z.literal("vercel-ai-gateway") })),
openAiCompatibleSchema.merge(z.object({ apiProvider: z.literal("openai-compatible") })),
defaultSchema,
])

Expand Down Expand Up @@ -504,6 +511,7 @@ export const providerSettingsSchema = z.object({
...qwenCodeSchema.shape,
...rooSchema.shape,
...vercelAiGatewaySchema.shape,
...openAiCompatibleSchema.shape,
...codebaseIndexProviderSchema.shape,
})

Expand Down Expand Up @@ -590,6 +598,7 @@ export const modelIdKeysByProvider: Record<TypicalProvider, ModelIdKey> = {
"io-intelligence": "ioIntelligenceModelId",
roo: "apiModelId",
"vercel-ai-gateway": "vercelAiGatewayModelId",
"openai-compatible": "apiModelId",
}

/**
Expand Down Expand Up @@ -626,7 +635,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str
*/

export const MODELS_BY_PROVIDER: Record<
Exclude<ProviderName, "fake-ai" | "human-relay" | "gemini-cli" | "openai">,
Exclude<ProviderName, "fake-ai" | "human-relay" | "gemini-cli" | "openai" | "openai-compatible">,
{ id: ProviderName; label: string; models: string[] }
> = {
anthropic: {
Expand Down
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
VertexHandler,
AnthropicVertexHandler,
OpenAiHandler,
OpenAiCompatibleHandler,
LmStudioHandler,
GeminiHandler,
OpenAiNativeHandler,
Expand Down Expand Up @@ -168,6 +169,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new VercelAiGatewayHandler(options)
case "minimax":
return new MiniMaxHandler(options)
case "openai-compatible":
return new OpenAiCompatibleHandler(options)
default:
apiProvider satisfies "gemini-cli" | undefined
return new AnthropicHandler(options)
Expand Down
259 changes: 259 additions & 0 deletions src/api/providers/__tests__/openai-compatible.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import OpenAI from "openai"

import { OpenAiCompatibleHandler } from "../openai-compatible"

vi.mock("openai")

describe("OpenAiCompatibleHandler", () => {
let handler: OpenAiCompatibleHandler
let mockCreate: any

beforeEach(() => {
vi.clearAllMocks()
mockCreate = vi.fn()
;(OpenAI as any).mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
},
}))
})

describe("initialization", () => {
it("should create handler with valid configuration", () => {
handler = new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
openAiCompatibleApiKey: "test-api-key",
apiModelId: "minimaxai/minimax-m2",
} as any)

expect(handler).toBeInstanceOf(OpenAiCompatibleHandler)
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: "https://integrate.api.nvidia.com/v1",
apiKey: "test-api-key",
}),
)
})

it("should throw error when base URL is missing", () => {
expect(
() =>
new OpenAiCompatibleHandler({
openAiCompatibleApiKey: "test-api-key",
} as any),
).toThrow("OpenAI-compatible base URL is required")
})

it("should throw error when API key is missing", () => {
expect(
() =>
new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
} as any),
).toThrow("OpenAI-compatible API key is required")
})

it("should use fallback properties if openAiCompatible ones are not present", () => {
handler = new OpenAiCompatibleHandler({
openAiBaseUrl: "https://integrate.api.nvidia.com/v1",
openAiApiKey: "test-api-key",
apiModelId: "minimaxai/minimax-m2",
} as any)

expect(handler).toBeInstanceOf(OpenAiCompatibleHandler)
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: "https://integrate.api.nvidia.com/v1",
apiKey: "test-api-key",
}),
)
})

it("should use default model when apiModelId is not provided", () => {
handler = new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
openAiCompatibleApiKey: "test-api-key",
} as any)

const model = handler.getModel()
expect(model.id).toBe("default")
})

it("should support NVIDIA API with MiniMax model", () => {
handler = new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
openAiCompatibleApiKey: "nvapi-test-key",
apiModelId: "minimaxai/minimax-m2",
} as any)

const model = handler.getModel()
expect(model.id).toBe("minimaxai/minimax-m2")
expect(model.info.maxTokens).toBe(128000)
expect(model.info.contextWindow).toBe(128000)
})

it("should support any custom OpenAI-compatible endpoint", () => {
handler = new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://custom.api.example.com/v1",
openAiCompatibleApiKey: "custom-api-key",
apiModelId: "custom-model",
} as any)

const model = handler.getModel()
expect(model.id).toBe("custom-model")
})
})

describe("getModel", () => {
beforeEach(() => {
handler = new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
openAiCompatibleApiKey: "test-api-key",
apiModelId: "minimaxai/minimax-m2",
} as any)
})

it("should return correct model info", () => {
const model = handler.getModel()
expect(model.id).toBe("minimaxai/minimax-m2")
expect(model.info).toMatchObject({
maxTokens: 128000,
contextWindow: 128000,
supportsPromptCache: false,
supportsImages: false,
})
})
})

describe("createMessage", () => {
beforeEach(() => {
handler = new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
openAiCompatibleApiKey: "test-api-key",
apiModelId: "minimaxai/minimax-m2",
} as any)
})

it("should create streaming request with correct parameters", async () => {
const mockStream = (async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
},
],
}
yield {
choices: [
{
delta: {},
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
})()

mockCreate.mockReturnValue(mockStream)

const systemPrompt = "You are a helpful assistant."
const messages = [
{
role: "user" as const,
content: "Hello, can you help me?",
},
]

const stream = handler.createMessage(systemPrompt, messages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "minimaxai/minimax-m2",
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello, can you help me?" },
],
stream: true,
stream_options: { include_usage: true },
temperature: 0.7,
max_tokens: 25600, // 20% of context window (128000)
}),
undefined,
)

expect(chunks).toContainEqual(
expect.objectContaining({
type: "text",
text: "Test response",
}),
)

expect(chunks).toContainEqual(
expect.objectContaining({
type: "usage",
inputTokens: 10,
outputTokens: 5,
}),
)
})
})

describe("completePrompt", () => {
beforeEach(() => {
handler = new OpenAiCompatibleHandler({
openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1",
openAiCompatibleApiKey: "test-api-key",
apiModelId: "minimaxai/minimax-m2",
} as any)
})

it("should complete prompt correctly", async () => {
const mockResponse = {
choices: [
{
message: { content: "Completed response" },
},
],
}

mockCreate.mockResolvedValue(mockResponse)

const result = await handler.completePrompt("Complete this: Hello")

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "minimaxai/minimax-m2",
messages: [{ role: "user", content: "Complete this: Hello" }],
}),
)

expect(result).toBe("Completed response")
})

it("should return empty string when no content", async () => {
const mockResponse = {
choices: [
{
message: { content: null },
},
],
}

mockCreate.mockResolvedValue(mockResponse)

const result = await handler.completePrompt("Complete this")

expect(result).toBe("")
})
})
})
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export { MistralHandler } from "./mistral"
export { OllamaHandler } from "./ollama"
export { OpenAiNativeHandler } from "./openai-native"
export { OpenAiHandler } from "./openai"
export { OpenAiCompatibleHandler } from "./openai-compatible"
export { OpenRouterHandler } from "./openrouter"
export { QwenCodeHandler } from "./qwen-code"
export { RequestyHandler } from "./requesty"
Expand Down
Loading