Skip to content

Commit bd001ab

Browse files
committed
feat: add SambaNova provider integration
- Add SambaNova types with supported models - Create SambaNova handler extending BaseOpenAiCompatibleProvider - Add SambaNova UI component for settings - Update provider configurations and exports - Add translations for all supported locales - Add comprehensive tests for SambaNova provider Closes #6077
1 parent 5629199 commit bd001ab

File tree

29 files changed

+432
-0
lines changed

29 files changed

+432
-0
lines changed

packages/types/src/provider-settings.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export const providerNames = [
3232
"groq",
3333
"chutes",
3434
"litellm",
35+
"sambanova",
3536
] as const
3637

3738
export const providerNamesSchema = z.enum(providerNames)
@@ -229,6 +230,10 @@ const litellmSchema = baseProviderSettingsSchema.extend({
229230
litellmModelId: z.string().optional(),
230231
})
231232

233+
const sambaNovaSchema = apiModelIdProviderModelSchema.extend({
234+
sambaNovaApiKey: z.string().optional(),
235+
})
236+
232237
const defaultSchema = z.object({
233238
apiProvider: z.undefined(),
234239
})
@@ -258,6 +263,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
258263
groqSchema.merge(z.object({ apiProvider: z.literal("groq") })),
259264
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
260265
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
266+
sambaNovaSchema.merge(z.object({ apiProvider: z.literal("sambanova") })),
261267
defaultSchema,
262268
])
263269

@@ -287,6 +293,7 @@ export const providerSettingsSchema = z.object({
287293
...groqSchema.shape,
288294
...chutesSchema.shape,
289295
...litellmSchema.shape,
296+
...sambaNovaSchema.shape,
290297
...codebaseIndexProviderSchema.shape,
291298
})
292299

