diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index fc69948d476..5b098fb1c00 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -132,6 +132,7 @@ export const providerNames = [ "mistral", "moonshot", "minimax", + "openai-compatible", "openai-native", "qwen-code", "roo", @@ -420,6 +421,11 @@ const vercelAiGatewaySchema = baseProviderSettingsSchema.extend({ vercelAiGatewayModelId: z.string().optional(), }) +const openAiCompatibleSchema = apiModelIdProviderModelSchema.extend({ + openAiCompatibleBaseUrl: z.string().optional(), + openAiCompatibleApiKey: z.string().optional(), +}) + const defaultSchema = z.object({ apiProvider: z.undefined(), }) @@ -462,6 +468,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv qwenCodeSchema.merge(z.object({ apiProvider: z.literal("qwen-code") })), rooSchema.merge(z.object({ apiProvider: z.literal("roo") })), vercelAiGatewaySchema.merge(z.object({ apiProvider: z.literal("vercel-ai-gateway") })), + openAiCompatibleSchema.merge(z.object({ apiProvider: z.literal("openai-compatible") })), defaultSchema, ]) @@ -504,6 +511,7 @@ export const providerSettingsSchema = z.object({ ...qwenCodeSchema.shape, ...rooSchema.shape, ...vercelAiGatewaySchema.shape, + ...openAiCompatibleSchema.shape, ...codebaseIndexProviderSchema.shape, }) @@ -590,6 +598,7 @@ export const modelIdKeysByProvider: Record = { "io-intelligence": "ioIntelligenceModelId", roo: "apiModelId", "vercel-ai-gateway": "vercelAiGatewayModelId", + "openai-compatible": "apiModelId", } /** @@ -626,7 +635,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str */ export const MODELS_BY_PROVIDER: Record< - Exclude, + Exclude, { id: ProviderName; label: string; models: string[] } > = { anthropic: { diff --git a/src/api/index.ts b/src/api/index.ts index 351f4ef1bef..6cc57a47ebb 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -13,6 +13,7 @@ import { VertexHandler, AnthropicVertexHandler, OpenAiHandler, + OpenAiCompatibleHandler, LmStudioHandler, GeminiHandler, OpenAiNativeHandler, @@ -168,6 +169,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new VercelAiGatewayHandler(options) case "minimax": return new MiniMaxHandler(options) + case "openai-compatible": + return new OpenAiCompatibleHandler(options) default: apiProvider satisfies "gemini-cli" | undefined return new AnthropicHandler(options) diff --git a/src/api/providers/__tests__/openai-compatible.spec.ts b/src/api/providers/__tests__/openai-compatible.spec.ts new file mode 100644 index 00000000000..3c35a894bf5 --- /dev/null +++ b/src/api/providers/__tests__/openai-compatible.spec.ts @@ -0,0 +1,259 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import OpenAI from "openai" + +import { OpenAiCompatibleHandler } from "../openai-compatible" + +vi.mock("openai") + +describe("OpenAiCompatibleHandler", () => { + let handler: OpenAiCompatibleHandler + let mockCreate: any + + beforeEach(() => { + vi.clearAllMocks() + mockCreate = vi.fn() + ;(OpenAI as any).mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })) + }) + + describe("initialization", () => { + it("should create handler with valid configuration", () => { + handler = new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1", + openAiCompatibleApiKey: "test-api-key", + apiModelId: "minimaxai/minimax-m2", + } as any) + + expect(handler).toBeInstanceOf(OpenAiCompatibleHandler) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://integrate.api.nvidia.com/v1", + apiKey: "test-api-key", + }), + ) + }) + + it("should throw error when base URL is missing", () => { + expect( + () => + new OpenAiCompatibleHandler({ + openAiCompatibleApiKey: "test-api-key", + } as any), + ).toThrow("OpenAI-compatible base URL is required") + }) + + it("should throw error when API key is missing", () => { + expect( + () => + new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1", + } as any), + ).toThrow("OpenAI-compatible API key is required") + }) + + it("should use fallback properties if openAiCompatible ones are not present", () => { + handler = new OpenAiCompatibleHandler({ + openAiBaseUrl: "https://integrate.api.nvidia.com/v1", + openAiApiKey: "test-api-key", + apiModelId: "minimaxai/minimax-m2", + } as any) + + expect(handler).toBeInstanceOf(OpenAiCompatibleHandler) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://integrate.api.nvidia.com/v1", + apiKey: "test-api-key", + }), + ) + }) + + it("should use default model when apiModelId is not provided", () => { + handler = new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1", + openAiCompatibleApiKey: "test-api-key", + } as any) + + const model = handler.getModel() + expect(model.id).toBe("default") + }) + + it("should support NVIDIA API with MiniMax model", () => { + handler = new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1", + openAiCompatibleApiKey: "nvapi-test-key", + apiModelId: "minimaxai/minimax-m2", + } as any) + + const model = handler.getModel() + expect(model.id).toBe("minimaxai/minimax-m2") + expect(model.info.maxTokens).toBe(128000) + expect(model.info.contextWindow).toBe(128000) + }) + + it("should support any custom OpenAI-compatible endpoint", () => { + handler = new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://custom.api.example.com/v1", + openAiCompatibleApiKey: "custom-api-key", + apiModelId: "custom-model", + } as any) + + const model = handler.getModel() + expect(model.id).toBe("custom-model") + }) + }) + + describe("getModel", () => { + beforeEach(() => { + handler = new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1", + openAiCompatibleApiKey: "test-api-key", + apiModelId: "minimaxai/minimax-m2", + } as any) + }) + + it("should return correct model info", () => { + const model = handler.getModel() + expect(model.id).toBe("minimaxai/minimax-m2") + expect(model.info).toMatchObject({ + maxTokens: 128000, + contextWindow: 128000, + supportsPromptCache: false, + supportsImages: false, + }) + }) + }) + + describe("createMessage", () => { + beforeEach(() => { + handler = new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1", + openAiCompatibleApiKey: "test-api-key", + apiModelId: "minimaxai/minimax-m2", + } as any) + }) + + it("should create streaming request with correct parameters", async () => { + const mockStream = (async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + }, + ], + } + yield { + choices: [ + { + delta: {}, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + }, + } + })() + + mockCreate.mockReturnValue(mockStream) + + const systemPrompt = "You are a helpful assistant." + const messages = [ + { + role: "user" as const, + content: "Hello, can you help me?", + }, + ] + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "minimaxai/minimax-m2", + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: "Hello, can you help me?" }, + ], + stream: true, + stream_options: { include_usage: true }, + temperature: 0.7, + max_tokens: 25600, // 20% of context window (128000) + }), + undefined, + ) + + expect(chunks).toContainEqual( + expect.objectContaining({ + type: "text", + text: "Test response", + }), + ) + + expect(chunks).toContainEqual( + expect.objectContaining({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + }), + ) + }) + }) + + describe("completePrompt", () => { + beforeEach(() => { + handler = new OpenAiCompatibleHandler({ + openAiCompatibleBaseUrl: "https://integrate.api.nvidia.com/v1", + openAiCompatibleApiKey: "test-api-key", + apiModelId: "minimaxai/minimax-m2", + } as any) + }) + + it("should complete prompt correctly", async () => { + const mockResponse = { + choices: [ + { + message: { content: "Completed response" }, + }, + ], + } + + mockCreate.mockResolvedValue(mockResponse) + + const result = await handler.completePrompt("Complete this: Hello") + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "minimaxai/minimax-m2", + messages: [{ role: "user", content: "Complete this: Hello" }], + }), + ) + + expect(result).toBe("Completed response") + }) + + it("should return empty string when no content", async () => { + const mockResponse = { + choices: [ + { + message: { content: null }, + }, + ], + } + + mockCreate.mockResolvedValue(mockResponse) + + const result = await handler.completePrompt("Complete this") + + expect(result).toBe("") + }) + }) +}) diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index 533023d0374..27a78dabd58 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -20,6 +20,7 @@ export { MistralHandler } from "./mistral" export { OllamaHandler } from "./ollama" export { OpenAiNativeHandler } from "./openai-native" export { OpenAiHandler } from "./openai" +export { OpenAiCompatibleHandler } from "./openai-compatible" export { OpenRouterHandler } from "./openrouter" export { QwenCodeHandler } from "./qwen-code" export { RequestyHandler } from "./requesty" diff --git a/src/api/providers/openai-compatible.ts b/src/api/providers/openai-compatible.ts new file mode 100644 index 00000000000..0625aa323af --- /dev/null +++ b/src/api/providers/openai-compatible.ts @@ -0,0 +1,64 @@ +import type { ApiHandlerOptions } from "../../shared/api" +import type { ModelInfo } from "@roo-code/types" + +import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" + +// Default model configuration for OpenAI-compatible APIs +const DEFAULT_OPENAI_COMPATIBLE_MODEL: ModelInfo = { + maxTokens: 128000, + contextWindow: 128000, + supportsPromptCache: false, + supportsImages: false, + supportsReasoningEffort: false, + supportsReasoningBinary: false, + inputPrice: 0, + outputPrice: 0, +} + +// Support any model ID as string for maximum flexibility +export class OpenAiCompatibleHandler extends BaseOpenAiCompatibleProvider { + constructor(options: ApiHandlerOptions) { + // Since ApiHandlerOptions doesn't have openAiCompatibleBaseUrl/ApiKey yet, + // we'll use the openAiBaseUrl and openAiApiKey for now as a workaround + // This will be properly fixed when we add the proper types + const baseURL = (options as any).openAiCompatibleBaseUrl || options.openAiBaseUrl + const apiKey = (options as any).openAiCompatibleApiKey || options.openAiApiKey + + if (!baseURL) { + throw new Error("OpenAI-compatible base URL is required") + } + + if (!apiKey) { + throw new Error("OpenAI-compatible API key is required") + } + + // Use the model ID provided or default to a generic one + const modelId = options.apiModelId || "default" + + // Create a models object with the single model + const providerModels: Record = { + [modelId]: DEFAULT_OPENAI_COMPATIBLE_MODEL, + } + + super({ + ...options, + providerName: "OpenAI Compatible", + baseURL, + apiKey, + defaultProviderModelId: modelId, + providerModels, + defaultTemperature: 0.7, + }) + } + + override getModel() { + // For OpenAI-compatible APIs, we allow any model ID + // and use default configuration if not known + const modelId = this.options.apiModelId || "default" + + return { + id: modelId, + info: DEFAULT_OPENAI_COMPATIBLE_MODEL, + } + } +} diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 296b262c373..0836e2ffd9b 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -367,6 +367,18 @@ function getSelectedModel({ const info = routerModels["vercel-ai-gateway"]?.[id] return { id, info } } + case "openai-compatible": { + // OpenAI Compatible provider uses custom configuration + const id = apiConfiguration.apiModelId ?? "default" + // Basic model info for OpenAI compatible APIs + const info: ModelInfo = { + maxTokens: 128000, + contextWindow: 128000, + supportsPromptCache: false, + supportsImages: false, + } + return { id, info } + } // case "anthropic": // case "human-relay": // case "fake-ai":