|
| 1 | +// src/api/providers/custom-openai.ts |
| 2 | +import { Anthropic } from "@anthropic-ai/sdk" |
| 3 | +import axios, { AxiosInstance, AxiosRequestConfig } from "axios" // Use axios for custom requests |
| 4 | + |
| 5 | +import { |
| 6 | + ApiHandlerOptions, |
| 7 | + ModelInfo, |
| 8 | + openAiModelInfoSaneDefaults, // Use sane defaults initially |
| 9 | +} from "../../shared/api" |
| 10 | +import { SingleCompletionHandler } from "../index" |
| 11 | +import { convertToOpenAiMessages } from "../transform/openai-format" // Reuse message formatting |
| 12 | +import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream" |
| 13 | +import { BaseProvider } from "./base-provider" |
| 14 | +import { XmlMatcher } from "../../utils/xml-matcher" // For potential reasoning tags |
| 15 | + |
| 16 | +// Define specific options for the custom provider |
| 17 | +export interface CustomOpenAiHandlerOptions extends ApiHandlerOptions { |
| 18 | + customBaseUrl?: string |
| 19 | + customApiKey?: string |
| 20 | + customAuthHeaderName?: string // e.g., 'X-API-Key' |
| 21 | + customAuthHeaderPrefix?: string // e.g., 'Bearer ' or '' |
| 22 | + // URL path options |
| 23 | + useModelInPath?: boolean // Whether to include model in URL path (e.g., /api/v1/chat/model-name) |
| 24 | + customPathPrefix?: string // Custom path prefix (e.g., /api/v1/chat/) |
| 25 | + // Potentially add other OpenAI-compatible options if needed later |
| 26 | + modelTemperature?: number | null // Allow null to match schema |
| 27 | + includeMaxTokens?: boolean |
| 28 | + openAiStreamingEnabled?: boolean // Reuse existing streaming flag? |
| 29 | + openAiModelId?: string // Reuse model ID field |
| 30 | + openAiCustomModelInfo?: ModelInfo | null // Allow null to match schema |
| 31 | +} |
| 32 | + |
| 33 | +// Default headers - maybe keep these? |
| 34 | +export const defaultHeaders = { |
| 35 | + "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", |
| 36 | + "X-Title": "Roo Code", |
| 37 | +} |
| 38 | + |
| 39 | +export class CustomOpenAiHandler extends BaseProvider implements SingleCompletionHandler { |
| 40 | + protected options: CustomOpenAiHandlerOptions |
| 41 | + private client: AxiosInstance // Use an axios instance |
| 42 | + |
| 43 | + constructor(options: CustomOpenAiHandlerOptions) { |
| 44 | + super() |
| 45 | + this.options = options |
| 46 | + |
| 47 | + const baseURL = this.options.customBaseUrl |
| 48 | + if (!baseURL) { |
| 49 | + throw new Error("Custom OpenAI provider requires 'customBaseUrl' to be set.") |
| 50 | + } |
| 51 | + if (!this.options.customApiKey) { |
| 52 | + console.warn("Custom OpenAI provider initialized without 'customApiKey'.") |
| 53 | + } |
| 54 | + |
| 55 | + // Prepare authentication header |
| 56 | + const authHeaderName = this.options.customAuthHeaderName || "Authorization" // Default to Authorization |
| 57 | + const authHeaderPrefix = |
| 58 | + this.options.customAuthHeaderPrefix !== undefined ? this.options.customAuthHeaderPrefix : "Bearer " // Default to Bearer prefix |
| 59 | + const apiKey = this.options.customApiKey || "not-provided" |
| 60 | + const authHeaderValue = `${authHeaderPrefix}${apiKey}`.trim() // Handle empty prefix |
| 61 | + |
| 62 | + this.client = axios.create({ |
| 63 | + baseURL, |
| 64 | + headers: { |
| 65 | + ...defaultHeaders, // Include default Roo headers |
| 66 | + [authHeaderName]: authHeaderValue, // Add the custom auth header |
| 67 | + "Content-Type": "application/json", |
| 68 | + }, |
| 69 | + }) |
| 70 | + } |
| 71 | + |
| 72 | + // --- Implementation using axios --- |
| 73 | + |
| 74 | + override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { |
| 75 | + const modelInfo = this.getModel().info |
| 76 | + const modelId = this.options.openAiModelId ?? "custom-model" // Get model ID from options |
| 77 | + const streamingEnabled = this.options.openAiStreamingEnabled ?? true // Default to streaming |
| 78 | + |
| 79 | + // Convert messages to OpenAI format |
| 80 | + // Need to import OpenAI types for this |
| 81 | + const systemMessage: { role: "system"; content: string } = { |
| 82 | + role: "system", |
| 83 | + content: systemPrompt, |
| 84 | + } |
| 85 | + const convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] |
| 86 | + |
| 87 | + // Construct the common payload parts |
| 88 | + const payload: Record<string, any> = { |
| 89 | + model: modelId, |
| 90 | + messages: convertedMessages, |
| 91 | + temperature: this.options.modelTemperature ?? 0, // Default temperature |
| 92 | + stream: streamingEnabled, |
| 93 | + } |
| 94 | + |
| 95 | + if (streamingEnabled && modelInfo.supportsUsageStream) { |
| 96 | + payload.stream_options = { include_usage: true } |
| 97 | + } |
| 98 | + |
| 99 | + if (this.options.includeMaxTokens && modelInfo.maxTokens) { |
| 100 | + payload.max_tokens = modelInfo.maxTokens |
| 101 | + } |
| 102 | + // Determine the endpoint based on configuration |
| 103 | + let endpoint = "/chat/completions" // Default OpenAI-compatible endpoint |
| 104 | + |
| 105 | + // If useModelInPath is true, construct the endpoint with the model in the path |
| 106 | + if (this.options.useModelInPath && modelId) { |
| 107 | + const pathPrefix = this.options.customPathPrefix || "/api/v1/chat/" |
| 108 | + endpoint = `${pathPrefix}${modelId}` |
| 109 | + } |
| 110 | + |
| 111 | + try { |
| 112 | + if (streamingEnabled) { |
| 113 | + const response = await this.client.post(endpoint, payload, { |
| 114 | + responseType: "stream", |
| 115 | + }) |
| 116 | + |
| 117 | + const stream = response.data as NodeJS.ReadableStream |
| 118 | + let buffer = "" |
| 119 | + let lastUsage: any = null |
| 120 | + const matcher = new XmlMatcher( |
| 121 | + "think", |
| 122 | + (chunk) => ({ type: chunk.matched ? "reasoning" : "text", text: chunk.data }) as const, |
| 123 | + ) |
| 124 | + |
| 125 | + for await (const chunk of stream) { |
| 126 | + buffer += chunk.toString() |
| 127 | + |
| 128 | + // Process buffer line by line (SSE format) |
| 129 | + let EOL |
| 130 | + while ((EOL = buffer.indexOf("\n")) >= 0) { |
| 131 | + const line = buffer.substring(0, EOL).trim() |
| 132 | + buffer = buffer.substring(EOL + 1) |
| 133 | + |
| 134 | + if (line.startsWith("data:")) { |
| 135 | + const data = line.substring(5).trim() |
| 136 | + if (data === "[DONE]") { |
| 137 | + break // Stream finished |
| 138 | + } |
| 139 | + try { |
| 140 | + const parsed = JSON.parse(data) |
| 141 | + const delta = parsed.choices?.[0]?.delta ?? {} |
| 142 | + |
| 143 | + if (delta.content) { |
| 144 | + for (const contentChunk of matcher.update(delta.content)) { |
| 145 | + yield contentChunk |
| 146 | + } |
| 147 | + } |
| 148 | + // Handle potential reasoning content if supported by the custom model |
| 149 | + if ("reasoning_content" in delta && delta.reasoning_content) { |
| 150 | + yield { |
| 151 | + type: "reasoning", |
| 152 | + text: (delta.reasoning_content as string | undefined) || "", |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + if (parsed.usage) { |
| 157 | + lastUsage = parsed.usage |
| 158 | + } |
| 159 | + } catch (e) { |
| 160 | + console.error("Error parsing stream data:", e, "Data:", data) |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | + } |
| 165 | + // Yield any remaining text from the matcher |
| 166 | + for (const contentChunk of matcher.final()) { |
| 167 | + yield contentChunk |
| 168 | + } |
| 169 | + |
| 170 | + if (lastUsage) { |
| 171 | + yield this.processUsageMetrics(lastUsage, modelInfo) |
| 172 | + } |
| 173 | + } else { |
| 174 | + // Non-streaming case |
| 175 | + const response = await this.client.post(endpoint, payload) |
| 176 | + const completion = response.data |
| 177 | + |
| 178 | + yield { |
| 179 | + type: "text", |
| 180 | + text: completion.choices?.[0]?.message?.content || "", |
| 181 | + } |
| 182 | + if (completion.usage) { |
| 183 | + yield this.processUsageMetrics(completion.usage, modelInfo) |
| 184 | + } |
| 185 | + } |
| 186 | + } catch (error: any) { |
| 187 | + console.error("Custom OpenAI API request failed:", error) |
| 188 | + let errorMessage = "Custom OpenAI API request failed." |
| 189 | + if (axios.isAxiosError(error) && error.response) { |
| 190 | + errorMessage += ` Status: ${error.response.status}. Data: ${JSON.stringify(error.response.data)}` |
| 191 | + } else if (error instanceof Error) { |
| 192 | + errorMessage += ` Error: ${error.message}` |
| 193 | + } |
| 194 | + // Yield an error chunk or throw? For now, yield text. |
| 195 | + yield { type: "text", text: `[ERROR: ${errorMessage}]` } |
| 196 | + // Consider throwing an error instead if that's preferred for handling failures |
| 197 | + // throw new Error(errorMessage); |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + override getModel(): { id: string; info: ModelInfo } { |
| 202 | + // Reuse existing fields if they make sense for custom providers |
| 203 | + return { |
| 204 | + id: this.options.openAiModelId ?? "custom-model", // Default or configured ID |
| 205 | + info: this.options.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults, |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + async completePrompt(prompt: string): Promise<string> { |
| 210 | + // TODO: Implement non-streaming completion if needed (optional for Roo?) |
| 211 | + console.log("Prompt:", prompt) |
| 212 | + return "[Placeholder: CustomOpenAiHandler.completePrompt not implemented]" |
| 213 | + } |
| 214 | + |
| 215 | + // --- Helper methods (potentially reuse/adapt from OpenAiHandler) --- |
| 216 | + protected processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk { |
| 217 | + // Adapt if usage stats format differs |
| 218 | + return { |
| 219 | + type: "usage", |
| 220 | + inputTokens: usage?.prompt_tokens || 0, |
| 221 | + outputTokens: usage?.completion_tokens || 0, |
| 222 | + } |
| 223 | + } |
| 224 | +} |
| 225 | + |
| 226 | +// TODO: Add function to fetch models if the custom endpoint supports a /models route |
| 227 | +// export async function getCustomOpenAiModels(...) { ... } |
0 commit comments