packages/types/src/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export * from "./ollama.js"
1414
export * from "./openai.js"
1515
export * from "./openrouter.js"
1616
export * from "./requesty.js"
17+
export * from "./sambanova.js"
1718
export * from "./unbound.js"
1819
export * from "./vertex.js"
1920
export * from "./vscode-llm.js"
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import type { ModelInfo } from "../model.js"
2+
3+
// https://docs.sambanova.ai/cloud/docs/get-started/supported-models
4+
export type SambaNovaModelId =
5+
| "Meta-Llama-3.1-8B-Instruct"
6+
| "Meta-Llama-3.1-70B-Instruct"
7+
| "Meta-Llama-3.1-405B-Instruct"
8+
| "Meta-Llama-3.2-1B-Instruct"
9+
| "Meta-Llama-3.2-3B-Instruct"
10+
| "Meta-Llama-3.3-70B-Instruct"
11+
| "Llama-3.2-11B-Vision-Instruct"
12+
| "Llama-3.2-90B-Vision-Instruct"
13+
| "QwQ-32B-Preview"
14+
| "Qwen2.5-72B-Instruct"
15+
| "Qwen2.5-Coder-32B-Instruct"
16+
| "deepseek-r1"
17+
| "deepseek-r1-distill-llama-70b"
18+
19+
export const sambaNovaDefaultModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
20+
21+
export const sambaNovaModels = {
22+
"Meta-Llama-3.1-8B-Instruct": {
23+
maxTokens: 8192,
24+
contextWindow: 131072,
25+
supportsImages: false,
26+
supportsPromptCache: false,
27+
inputPrice: 0.1,
28+
outputPrice: 0.2,
29+
description: "Meta Llama 3.1 8B Instruct model with 128K context window.",
30+
},
31+
"Meta-Llama-3.1-70B-Instruct": {
32+
maxTokens: 8192,
33+
contextWindow: 131072,
34+
supportsImages: false,
35+
supportsPromptCache: false,
36+
inputPrice: 0.64,
37+
outputPrice: 0.8,
38+
description: "Meta Llama 3.1 70B Instruct model with 128K context window.",
39+
},
40+
"Meta-Llama-3.1-405B-Instruct": {
41+
maxTokens: 8192,
42+
contextWindow: 131072,
43+
supportsImages: false,
44+
supportsPromptCache: false,
45+
inputPrice: 3.0,
46+
outputPrice: 15.0,
47+
description: "Meta Llama 3.1 405B Instruct model with 128K context window.",
48+
},
49+
"Meta-Llama-3.2-1B-Instruct": {
50+
maxTokens: 8192,
51+
contextWindow: 131072,
52+
supportsImages: false,
53+
supportsPromptCache: false,
54+
inputPrice: 0.04,
55+
outputPrice: 0.04,
56+
description: "Meta Llama 3.2 1B Instruct model with 128K context window.",
57+
},
58+
"Meta-Llama-3.2-3B-Instruct": {
59+
maxTokens: 8192,
60+
contextWindow: 131072,
61+
supportsImages: false,
62+
supportsPromptCache: false,
63+
inputPrice: 0.06,
64+
outputPrice: 0.06,
65+
description: "Meta Llama 3.2 3B Instruct model with 128K context window.",
66+
},
67+
"Meta-Llama-3.3-70B-Instruct": {
68+
maxTokens: 8192,
69+
contextWindow: 131072,
70+
supportsImages: false,
71+
supportsPromptCache: false,
72+
inputPrice: 0.64,
73+
outputPrice: 0.8,
74+
description: "Meta Llama 3.3 70B Instruct model with 128K context window.",
75+
},
76+
"Llama-3.2-11B-Vision-Instruct": {
77+
maxTokens: 8192,
78+
contextWindow: 131072,
79+
supportsImages: true,
80+
supportsPromptCache: false,
81+
inputPrice: 0.18,
82+
outputPrice: 0.2,
83+
description: "Meta Llama 3.2 11B Vision Instruct model with image support.",
84+
},
85+
"Llama-3.2-90B-Vision-Instruct": {
86+
maxTokens: 8192,
87+
contextWindow: 131072,
88+
supportsImages: true,
89+
supportsPromptCache: false,
90+
inputPrice: 0.9,
91+
outputPrice: 1.1,
92+
description: "Meta Llama 3.2 90B Vision Instruct model with image support.",
93+
},
94+
"QwQ-32B-Preview": {
95+
maxTokens: 32768,
96+
contextWindow: 32768,
97+
supportsImages: false,
98+
supportsPromptCache: false,
99+
supportsReasoningBudget: true,
100+
inputPrice: 0.15,
101+
outputPrice: 0.15,
102+
description: "Alibaba QwQ 32B Preview reasoning model.",
103+
},
104+
"Qwen2.5-72B-Instruct": {
105+
maxTokens: 8192,
106+
contextWindow: 131072,
107+
supportsImages: false,
108+
supportsPromptCache: false,
109+
inputPrice: 0.59,
110+
outputPrice: 0.79,
111+
description: "Alibaba Qwen 2.5 72B Instruct model with 128K context window.",
112+
},
113+
"Qwen2.5-Coder-32B-Instruct": {
114+
maxTokens: 8192,
115+
contextWindow: 131072,
116+
supportsImages: false,
117+
supportsPromptCache: false,
118+
inputPrice: 0.29,
119+
outputPrice: 0.39,
120+
description: "Alibaba Qwen 2.5 Coder 32B Instruct model optimized for coding tasks.",
121+
},
122+
"deepseek-r1": {
123+
maxTokens: 8192,
124+
contextWindow: 131072,
125+
supportsImages: false,
126+
supportsPromptCache: false,
127+
supportsReasoningBudget: true,
128+
inputPrice: 0.55,
129+
outputPrice: 2.19,
130+
description: "DeepSeek R1 reasoning model with 128K context window.",
131+
},
132+
"deepseek-r1-distill-llama-70b": {
133+
maxTokens: 8192,
134+
contextWindow: 131072,
135+
supportsImages: false,
136+
supportsPromptCache: false,
137+
inputPrice: 0.27,
138+
outputPrice: 1.08,
139+
description: "DeepSeek R1 distilled Llama 70B model with 128K context window.",
140+
},
141+
} as const satisfies Record<string, ModelInfo>

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import {
2929
ChutesHandler,
3030
LiteLLMHandler,
3131
ClaudeCodeHandler,
32+
SambaNovaHandler,
3233
} from "./providers"
3334

