Skip to content

Commit c20ef5a

Browse files
committed
feat: add openai-compatible provider support for token usage display
- Add 'openai-compatible' as a valid provider in buildApiHandler - Add 'openai-compatible' to provider types and schemas - Update ProfileValidator to handle openai-compatible provider - Add tests for openai-compatible provider functionality Fixes #8543 - Token usage now displays correctly when using OpenAI Compatible API provider
1 parent 5a3f911 commit c20ef5a

File tree

4 files changed

+146
-2
lines changed

4 files changed

+146
-2
lines changed

packages/types/src/provider-settings.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ export const isInternalProvider = (key: string): key is InternalProvider =>
8888
* Custom providers are completely configurable within Roo Code settings.
8989
*/
9090

91-
export const customProviders = ["openai"] as const
91+
export const customProviders = ["openai", "openai-compatible"] as const
9292

9393
export type CustomProvider = (typeof customProviders)[number]
9494

@@ -138,6 +138,7 @@ export const providerNames = [
138138
"vertex",
139139
"xai",
140140
"zai",
141+
"openai-compatible",
141142
] as const
142143

143144
export const providerNamesSchema = z.enum(providerNames)
@@ -424,6 +425,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
424425
bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })),
425426
vertexSchema.merge(z.object({ apiProvider: z.literal("vertex") })),
426427
openAiSchema.merge(z.object({ apiProvider: z.literal("openai") })),
428+
openAiSchema.merge(z.object({ apiProvider: z.literal("openai-compatible") })),
427429
ollamaSchema.merge(z.object({ apiProvider: z.literal("ollama") })),
428430
vsCodeLmSchema.merge(z.object({ apiProvider: z.literal("vscode-lm") })),
429431
lmStudioSchema.merge(z.object({ apiProvider: z.literal("lmstudio") })),
@@ -610,7 +612,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str
610612
*/
611613

612614
export const MODELS_BY_PROVIDER: Record<
613-
Exclude<ProviderName, "fake-ai" | "human-relay" | "gemini-cli" | "openai">,
615+
Exclude<ProviderName, "fake-ai" | "human-relay" | "gemini-cli" | "openai" | "openai-compatible">,
614616
{ id: ProviderName; label: string; models: string[] }
615617
> = {
616618
anthropic: {

src/api/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
106106
? new AnthropicVertexHandler(options)
107107
: new VertexHandler(options)
108108
case "openai":
109+
case "openai-compatible":
109110
return new OpenAiHandler(options)
110111
case "ollama":
111112
return new NativeOllamaHandler(options)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import { buildApiHandler } from "../../index"
3+
import { OpenAiHandler } from "../openai"
4+
5+
vi.mock("openai", () => {
6+
const mockCreate = vi.fn()
7+
return {
8+
default: vi.fn().mockImplementation(() => ({
9+
chat: {
10+
completions: {
11+
create: mockCreate,
12+
},
13+
},
14+
})),
15+
OpenAI: vi.fn().mockImplementation(() => ({
16+
chat: {
17+
completions: {
18+
create: mockCreate,
19+
},
20+
},
21+
})),
22+
AzureOpenAI: vi.fn().mockImplementation(() => ({
23+
chat: {
24+
completions: {
25+
create: mockCreate,
26+
},
27+
},
28+
})),
29+
}
30+
})
31+
32+
describe("OpenAI Compatible Provider", () => {
33+
beforeEach(() => {
34+
vi.clearAllMocks()
35+
})
36+
37+
it("should create OpenAiHandler when apiProvider is 'openai-compatible'", () => {
38+
const handler = buildApiHandler({
39+
apiProvider: "openai-compatible",
40+
openAiApiKey: "test-key",
41+
openAiBaseUrl: "https://api.example.com/v1",
42+
openAiModelId: "test-model",
43+
})
44+
45+
expect(handler).toBeInstanceOf(OpenAiHandler)
46+
})
47+
48+
it("should handle token usage correctly for openai-compatible provider", async () => {
49+
const mockStream = {
50+
async *[Symbol.asyncIterator]() {
51+
yield {
52+
choices: [{ delta: { content: "Hello" } }],
53+
}
54+
yield {
55+
choices: [{ delta: { content: " world" } }],
56+
}
57+
yield {
58+
choices: [{ delta: {} }],
59+
usage: {
60+
prompt_tokens: 10,
61+
completion_tokens: 5,
62+
total_tokens: 15,
63+
},
64+
}
65+
},
66+
}
67+
68+
const OpenAI = (await import("openai")).default
69+
const mockCreate = vi.fn().mockResolvedValue(mockStream)
70+
;(OpenAI as any).mockImplementation(() => ({
71+
chat: {
72+
completions: {
73+
create: mockCreate,
74+
},
75+
},
76+
}))
77+
78+
const handler = buildApiHandler({
79+
apiProvider: "openai-compatible",
80+
openAiApiKey: "test-key",
81+
openAiBaseUrl: "https://api.example.com/v1",
82+
openAiModelId: "test-model",
83+
})
84+
85+
const messages = [{ role: "user" as const, content: "Test message" }]
86+
const stream = handler.createMessage("System prompt", messages)
87+
88+
const chunks = []
89+
for await (const chunk of stream) {
90+
chunks.push(chunk)
91+
}
92+
93+
// Check that we got text chunks
94+
const textChunks = chunks.filter((c) => c.type === "text")
95+
expect(textChunks).toHaveLength(2)
96+
expect(textChunks[0].text).toBe("Hello")
97+
expect(textChunks[1].text).toBe(" world")
98+
99+
// Check that we got usage data
100+
const usageChunk = chunks.find((c) => c.type === "usage")
101+
expect(usageChunk).toBeDefined()
102+
expect(usageChunk).toEqual({
103+
type: "usage",
104+
inputTokens: 10,
105+
outputTokens: 5,
106+
})
107+
})
108+
109+
it("should use the same configuration as openai provider", () => {
110+
const config = {
111+
openAiApiKey: "test-key",
112+
openAiBaseUrl: "https://api.example.com/v1",
113+
openAiModelId: "test-model",
114+
openAiCustomModelInfo: {
115+
maxTokens: 4096,
116+
contextWindow: 8192,
117+
supportsPromptCache: false,
118+
inputPrice: 0.001,
119+
outputPrice: 0.002,
120+
},
121+
}
122+
123+
const openaiHandler = buildApiHandler({
124+
apiProvider: "openai",
125+
...config,
126+
})
127+
128+
const openaiCompatibleHandler = buildApiHandler({
129+
apiProvider: "openai-compatible",
130+
...config,
131+
})
132+
133+
// Both should be instances of OpenAiHandler
134+
expect(openaiHandler).toBeInstanceOf(OpenAiHandler)
135+
expect(openaiCompatibleHandler).toBeInstanceOf(OpenAiHandler)
136+
137+
// Both should have the same model configuration
138+
expect(openaiHandler.getModel()).toEqual(openaiCompatibleHandler.getModel())
139+
})
140+
})

src/shared/ProfileValidator.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ export class ProfileValidator {
5656
private static getModelIdFromProfile(profile: ProviderSettings): string | undefined {
5757
switch (profile.apiProvider) {
5858
case "openai":
59+
case "openai-compatible":
5960
return profile.openAiModelId
6061
case "anthropic":
6162
case "openai-native":

0 commit comments

Comments
 (0)