Skip to content

Commit ec18380

Browse files
committed
feat(custom-openai-compatible-api-provider): extended api provider to connect roo to openai compatible apis that have non standard configs
ustom auth headers instead of `authorization` custom or empty auth header prefixes and url structures where model is included in the path
1 parent 5fa555e commit ec18380

File tree

9 files changed

+463
-1
lines changed

9 files changed

+463
-1
lines changed

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { UnboundHandler } from "./providers/unbound"
2121
import { RequestyHandler } from "./providers/requesty"
2222
import { HumanRelayHandler } from "./providers/human-relay"
2323
import { FakeAIHandler } from "./providers/fake-ai"
24+
import { CustomOpenAiHandler } from "./providers/custom-openai" // Import the new handler
2425

2526
export interface SingleCompletionHandler {
2627
completePrompt(prompt: string): Promise<string>
@@ -56,6 +57,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
5657
return new VertexHandler(options)
5758
case "openai":
5859
return new OpenAiHandler(options)
60+
case "custom-openai": // Add case for the new handler
61+
return new CustomOpenAiHandler(options)
5962
case "ollama":
6063
return new OllamaHandler(options)
6164
case "lmstudio":
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import { CustomOpenAiHandler } from "../custom-openai"
2+
import { openAiModelInfoSaneDefaults } from "../../../shared/api"
3+
4+
describe("CustomOpenAiHandler", () => {
5+
it("should construct with required options", () => {
6+
const handler = new CustomOpenAiHandler({
7+
customBaseUrl: "https://api.example.com",
8+
customApiKey: "test-key",
9+
customAuthHeaderName: "X-API-Key",
10+
customAuthHeaderPrefix: "",
11+
})
12+
13+
expect(handler).toBeDefined()
14+
})
15+
16+
it("should throw error if customBaseUrl is not provided", () => {
17+
expect(() => {
18+
new CustomOpenAiHandler({
19+
customApiKey: "test-key",
20+
})
21+
}).toThrow("Custom OpenAI provider requires 'customBaseUrl' to be set.")
22+
})
23+
24+
it("should use model in path when useModelInPath is true", async () => {
25+
const handler = new CustomOpenAiHandler({
26+
customBaseUrl: "https://api.example.com",
27+
customApiKey: "test-key",
28+
useModelInPath: true,
29+
customPathPrefix: "/api/v1/chat/",
30+
openAiModelId: "gpt-3.5-turbo",
31+
openAiCustomModelInfo: openAiModelInfoSaneDefaults,
32+
})
33+
34+
// Mock the client.post method
35+
const mockPost = jest.fn().mockResolvedValue({
36+
data: {
37+
choices: [{ message: { content: "Test response" } }],
38+
usage: { prompt_tokens: 10, completion_tokens: 20 },
39+
},
40+
})
41+
42+
// @ts-ignore - Replace the client with our mock
43+
handler.client = { post: mockPost }
44+
45+
// Call createMessage to trigger the endpoint construction
46+
const stream = handler.createMessage("Test system prompt", [{ role: "user", content: "Test message" }])
47+
48+
// Consume the stream to ensure the post method is called
49+
for await (const _ of stream) {
50+
// Just consume the stream
51+
}
52+
53+
// Verify the endpoint used in the post call
54+
expect(mockPost).toHaveBeenCalledWith("/api/v1/chat/gpt-3.5-turbo", expect.any(Object), expect.any(Object))
55+
})
56+
57+
it("should use standard endpoint when useModelInPath is false", async () => {
58+
const handler = new CustomOpenAiHandler({
59+
customBaseUrl: "https://api.example.com",
60+
customApiKey: "test-key",
61+
useModelInPath: false,
62+
openAiModelId: "gpt-3.5-turbo",
63+
openAiCustomModelInfo: openAiModelInfoSaneDefaults,
64+
})
65+
66+
// Mock the client.post method
67+
const mockPost = jest.fn().mockResolvedValue({
68+
data: {
69+
choices: [{ message: { content: "Test response" } }],
70+
usage: { prompt_tokens: 10, completion_tokens: 20 },
71+
},
72+
})
73+
74+
// @ts-ignore - Replace the client with our mock
75+
handler.client = { post: mockPost }
76+
77+
// Call createMessage to trigger the endpoint construction
78+
const stream = handler.createMessage("Test system prompt", [{ role: "user", content: "Test message" }])
79+
80+
// Consume the stream to ensure the post method is called
81+
for await (const _ of stream) {
82+
// Just consume the stream
83+
}
84+
85+
// Verify the endpoint used in the post call
86+
expect(mockPost).toHaveBeenCalledWith("/chat/completions", expect.any(Object), expect.any(Object))
87+
})
88+
})

src/api/providers/custom-openai.ts

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
// src/api/providers/custom-openai.ts
2+
import { Anthropic } from "@anthropic-ai/sdk"
3+
import axios, { AxiosInstance, AxiosRequestConfig } from "axios" // Use axios for custom requests
4+
5+
import {
6+
ApiHandlerOptions,
7+
ModelInfo,
8+
openAiModelInfoSaneDefaults, // Use sane defaults initially
9+
} from "../../shared/api"
10+
import { SingleCompletionHandler } from "../index"
11+
import { convertToOpenAiMessages } from "../transform/openai-format" // Reuse message formatting
12+
import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
13+
import { BaseProvider } from "./base-provider"
14+
import { XmlMatcher } from "../../utils/xml-matcher" // For potential reasoning tags
15+
16+
// Define specific options for the custom provider
17+
export interface CustomOpenAiHandlerOptions extends ApiHandlerOptions {
18+
customBaseUrl?: string
19+
customApiKey?: string
20+
customAuthHeaderName?: string // e.g., 'X-API-Key'
21+
customAuthHeaderPrefix?: string // e.g., 'Bearer ' or ''
22+
// URL path options
23+
useModelInPath?: boolean // Whether to include model in URL path (e.g., /api/v1/chat/model-name)
24+
customPathPrefix?: string // Custom path prefix (e.g., /api/v1/chat/)
25+
// Potentially add other OpenAI-compatible options if needed later
26+
modelTemperature?: number | null // Allow null to match schema
27+
includeMaxTokens?: boolean
28+
openAiStreamingEnabled?: boolean // Reuse existing streaming flag?
29+
openAiModelId?: string // Reuse model ID field
30+
openAiCustomModelInfo?: ModelInfo | null // Allow null to match schema
31+
}
32+
33+
// Default headers - maybe keep these?
34+
export const defaultHeaders = {
35+
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
36+
"X-Title": "Roo Code",
37+
}
38+
39+
export class CustomOpenAiHandler extends BaseProvider implements SingleCompletionHandler {
40+
protected options: CustomOpenAiHandlerOptions
41+
private client: AxiosInstance // Use an axios instance
42+
43+
constructor(options: CustomOpenAiHandlerOptions) {
44+
super()
45+
this.options = options
46+
47+
const baseURL = this.options.customBaseUrl
48+
if (!baseURL) {
49+
throw new Error("Custom OpenAI provider requires 'customBaseUrl' to be set.")
50+
}
51+
if (!this.options.customApiKey) {
52+
console.warn("Custom OpenAI provider initialized without 'customApiKey'.")
53+
}
54+
55+
// Prepare authentication header
56+
const authHeaderName = this.options.customAuthHeaderName || "Authorization" // Default to Authorization
57+
const authHeaderPrefix =
58+
this.options.customAuthHeaderPrefix !== undefined ? this.options.customAuthHeaderPrefix : "Bearer " // Default to Bearer prefix
59+
const apiKey = this.options.customApiKey || "not-provided"
60+
const authHeaderValue = `${authHeaderPrefix}${apiKey}`.trim() // Handle empty prefix
61+
62+
this.client = axios.create({
63+
baseURL,
64+
headers: {
65+
...defaultHeaders, // Include default Roo headers
66+
[authHeaderName]: authHeaderValue, // Add the custom auth header
67+
"Content-Type": "application/json",
68+
},
69+
})
70+
}
71+
72+
// --- Implementation using axios ---
73+
74+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
75+
const modelInfo = this.getModel().info
76+
const modelId = this.options.openAiModelId ?? "custom-model" // Get model ID from options
77+
const streamingEnabled = this.options.openAiStreamingEnabled ?? true // Default to streaming
78+
79+
// Convert messages to OpenAI format
80+
// Need to import OpenAI types for this
81+
const systemMessage: { role: "system"; content: string } = {
82+
role: "system",
83+
content: systemPrompt,
84+
}
85+
const convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)]
86+
87+
// Construct the common payload parts
88+
const payload: Record<string, any> = {
89+
model: modelId,
90+
messages: convertedMessages,
91+
temperature: this.options.modelTemperature ?? 0, // Default temperature
92+
stream: streamingEnabled,
93+
}
94+
95+
if (streamingEnabled && modelInfo.supportsUsageStream) {
96+
payload.stream_options = { include_usage: true }
97+
}
98+
99+
if (this.options.includeMaxTokens && modelInfo.maxTokens) {
100+
payload.max_tokens = modelInfo.maxTokens
101+
}
102+
// Determine the endpoint based on configuration
103+
let endpoint = "/chat/completions" // Default OpenAI-compatible endpoint
104+
105+
// If useModelInPath is true, construct the endpoint with the model in the path
106+
if (this.options.useModelInPath && modelId) {
107+
const pathPrefix = this.options.customPathPrefix || "/api/v1/chat/"
108+
endpoint = `${pathPrefix}${modelId}`
109+
}
110+
111+
try {
112+
if (streamingEnabled) {
113+
const response = await this.client.post(endpoint, payload, {
114+
responseType: "stream",
115+
})
116+
117+
const stream = response.data as NodeJS.ReadableStream
118+
let buffer = ""
119+
let lastUsage: any = null
120+
const matcher = new XmlMatcher(
121+
"think",
122+
(chunk) => ({ type: chunk.matched ? "reasoning" : "text", text: chunk.data }) as const,
123+
)
124+
125+
for await (const chunk of stream) {
126+
buffer += chunk.toString()
127+
128+
// Process buffer line by line (SSE format)
129+
let EOL
130+
while ((EOL = buffer.indexOf("\n")) >= 0) {
131+
const line = buffer.substring(0, EOL).trim()
132+
buffer = buffer.substring(EOL + 1)
133+
134+
if (line.startsWith("data:")) {
135+
const data = line.substring(5).trim()
136+
if (data === "[DONE]") {
137+
break // Stream finished
138+
}
139+
try {
140+
const parsed = JSON.parse(data)
141+
const delta = parsed.choices?.[0]?.delta ?? {}
142+
143+
if (delta.content) {
144+
for (const contentChunk of matcher.update(delta.content)) {
145+
yield contentChunk
146+
}
147+
}
148+
// Handle potential reasoning content if supported by the custom model
149+
if ("reasoning_content" in delta && delta.reasoning_content) {
150+
yield {
151+
type: "reasoning",
152+
text: (delta.reasoning_content as string | undefined) || "",
153+
}
154+
}
155+
156+
if (parsed.usage) {
157+
lastUsage = parsed.usage
158+
}
159+
} catch (e) {
160+
console.error("Error parsing stream data:", e, "Data:", data)
161+
}
162+
}
163+
}
164+
}
165+
// Yield any remaining text from the matcher
166+
for (const contentChunk of matcher.final()) {
167+
yield contentChunk
168+
}
169+
170+
if (lastUsage) {
171+
yield this.processUsageMetrics(lastUsage, modelInfo)
172+
}
173+
} else {
174+
// Non-streaming case
175+
const response = await this.client.post(endpoint, payload)
176+
const completion = response.data
177+
178+
yield {
179+
type: "text",
180+
text: completion.choices?.[0]?.message?.content || "",
181+
}
182+
if (completion.usage) {
183+
yield this.processUsageMetrics(completion.usage, modelInfo)
184+
}
185+
}
186+
} catch (error: any) {
187+
console.error("Custom OpenAI API request failed:", error)
188+
let errorMessage = "Custom OpenAI API request failed."
189+
if (axios.isAxiosError(error) && error.response) {
190+
errorMessage += ` Status: ${error.response.status}. Data: ${JSON.stringify(error.response.data)}`
191+
} else if (error instanceof Error) {
192+
errorMessage += ` Error: ${error.message}`
193+
}
194+
// Yield an error chunk or throw? For now, yield text.
195+
yield { type: "text", text: `[ERROR: ${errorMessage}]` }
196+
// Consider throwing an error instead if that's preferred for handling failures
197+
// throw new Error(errorMessage);
198+
}
199+
}
200+
201+
override getModel(): { id: string; info: ModelInfo } {
202+
// Reuse existing fields if they make sense for custom providers
203+
return {
204+
id: this.options.openAiModelId ?? "custom-model", // Default or configured ID
205+
info: this.options.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults,
206+
}
207+
}
208+
209+
async completePrompt(prompt: string): Promise<string> {
210+
// TODO: Implement non-streaming completion if needed (optional for Roo?)
211+
console.log("Prompt:", prompt)
212+
return "[Placeholder: CustomOpenAiHandler.completePrompt not implemented]"
213+
}
214+
215+
// --- Helper methods (potentially reuse/adapt from OpenAiHandler) ---
216+
protected processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk {
217+
// Adapt if usage stats format differs
218+
return {
219+
type: "usage",
220+
inputTokens: usage?.prompt_tokens || 0,
221+
outputTokens: usage?.completion_tokens || 0,
222+
}
223+
}
224+
}
225+
226+
// TODO: Add function to fetch models if the custom endpoint supports a /models route
227+
// export async function getCustomOpenAiModels(...) { ... }

src/exports/api.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
8888
images,
8989
newTab,
9090
}: {
91-
configuration: RooCodeSettings
91+
configuration?: RooCodeSettings // Make configuration optional
9292
text?: string
9393
images?: string[]
9494
newTab?: boolean

0 commit comments

Comments
 (0)