From ec183802e63bf36175a35256b278b07852f279e4 Mon Sep 17 00:00:00 2001 From: Capcy Date: Wed, 9 Apr 2025 21:43:36 -0500 Subject: [PATCH 1/2] feat(custom-openai-compatible-api-provider): extended api provider to connect roo to openai compatible apis that have non standard configs ustom auth headers instead of `authorization` custom or empty auth header prefixes and url structures where model is included in the path --- src/api/index.ts | 3 + .../providers/__tests__/custom-openai.test.ts | 88 +++++++ src/api/providers/custom-openai.ts | 227 ++++++++++++++++++ src/exports/api.ts | 2 +- src/exports/roo-code.d.ts | 13 + src/exports/types.ts | 13 + src/schemas/index.ts | 16 ++ .../src/components/settings/ApiOptions.tsx | 101 ++++++++ .../src/components/settings/constants.ts | 1 + 9 files changed, 463 insertions(+), 1 deletion(-) create mode 100644 src/api/providers/__tests__/custom-openai.test.ts create mode 100644 src/api/providers/custom-openai.ts diff --git a/src/api/index.ts b/src/api/index.ts index 0880f422182..12dc03df4c3 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -21,6 +21,7 @@ import { UnboundHandler } from "./providers/unbound" import { RequestyHandler } from "./providers/requesty" import { HumanRelayHandler } from "./providers/human-relay" import { FakeAIHandler } from "./providers/fake-ai" +import { CustomOpenAiHandler } from "./providers/custom-openai" // Import the new handler export interface SingleCompletionHandler { completePrompt(prompt: string): Promise @@ -56,6 +57,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler { return new VertexHandler(options) case "openai": return new OpenAiHandler(options) + case "custom-openai": // Add case for the new handler + return new CustomOpenAiHandler(options) case "ollama": return new OllamaHandler(options) case "lmstudio": diff --git a/src/api/providers/__tests__/custom-openai.test.ts b/src/api/providers/__tests__/custom-openai.test.ts new file mode 100644 index 00000000000..01629aa67df --- /dev/null +++ b/src/api/providers/__tests__/custom-openai.test.ts @@ -0,0 +1,88 @@ +import { CustomOpenAiHandler } from "../custom-openai" +import { openAiModelInfoSaneDefaults } from "../../../shared/api" + +describe("CustomOpenAiHandler", () => { + it("should construct with required options", () => { + const handler = new CustomOpenAiHandler({ + customBaseUrl: "https://api.example.com", + customApiKey: "test-key", + customAuthHeaderName: "X-API-Key", + customAuthHeaderPrefix: "", + }) + + expect(handler).toBeDefined() + }) + + it("should throw error if customBaseUrl is not provided", () => { + expect(() => { + new CustomOpenAiHandler({ + customApiKey: "test-key", + }) + }).toThrow("Custom OpenAI provider requires 'customBaseUrl' to be set.") + }) + + it("should use model in path when useModelInPath is true", async () => { + const handler = new CustomOpenAiHandler({ + customBaseUrl: "https://api.example.com", + customApiKey: "test-key", + useModelInPath: true, + customPathPrefix: "/api/v1/chat/", + openAiModelId: "gpt-3.5-turbo", + openAiCustomModelInfo: openAiModelInfoSaneDefaults, + }) + + // Mock the client.post method + const mockPost = jest.fn().mockResolvedValue({ + data: { + choices: [{ message: { content: "Test response" } }], + usage: { prompt_tokens: 10, completion_tokens: 20 }, + }, + }) + + // @ts-ignore - Replace the client with our mock + handler.client = { post: mockPost } + + // Call createMessage to trigger the endpoint construction + const stream = handler.createMessage("Test system prompt", [{ role: "user", content: "Test message" }]) + + // Consume the stream to ensure the post method is called + for await (const _ of stream) { + // Just consume the stream + } + + // Verify the endpoint used in the post call + expect(mockPost).toHaveBeenCalledWith("/api/v1/chat/gpt-3.5-turbo", expect.any(Object), expect.any(Object)) + }) + + it("should use standard endpoint when useModelInPath is false", async () => { + const handler = new CustomOpenAiHandler({ + customBaseUrl: "https://api.example.com", + customApiKey: "test-key", + useModelInPath: false, + openAiModelId: "gpt-3.5-turbo", + openAiCustomModelInfo: openAiModelInfoSaneDefaults, + }) + + // Mock the client.post method + const mockPost = jest.fn().mockResolvedValue({ + data: { + choices: [{ message: { content: "Test response" } }], + usage: { prompt_tokens: 10, completion_tokens: 20 }, + }, + }) + + // @ts-ignore - Replace the client with our mock + handler.client = { post: mockPost } + + // Call createMessage to trigger the endpoint construction + const stream = handler.createMessage("Test system prompt", [{ role: "user", content: "Test message" }]) + + // Consume the stream to ensure the post method is called + for await (const _ of stream) { + // Just consume the stream + } + + // Verify the endpoint used in the post call + expect(mockPost).toHaveBeenCalledWith("/chat/completions", expect.any(Object), expect.any(Object)) + }) +}) diff --git a/src/api/providers/custom-openai.ts b/src/api/providers/custom-openai.ts new file mode 100644 index 00000000000..963417910d4 --- /dev/null +++ b/src/api/providers/custom-openai.ts @@ -0,0 +1,227 @@ +// src/api/providers/custom-openai.ts +import { Anthropic } from "@anthropic-ai/sdk" +import axios, { AxiosInstance, AxiosRequestConfig } from "axios" // Use axios for custom requests + +import { + ApiHandlerOptions, + ModelInfo, + openAiModelInfoSaneDefaults, // Use sane defaults initially +} from "../../shared/api" +import { SingleCompletionHandler } from "../index" +import { convertToOpenAiMessages } from "../transform/openai-format" // Reuse message formatting +import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream" +import { BaseProvider } from "./base-provider" +import { XmlMatcher } from "../../utils/xml-matcher" // For potential reasoning tags + +// Define specific options for the custom provider +export interface CustomOpenAiHandlerOptions extends ApiHandlerOptions { + customBaseUrl?: string + customApiKey?: string + customAuthHeaderName?: string // e.g., 'X-API-Key' + customAuthHeaderPrefix?: string // e.g., 'Bearer ' or '' + // URL path options + useModelInPath?: boolean // Whether to include model in URL path (e.g., /api/v1/chat/model-name) + customPathPrefix?: string // Custom path prefix (e.g., /api/v1/chat/) + // Potentially add other OpenAI-compatible options if needed later + modelTemperature?: number | null // Allow null to match schema + includeMaxTokens?: boolean + openAiStreamingEnabled?: boolean // Reuse existing streaming flag? + openAiModelId?: string // Reuse model ID field + openAiCustomModelInfo?: ModelInfo | null // Allow null to match schema +} + +// Default headers - maybe keep these? +export const defaultHeaders = { + "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", + "X-Title": "Roo Code", +} + +export class CustomOpenAiHandler extends BaseProvider implements SingleCompletionHandler { + protected options: CustomOpenAiHandlerOptions + private client: AxiosInstance // Use an axios instance + + constructor(options: CustomOpenAiHandlerOptions) { + super() + this.options = options + + const baseURL = this.options.customBaseUrl + if (!baseURL) { + throw new Error("Custom OpenAI provider requires 'customBaseUrl' to be set.") + } + if (!this.options.customApiKey) { + console.warn("Custom OpenAI provider initialized without 'customApiKey'.") + } + + // Prepare authentication header + const authHeaderName = this.options.customAuthHeaderName || "Authorization" // Default to Authorization + const authHeaderPrefix = + this.options.customAuthHeaderPrefix !== undefined ? this.options.customAuthHeaderPrefix : "Bearer " // Default to Bearer prefix + const apiKey = this.options.customApiKey || "not-provided" + const authHeaderValue = `${authHeaderPrefix}${apiKey}`.trim() // Handle empty prefix + + this.client = axios.create({ + baseURL, + headers: { + ...defaultHeaders, // Include default Roo headers + [authHeaderName]: authHeaderValue, // Add the custom auth header + "Content-Type": "application/json", + }, + }) + } + + // --- Implementation using axios --- + + override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const modelInfo = this.getModel().info + const modelId = this.options.openAiModelId ?? "custom-model" // Get model ID from options + const streamingEnabled = this.options.openAiStreamingEnabled ?? true // Default to streaming + + // Convert messages to OpenAI format + // Need to import OpenAI types for this + const systemMessage: { role: "system"; content: string } = { + role: "system", + content: systemPrompt, + } + const convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] + + // Construct the common payload parts + const payload: Record = { + model: modelId, + messages: convertedMessages, + temperature: this.options.modelTemperature ?? 0, // Default temperature + stream: streamingEnabled, + } + + if (streamingEnabled && modelInfo.supportsUsageStream) { + payload.stream_options = { include_usage: true } + } + + if (this.options.includeMaxTokens && modelInfo.maxTokens) { + payload.max_tokens = modelInfo.maxTokens + } + // Determine the endpoint based on configuration + let endpoint = "/chat/completions" // Default OpenAI-compatible endpoint + + // If useModelInPath is true, construct the endpoint with the model in the path + if (this.options.useModelInPath && modelId) { + const pathPrefix = this.options.customPathPrefix || "/api/v1/chat/" + endpoint = `${pathPrefix}${modelId}` + } + + try { + if (streamingEnabled) { + const response = await this.client.post(endpoint, payload, { + responseType: "stream", + }) + + const stream = response.data as NodeJS.ReadableStream + let buffer = "" + let lastUsage: any = null + const matcher = new XmlMatcher( + "think", + (chunk) => ({ type: chunk.matched ? "reasoning" : "text", text: chunk.data }) as const, + ) + + for await (const chunk of stream) { + buffer += chunk.toString() + + // Process buffer line by line (SSE format) + let EOL + while ((EOL = buffer.indexOf("\n")) >= 0) { + const line = buffer.substring(0, EOL).trim() + buffer = buffer.substring(EOL + 1) + + if (line.startsWith("data:")) { + const data = line.substring(5).trim() + if (data === "[DONE]") { + break // Stream finished + } + try { + const parsed = JSON.parse(data) + const delta = parsed.choices?.[0]?.delta ?? {} + + if (delta.content) { + for (const contentChunk of matcher.update(delta.content)) { + yield contentChunk + } + } + // Handle potential reasoning content if supported by the custom model + if ("reasoning_content" in delta && delta.reasoning_content) { + yield { + type: "reasoning", + text: (delta.reasoning_content as string | undefined) || "", + } + } + + if (parsed.usage) { + lastUsage = parsed.usage + } + } catch (e) { + console.error("Error parsing stream data:", e, "Data:", data) + } + } + } + } + // Yield any remaining text from the matcher + for (const contentChunk of matcher.final()) { + yield contentChunk + } + + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, modelInfo) + } + } else { + // Non-streaming case + const response = await this.client.post(endpoint, payload) + const completion = response.data + + yield { + type: "text", + text: completion.choices?.[0]?.message?.content || "", + } + if (completion.usage) { + yield this.processUsageMetrics(completion.usage, modelInfo) + } + } + } catch (error: any) { + console.error("Custom OpenAI API request failed:", error) + let errorMessage = "Custom OpenAI API request failed." + if (axios.isAxiosError(error) && error.response) { + errorMessage += ` Status: ${error.response.status}. Data: ${JSON.stringify(error.response.data)}` + } else if (error instanceof Error) { + errorMessage += ` Error: ${error.message}` + } + // Yield an error chunk or throw? For now, yield text. + yield { type: "text", text: `[ERROR: ${errorMessage}]` } + // Consider throwing an error instead if that's preferred for handling failures + // throw new Error(errorMessage); + } + } + + override getModel(): { id: string; info: ModelInfo } { + // Reuse existing fields if they make sense for custom providers + return { + id: this.options.openAiModelId ?? "custom-model", // Default or configured ID + info: this.options.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults, + } + } + + async completePrompt(prompt: string): Promise { + // TODO: Implement non-streaming completion if needed (optional for Roo?) + console.log("Prompt:", prompt) + return "[Placeholder: CustomOpenAiHandler.completePrompt not implemented]" + } + + // --- Helper methods (potentially reuse/adapt from OpenAiHandler) --- + protected processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk { + // Adapt if usage stats format differs + return { + type: "usage", + inputTokens: usage?.prompt_tokens || 0, + outputTokens: usage?.completion_tokens || 0, + } + } +} + +// TODO: Add function to fetch models if the custom endpoint supports a /models route +// export async function getCustomOpenAiModels(...) { ... } diff --git a/src/exports/api.ts b/src/exports/api.ts index 2da90a84a5e..717c5d8499c 100644 --- a/src/exports/api.ts +++ b/src/exports/api.ts @@ -88,7 +88,7 @@ export class API extends EventEmitter implements RooCodeAPI { images, newTab, }: { - configuration: RooCodeSettings + configuration?: RooCodeSettings // Make configuration optional text?: string images?: string[] newTab?: boolean diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index 40939e4e32a..3a7f1d68bae 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -9,6 +9,7 @@ type ProviderSettings = { | "bedrock" | "vertex" | "openai" + | "custom-openai" | "ollama" | "vscode-lm" | "lmstudio" @@ -33,6 +34,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -55,6 +57,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -97,6 +100,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -142,6 +146,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -163,6 +168,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -181,6 +187,12 @@ type ProviderSettings = { includeMaxTokens?: boolean | undefined rateLimitSeconds?: number | undefined fakeAi?: unknown | undefined + customBaseUrl?: string | undefined + customApiKey?: string | undefined + customAuthHeaderName?: string | undefined + customAuthHeaderPrefix?: string | undefined + useModelInPath?: boolean | undefined + customPathPrefix?: string | undefined } type GlobalSettings = { @@ -197,6 +209,7 @@ type GlobalSettings = { | "bedrock" | "vertex" | "openai" + | "custom-openai" | "ollama" | "vscode-lm" | "lmstudio" diff --git a/src/exports/types.ts b/src/exports/types.ts index 64a955554e9..535c617954f 100644 --- a/src/exports/types.ts +++ b/src/exports/types.ts @@ -10,6 +10,7 @@ type ProviderSettings = { | "bedrock" | "vertex" | "openai" + | "custom-openai" | "ollama" | "vscode-lm" | "lmstudio" @@ -34,6 +35,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -56,6 +58,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -98,6 +101,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -143,6 +147,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -164,6 +169,7 @@ type ProviderSettings = { supportsImages?: boolean | undefined supportsComputerUse?: boolean | undefined supportsPromptCache: boolean + supportsUsageStream?: boolean | undefined inputPrice?: number | undefined outputPrice?: number | undefined cacheWritesPrice?: number | undefined @@ -182,6 +188,12 @@ type ProviderSettings = { includeMaxTokens?: boolean | undefined rateLimitSeconds?: number | undefined fakeAi?: unknown | undefined + customBaseUrl?: string | undefined + customApiKey?: string | undefined + customAuthHeaderName?: string | undefined + customAuthHeaderPrefix?: string | undefined + useModelInPath?: boolean | undefined + customPathPrefix?: string | undefined } export type { ProviderSettings } @@ -200,6 +212,7 @@ type GlobalSettings = { | "bedrock" | "vertex" | "openai" + | "custom-openai" | "ollama" | "vscode-lm" | "lmstudio" diff --git a/src/schemas/index.ts b/src/schemas/index.ts index d2471882ecc..277d7bfdefa 100644 --- a/src/schemas/index.ts +++ b/src/schemas/index.ts @@ -17,6 +17,7 @@ export const providerNames = [ "bedrock", "vertex", "openai", + "custom-openai", // Added custom provider "ollama", "vscode-lm", "lmstudio", @@ -105,6 +106,7 @@ export const modelInfoSchema = z.object({ supportsImages: z.boolean().optional(), supportsComputerUse: z.boolean().optional(), supportsPromptCache: z.boolean(), + supportsUsageStream: z.boolean().optional(), inputPrice: z.number().optional(), outputPrice: z.number().optional(), cacheWritesPrice: z.number().optional(), @@ -391,6 +393,13 @@ export const providerSettingsSchema = z.object({ rateLimitSeconds: z.number().optional(), // Fake AI fakeAi: z.unknown().optional(), + // Custom OpenAI Compatible + customBaseUrl: z.string().optional(), + customApiKey: z.string().optional(), + customAuthHeaderName: z.string().optional(), + customAuthHeaderPrefix: z.string().optional(), + useModelInPath: z.boolean().optional(), + customPathPrefix: z.string().optional(), }) export type ProviderSettings = z.infer @@ -471,6 +480,13 @@ const providerSettingsRecord: ProviderSettingsRecord = { requestyModelInfo: undefined, // Claude 3.7 Sonnet Thinking modelTemperature: undefined, + // Custom OpenAI Compatible + customBaseUrl: undefined, + customApiKey: undefined, + customAuthHeaderName: undefined, + customAuthHeaderPrefix: undefined, + useModelInPath: undefined, + customPathPrefix: undefined, modelMaxTokens: undefined, modelMaxThinkingTokens: undefined, // Generic diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index fb633df155b..a0fdc18f0d3 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -1444,6 +1444,77 @@ const ApiOptions = ({ )} )} + {selectedProvider === "custom-openai" && ( + <> + + + + + + +
+ {t("settings:providers.apiKeyStorageNotice")} +
+ + + + + + + + {/* Custom-specific options */} +
+
Custom API Options
+ + Use model name in URL path + +
+ Enable this for APIs that include the model name in the URL path (e.g., + /api/v1/chat/model-name) +
+
+ + {apiConfiguration?.useModelInPath && ( + + + + )} + +
+ Configure a custom OpenAI-compatible API endpoint. For services that use a different + authentication header (like 'X-API-Key' instead of 'Authorization: Bearer'), specify the header + name and prefix accordingly. +
+ + )} {selectedProvider === "human-relay" && ( <> @@ -1519,6 +1590,30 @@ const ApiOptions = ({ )} + {selectedProvider === "custom-openai" && ( + <> + + + setApiConfigurationField("rateLimitSeconds", value)} + /> + + )} + {selectedProvider === "glama" && ( a.label.localeCompare(b.label)) export const VERTEX_REGIONS = [ From e473d67188c6b2060fc1e4665a2cfde8d94d9002 Mon Sep 17 00:00:00 2001 From: Capcy Date: Wed, 9 Apr 2025 22:54:33 -0500 Subject: [PATCH 2/2] fix(custom-openai): sanitize error msg to prevent circular error addressing PR feedback --- src/api/providers/custom-openai.ts | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/api/providers/custom-openai.ts b/src/api/providers/custom-openai.ts index 963417910d4..c5766e17e73 100644 --- a/src/api/providers/custom-openai.ts +++ b/src/api/providers/custom-openai.ts @@ -184,15 +184,17 @@ export class CustomOpenAiHandler extends BaseProvider implements SingleCompletio } } } catch (error: any) { - console.error("Custom OpenAI API request failed:", error) - let errorMessage = "Custom OpenAI API request failed." - if (axios.isAxiosError(error) && error.response) { - errorMessage += ` Status: ${error.response.status}. Data: ${JSON.stringify(error.response.data)}` + console.error("Custom OpenAI API request failed:", error?.message || error) // Log basic error message + let simpleErrorMessage = "Custom OpenAI API request failed." + if (axios.isAxiosError(error)) { + simpleErrorMessage += ` Status: ${error.response?.status || "unknown"}.` + // Avoid logging potentially large/circular response data + // console.error("Error Response Data:", error.response?.data); // Optional: Log only during debugging if needed } else if (error instanceof Error) { - errorMessage += ` Error: ${error.message}` + simpleErrorMessage += ` Error: ${error.message}` } - // Yield an error chunk or throw? For now, yield text. - yield { type: "text", text: `[ERROR: ${errorMessage}]` } + // Yield a simplified error message + yield { type: "text", text: `[ERROR: ${simpleErrorMessage}]` } // Consider throwing an error instead if that's preferred for handling failures // throw new Error(errorMessage); }