3435
export interface SingleCompletionHandler {
@@ -112,6 +113,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
112113
return new ChutesHandler(options)
113114
case "litellm":
114115
return new LiteLLMHandler(options)
116+
case "sambanova":
117+
return new SambaNovaHandler(options)
115118
default:
116119
apiProvider satisfies "gemini-cli" | undefined
117120
return new AnthropicHandler(options)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import OpenAI from "openai"
3+
import { Anthropic } from "@anthropic-ai/sdk"
4+
5+
import { type SambaNovaModelId, sambaNovaModels } from "@roo-code/types"
6+
7+
import { SambaNovaHandler } from "../sambanova"
8+
9+
// Mock OpenAI
10+
vi.mock("openai", () => {
11+
const mockCreate = vi.fn()
12+
return {
13+
default: vi.fn(() => ({
14+
chat: {
15+
completions: {
16+
create: mockCreate,
17+
},
18+
},
19+
})),
20+
}
21+
})
22+
23+
describe("SambaNovaHandler", () => {
24+
let handler: SambaNovaHandler
25+
let mockCreate: any
26+
27+
beforeEach(() => {
28+
vi.clearAllMocks()
29+
mockCreate = (OpenAI as unknown as any)().chat.completions.create
30+
handler = new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
31+
})
32+
33+
it("should use the correct SambaNova base URL", () => {
34+
new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
35+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.sambanova.ai/v1" }))
36+
})
37+
38+
it("should use the provided API key", () => {
39+
const sambaNovaApiKey = "test-sambanova-api-key"
40+
new SambaNovaHandler({ sambaNovaApiKey })
41+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: sambaNovaApiKey }))
42+
})
43+
44+
it("should throw an error if API key is not provided", () => {
45+
expect(() => new SambaNovaHandler({} as any)).toThrow("API key is required")
46+
})
47+
48+
it("should use the specified model when provided", () => {
49+
const testModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
50+
const handlerWithModel = new SambaNovaHandler({
51+
apiModelId: testModelId,
52+
sambaNovaApiKey: "test-sambanova-api-key",
53+
})
54+
const model = handlerWithModel.getModel()
55+
expect(model.id).toBe(testModelId)
56+
expect(model.info).toEqual(sambaNovaModels[testModelId])
57+
})
58+
59+
it("should use the default model when no model is specified", () => {
60+
const model = handler.getModel()
61+
expect(model.id).toBe("Meta-Llama-3.3-70B-Instruct")
62+
expect(model.info).toEqual(sambaNovaModels["Meta-Llama-3.3-70B-Instruct"])
63+
})
64+
65+
describe("createMessage", () => {
66+
it("should create a streaming chat completion with correct parameters", async () => {
67+
const systemPrompt = "You are a helpful assistant"
68+
const messages: Anthropic.Messages.MessageParam[] = [
69+
{
70+
role: "user",
71+
content: "Hello",
72+
},
73+
]
74+
75+
mockCreate.mockImplementation(() => {
76+
const chunks = [
77+
{
78+
choices: [{ delta: { content: "Hi there!" } }],
79+
},
80+
{
81+
choices: [{ delta: {} }],
82+
usage: { prompt_tokens: 10, completion_tokens: 5 },
83+
},
84+
]
85+
86+
return {
87+
[Symbol.asyncIterator]: async function* () {
88+
for (const chunk of chunks) {
89+
yield chunk
90+
}
91+
},
92+
}
93+
})
94+
95+
const stream = handler.createMessage(systemPrompt, messages)
96+
const results = []
97+
for await (const chunk of stream) {
98+
results.push(chunk)
99+
}
100+
101+
expect(mockCreate).toHaveBeenCalledWith(
102+
expect.objectContaining({
103+
model: "Meta-Llama-3.3-70B-Instruct",
104+
max_tokens: 8192,
105+
temperature: 0.7,
106+
messages: [
107+
{ role: "system", content: systemPrompt },
108+
{ role: "user", content: "Hello" },
109+
],
110+
stream: true,
111+
stream_options: { include_usage: true },
112+
}),
113+
)
114+
115+
expect(results).toEqual([
116+
{ type: "text", text: "Hi there!" },
117+
{ type: "usage", inputTokens: 10, outputTokens: 5 },
118+
])
119+
})
120+
})
121+
122+
describe("completePrompt", () => {
123+
it("should complete a prompt successfully", async () => {
124+
const prompt = "Test prompt"
125+
const expectedResponse = "Test response"
126+
127+
mockCreate.mockResolvedValue({
128+
choices: [{ message: { content: expectedResponse } }],
129+
})
130+
131+
const result = await handler.completePrompt(prompt)
132+
133+
expect(mockCreate).toHaveBeenCalledWith({
134+
model: "Meta-Llama-3.3-70B-Instruct",
135+
messages: [{ role: "user", content: prompt }],
136+
})
137+
expect(result).toBe(expectedResponse)
138+
})
139+
140+
it("should handle errors properly", async () => {
141+
const prompt = "Test prompt"
142+
const errorMessage = "API Error"
143+
144+
mockCreate.mockRejectedValue(new Error(errorMessage))
145+
146+
await expect(handler.completePrompt(prompt)).rejects.toThrow(`SambaNova completion error: ${errorMessage}`)
147+
})
148+
})
149+
150+
describe("model selection", () => {
151+
it.each(Object.keys(sambaNovaModels) as SambaNovaModelId[])("should correctly handle model %s", (modelId) => {
152+
const modelInfo = sambaNovaModels[modelId]
153+
const handlerWithModel = new SambaNovaHandler({
154+
apiModelId: modelId,
155+
sambaNovaApiKey: "test-sambanova-api-key",
156+
})
157+
158+
const model = handlerWithModel.getModel()
159+
expect(model.id).toBe(modelId)
160+
expect(model.info).toEqual(modelInfo)
161+
})
162+
})
163+
})

