Skip to content

Commit 76278ed

Browse files
authored
Merge pull request #616 from websentry-ai/vs/support-unbound
Supports unbound API provider
2 parents e07d2aa + ea30563 commit 76278ed

File tree

6 files changed

+426
-0
lines changed

6 files changed

+426
-0
lines changed

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { DeepSeekHandler } from "./providers/deepseek"
1414
import { MistralHandler } from "./providers/mistral"
1515
import { VsCodeLmHandler } from "./providers/vscode-lm"
1616
import { ApiStream } from "./transform/stream"
17+
import { UnboundHandler } from "./providers/unbound"
1718

1819
export interface SingleCompletionHandler {
1920
completePrompt(prompt: string): Promise<string>
@@ -53,6 +54,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
5354
return new VsCodeLmHandler(options)
5455
case "mistral":
5556
return new MistralHandler(options)
57+
case "unbound":
58+
return new UnboundHandler(options)
5659
default:
5760
return new AnthropicHandler(options)
5861
}
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import { UnboundHandler } from "../unbound"
2+
import { ApiHandlerOptions } from "../../../shared/api"
3+
import OpenAI from "openai"
4+
import { Anthropic } from "@anthropic-ai/sdk"
5+
6+
// Mock OpenAI client
7+
const mockCreate = jest.fn()
8+
const mockWithResponse = jest.fn()
9+
10+
jest.mock("openai", () => {
11+
return {
12+
__esModule: true,
13+
default: jest.fn().mockImplementation(() => ({
14+
chat: {
15+
completions: {
16+
create: (...args: any[]) => {
17+
const stream = {
18+
[Symbol.asyncIterator]: async function* () {
19+
yield {
20+
choices: [
21+
{
22+
delta: { content: "Test response" },
23+
index: 0,
24+
},
25+
],
26+
}
27+
yield {
28+
choices: [
29+
{
30+
delta: {},
31+
index: 0,
32+
},
33+
],
34+
}
35+
},
36+
}
37+
38+
const result = mockCreate(...args)
39+
if (args[0].stream) {
40+
mockWithResponse.mockReturnValue(
41+
Promise.resolve({
42+
data: stream,
43+
response: { headers: new Map() },
44+
}),
45+
)
46+
result.withResponse = mockWithResponse
47+
}
48+
return result
49+
},
50+
},
51+
},
52+
})),
53+
}
54+
})
55+
56+
describe("UnboundHandler", () => {
57+
let handler: UnboundHandler
58+
let mockOptions: ApiHandlerOptions
59+
60+
beforeEach(() => {
61+
mockOptions = {
62+
apiModelId: "anthropic/claude-3-5-sonnet-20241022",
63+
unboundApiKey: "test-api-key",
64+
}
65+
handler = new UnboundHandler(mockOptions)
66+
mockCreate.mockClear()
67+
mockWithResponse.mockClear()
68+
69+
// Default mock implementation for non-streaming responses
70+
mockCreate.mockResolvedValue({
71+
id: "test-completion",
72+
choices: [
73+
{
74+
message: { role: "assistant", content: "Test response" },
75+
finish_reason: "stop",
76+
index: 0,
77+
},
78+
],
79+
})
80+
})
81+
82+
describe("constructor", () => {
83+
it("should initialize with provided options", () => {
84+
expect(handler).toBeInstanceOf(UnboundHandler)
85+
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
86+
})
87+
})
88+
89+
describe("createMessage", () => {
90+
const systemPrompt = "You are a helpful assistant."
91+
const messages: Anthropic.Messages.MessageParam[] = [
92+
{
93+
role: "user",
94+
content: "Hello!",
95+
},
96+
]
97+
98+
it("should handle streaming responses", async () => {
99+
const stream = handler.createMessage(systemPrompt, messages)
100+
const chunks: any[] = []
101+
for await (const chunk of stream) {
102+
chunks.push(chunk)
103+
}
104+
105+
expect(chunks.length).toBe(1)
106+
expect(chunks[0]).toEqual({
107+
type: "text",
108+
text: "Test response",
109+
})
110+
111+
expect(mockCreate).toHaveBeenCalledWith(
112+
expect.objectContaining({
113+
model: "claude-3-5-sonnet-20241022",
114+
messages: expect.any(Array),
115+
stream: true,
116+
}),
117+
expect.objectContaining({
118+
headers: {
119+
"X-Unbound-Metadata": expect.stringContaining("roo-code"),
120+
},
121+
}),
122+
)
123+
})
124+
125+
it("should handle API errors", async () => {
126+
mockCreate.mockImplementationOnce(() => {
127+
throw new Error("API Error")
128+
})
129+
130+
const stream = handler.createMessage(systemPrompt, messages)
131+
const chunks = []
132+
133+
try {
134+
for await (const chunk of stream) {
135+
chunks.push(chunk)
136+
}
137+
fail("Expected error to be thrown")
138+
} catch (error) {
139+
expect(error).toBeInstanceOf(Error)
140+
expect(error.message).toBe("API Error")
141+
}
142+
})
143+
})
144+
145+
describe("completePrompt", () => {
146+
it("should complete prompt successfully", async () => {
147+
const result = await handler.completePrompt("Test prompt")
148+
expect(result).toBe("Test response")
149+
expect(mockCreate).toHaveBeenCalledWith(
150+
expect.objectContaining({
151+
model: "claude-3-5-sonnet-20241022",
152+
messages: [{ role: "user", content: "Test prompt" }],
153+
temperature: 0,
154+
max_tokens: 8192,
155+
}),
156+
)
157+
})
158+
159+
it("should handle API errors", async () => {
160+
mockCreate.mockRejectedValueOnce(new Error("API Error"))
161+
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Unbound completion error: API Error")
162+
})
163+
164+
it("should handle empty response", async () => {
165+
mockCreate.mockResolvedValueOnce({
166+
choices: [{ message: { content: "" } }],
167+
})
168+
const result = await handler.completePrompt("Test prompt")
169+
expect(result).toBe("")
170+
})
171+
172+
it("should not set max_tokens for non-Anthropic models", async () => {
173+
mockCreate.mockClear()
174+
175+
const nonAnthropicOptions = {
176+
apiModelId: "openai/gpt-4o",
177+
unboundApiKey: "test-key",
178+
}
179+
const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions)
180+
181+
await nonAnthropicHandler.completePrompt("Test prompt")
182+
expect(mockCreate).toHaveBeenCalledWith(
183+
expect.objectContaining({
184+
model: "gpt-4o",
185+
messages: [{ role: "user", content: "Test prompt" }],
186+
temperature: 0,
187+
}),
188+
)
189+
expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens")
190+
})
191+
})
192+
193+
describe("getModel", () => {
194+
it("should return model info", () => {
195+
const modelInfo = handler.getModel()
196+
expect(modelInfo.id).toBe(mockOptions.apiModelId)
197+
expect(modelInfo.info).toBeDefined()
198+
})
199+
200+
it("should return default model when invalid model provided", () => {
201+
const handlerWithInvalidModel = new UnboundHandler({
202+
...mockOptions,
203+
apiModelId: "invalid/model",
204+
})
205+
const modelInfo = handlerWithInvalidModel.getModel()
206+
expect(modelInfo.id).toBe("openai/gpt-4o") // Default model
207+
expect(modelInfo.info).toBeDefined()
208+
})
209+
})
210+
})

