diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index a66aae08a24..ceceff119d5 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -88,7 +88,7 @@ export const isInternalProvider = (key: string): key is InternalProvider => * Custom providers are completely configurable within Roo Code settings. */ -export const customProviders = ["openai"] as const +export const customProviders = ["openai", "openai-compatible"] as const export type CustomProvider = (typeof customProviders)[number] @@ -138,6 +138,7 @@ export const providerNames = [ "vertex", "xai", "zai", + "openai-compatible", ] as const export const providerNamesSchema = z.enum(providerNames) @@ -424,6 +425,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })), vertexSchema.merge(z.object({ apiProvider: z.literal("vertex") })), openAiSchema.merge(z.object({ apiProvider: z.literal("openai") })), + openAiSchema.merge(z.object({ apiProvider: z.literal("openai-compatible") })), ollamaSchema.merge(z.object({ apiProvider: z.literal("ollama") })), vsCodeLmSchema.merge(z.object({ apiProvider: z.literal("vscode-lm") })), lmStudioSchema.merge(z.object({ apiProvider: z.literal("lmstudio") })), @@ -610,7 +612,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str */ export const MODELS_BY_PROVIDER: Record< - Exclude, + Exclude, { id: ProviderName; label: string; models: string[] } > = { anthropic: { diff --git a/src/api/index.ts b/src/api/index.ts index ac009676762..393dea74dba 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -106,6 +106,7 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { ? new AnthropicVertexHandler(options) : new VertexHandler(options) case "openai": + case "openai-compatible": return new OpenAiHandler(options) case "ollama": return new NativeOllamaHandler(options) diff --git a/src/api/providers/__tests__/openai-compatible.spec.ts b/src/api/providers/__tests__/openai-compatible.spec.ts new file mode 100644 index 00000000000..e338fde1923 --- /dev/null +++ b/src/api/providers/__tests__/openai-compatible.spec.ts @@ -0,0 +1,140 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { buildApiHandler } from "../../index" +import { OpenAiHandler } from "../openai" + +vi.mock("openai", () => { + const mockCreate = vi.fn() + return { + default: vi.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })), + OpenAI: vi.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })), + AzureOpenAI: vi.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })), + } +}) + +describe("OpenAI Compatible Provider", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("should create OpenAiHandler when apiProvider is 'openai-compatible'", () => { + const handler = buildApiHandler({ + apiProvider: "openai-compatible", + openAiApiKey: "test-key", + openAiBaseUrl: "https://api.example.com/v1", + openAiModelId: "test-model", + }) + + expect(handler).toBeInstanceOf(OpenAiHandler) + }) + + it("should handle token usage correctly for openai-compatible provider", async () => { + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "Hello" } }], + } + yield { + choices: [{ delta: { content: " world" } }], + } + yield { + choices: [{ delta: {} }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + + const OpenAI = (await import("openai")).default + const mockCreate = vi.fn().mockResolvedValue(mockStream) + ;(OpenAI as any).mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })) + + const handler = buildApiHandler({ + apiProvider: "openai-compatible", + openAiApiKey: "test-key", + openAiBaseUrl: "https://api.example.com/v1", + openAiModelId: "test-model", + }) + + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("System prompt", messages) + + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check that we got text chunks + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks).toHaveLength(2) + expect(textChunks[0].text).toBe("Hello") + expect(textChunks[1].text).toBe(" world") + + // Check that we got usage data + const usageChunk = chunks.find((c) => c.type === "usage") + expect(usageChunk).toBeDefined() + expect(usageChunk).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + }) + }) + + it("should use the same configuration as openai provider", () => { + const config = { + openAiApiKey: "test-key", + openAiBaseUrl: "https://api.example.com/v1", + openAiModelId: "test-model", + openAiCustomModelInfo: { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + inputPrice: 0.001, + outputPrice: 0.002, + }, + } + + const openaiHandler = buildApiHandler({ + apiProvider: "openai", + ...config, + }) + + const openaiCompatibleHandler = buildApiHandler({ + apiProvider: "openai-compatible", + ...config, + }) + + // Both should be instances of OpenAiHandler + expect(openaiHandler).toBeInstanceOf(OpenAiHandler) + expect(openaiCompatibleHandler).toBeInstanceOf(OpenAiHandler) + + // Both should have the same model configuration + expect(openaiHandler.getModel()).toEqual(openaiCompatibleHandler.getModel()) + }) +}) diff --git a/src/shared/ProfileValidator.ts b/src/shared/ProfileValidator.ts index 78ff6ed9fe1..8e699b94fce 100644 --- a/src/shared/ProfileValidator.ts +++ b/src/shared/ProfileValidator.ts @@ -56,6 +56,7 @@ export class ProfileValidator { private static getModelIdFromProfile(profile: ProviderSettings): string | undefined { switch (profile.apiProvider) { case "openai": + case "openai-compatible": return profile.openAiModelId case "anthropic": case "openai-native": diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 0d0514b4d66..d176f2738ae 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -255,7 +255,8 @@ function getSelectedModel({ const info = mistralModels[id as keyof typeof mistralModels] return { id, info } } - case "openai": { + case "openai": + case "openai-compatible": { const id = apiConfiguration.openAiModelId ?? "" const info = apiConfiguration?.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults return { id, info } @@ -360,7 +361,7 @@ function getSelectedModel({ // case "human-relay": // case "fake-ai": default: { - provider satisfies "anthropic" | "gemini-cli" | "qwen-code" | "human-relay" | "fake-ai" + provider satisfies "anthropic" | "gemini-cli" | "human-relay" | "fake-ai" const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId const baseInfo = anthropicModels[id as keyof typeof anthropicModels]