src/api/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ export { OpenAiNativeHandler } from "./openai-native"
1818
export { OpenAiHandler } from "./openai"
1919
export { OpenRouterHandler } from "./openrouter"
2020
export { RequestyHandler } from "./requesty"
21+
export { SambaNovaHandler } from "./sambanova"
2122
export { UnboundHandler } from "./unbound"
2223
export { VertexHandler } from "./vertex"
2324
export { VsCodeLmHandler } from "./vscode-lm"

src/api/providers/sambanova.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { type SambaNovaModelId, sambaNovaDefaultModelId, sambaNovaModels } from "@roo-code/types"
2+
3+
import type { ApiHandlerOptions } from "../../shared/api"
4+
5+
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
6+
7+
export class SambaNovaHandler extends BaseOpenAiCompatibleProvider<SambaNovaModelId> {
8+
constructor(options: ApiHandlerOptions) {
9+
super({
10+
...options,
11+
providerName: "SambaNova",
12+
baseURL: "https://api.sambanova.ai/v1",
13+
apiKey: options.sambaNovaApiKey,
14+
defaultProviderModelId: sambaNovaDefaultModelId,
15+
providerModels: sambaNovaModels,
16+
defaultTemperature: 0.7,
17+
})
18+
}
19+
}

0 commit comments

Comments
 (0)