src/api/providers/unbound.ts

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import OpenAI from "openai"
3+
import { ApiHandler, SingleCompletionHandler } from "../"
4+
import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api"
5+
import { convertToOpenAiMessages } from "../transform/openai-format"
6+
import { ApiStream } from "../transform/stream"
7+
8+
export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
9+
private options: ApiHandlerOptions
10+
private client: OpenAI
11+
12+
constructor(options: ApiHandlerOptions) {
13+
this.options = options
14+
this.client = new OpenAI({
15+
baseURL: "https://api.getunbound.ai/v1",
16+
apiKey: this.options.unboundApiKey,
17+
})
18+
}
19+
20+
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
21+
// Convert Anthropic messages to OpenAI format
22+
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
23+
{ role: "system", content: systemPrompt },
24+
...convertToOpenAiMessages(messages),
25+
]
26+
27+
// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
28+
if (this.getModel().id.startsWith("anthropic/claude-3")) {
29+
openAiMessages[0] = {
30+
role: "system",
31+
content: [
32+
{
33+
type: "text",
34+
text: systemPrompt,
35+
// @ts-ignore-next-line
36+
cache_control: { type: "ephemeral" },
37+
},
38+
],
39+
}
40+
41+
// Add cache_control to the last two user messages
42+
// (note: this works because we only ever add one user message at a time,
43+
// but if we added multiple we'd need to mark the user message before the last assistant message)
44+
const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2)
45+
lastTwoUserMessages.forEach((msg) => {
46+
if (typeof msg.content === "string") {
47+
msg.content = [{ type: "text", text: msg.content }]
48+
}
49+
if (Array.isArray(msg.content)) {
50+
// NOTE: this is fine since env details will always be added at the end.
51+
// but if it weren't there, and the user added a image_url type message,
52+
// it would pop a text part before it and then move it after to the end.
53+
let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
54+
55+
if (!lastTextPart) {
56+
lastTextPart = { type: "text", text: "..." }
57+
msg.content.push(lastTextPart)
58+
}
59+
// @ts-ignore-next-line
60+
lastTextPart["cache_control"] = { type: "ephemeral" }
61+
}
62+
})
63+
}
64+
65+
// Required by Anthropic
66+
// Other providers default to max tokens allowed.
67+
let maxTokens: number | undefined
68+
69+
if (this.getModel().id.startsWith("anthropic/")) {
70+
maxTokens = 8_192
71+
}
72+
73+
const { data: completion, response } = await this.client.chat.completions
74+
.create(
75+
{
76+
model: this.getModel().id.split("/")[1],
77+
max_tokens: maxTokens,
78+
temperature: 0,
79+
messages: openAiMessages,
80+
stream: true,
81+
},
82+
{
83+
headers: {
84+
"X-Unbound-Metadata": JSON.stringify({
85+
labels: [
86+
{
87+
key: "app",
88+
value: "roo-code",
89+
},
90+
],
91+
}),
92+
},
93+
},
94+
)
95+
.withResponse()
96+
97+
for await (const chunk of completion) {
98+
const delta = chunk.choices[0]?.delta
99+
const usage = chunk.usage
100+
101+
if (delta?.content) {
102+
yield {
103+
type: "text",
104+
text: delta.content,
105+
}
106+
}
107+
108+
if (usage) {
109+
yield {
110+
type: "usage",
111+
inputTokens: usage?.prompt_tokens || 0,
112+
outputTokens: usage?.completion_tokens || 0,
113+
}
114+
}
115+
}
116+
}
117+
118+
getModel(): { id: UnboundModelId; info: ModelInfo } {
119+
const modelId = this.options.apiModelId
120+
if (modelId && modelId in unboundModels) {
121+
const id = modelId as UnboundModelId
122+
return { id, info: unboundModels[id] }
123+
}
124+
return {
125+
id: unboundDefaultModelId,
126+
info: unboundModels[unboundDefaultModelId],
127+
}
128+
}
129+
130+
async completePrompt(prompt: string): Promise<string> {
131+
try {
132+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
133+
model: this.getModel().id.split("/")[1],
134+
messages: [{ role: "user", content: prompt }],
135+
temperature: 0,
136+
}
137+
138+
if (this.getModel().id.startsWith("anthropic/")) {
139+
requestOptions.max_tokens = 8192
140+
}
141+
142+
const response = await this.client.chat.completions.create(requestOptions)
143+
return response.choices[0]?.message.content || ""
144+
} catch (error) {
145+
if (error instanceof Error) {
146+
throw new Error(`Unbound completion error: ${error.message}`)
147+
}
148+
throw error
149+
}
150+
}
151+
}

0 commit comments

Comments
 (0)