Skip to content

Commit c20c8a8

Browse files
snova-jorgeputarn
authored andcommitted
feat: add SambaNova provider integration (RooCodeInc#6188)
1 parent aca2432 commit c20c8a8

File tree

34 files changed

+443
-113
lines changed

34 files changed

+443
-113
lines changed

.github/ISSUE_TEMPLATE/bug_report.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ body:
3838
- OpenAI Compatible
3939
- OpenRouter
4040
- Requesty
41+
- SambaNova
4142
- Unbound
4243
- VS Code Language Model API
4344
- xAI (Grok)

packages/types/src/global-settings.ts

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { z } from "zod"
22

3-
import { type Keys, keysOf } from "./type-fu.js"
3+
import { type Keys } from "./type-fu.js"
44
import {
55
type ProviderSettings,
66
PROVIDER_SETTINGS_KEYS,
@@ -161,63 +161,35 @@ export type RooCodeSettings = GlobalSettings & ProviderSettings
161161
/**
162162
* SecretState
163163
*/
164-
165-
export type SecretState = Pick<
166-
ProviderSettings,
167-
| "apiKey"
168-
| "glamaApiKey"
169-
| "openRouterApiKey"
170-
| "awsAccessKey"
171-
| "awsSecretKey"
172-
| "awsSessionToken"
173-
| "openAiApiKey"
174-
| "geminiApiKey"
175-
| "openAiNativeApiKey"
176-
| "deepSeekApiKey"
177-
| "mistralApiKey"
178-
| "unboundApiKey"
179-
| "requestyApiKey"
180-
| "xaiApiKey"
181-
| "groqApiKey"
182-
| "chutesApiKey"
183-
| "litellmApiKey"
184-
| "modelharborApiKey"
185-
| "codeIndexOpenAiKey"
186-
| "codeIndexQdrantApiKey"
187-
| "codebaseIndexOpenAiCompatibleApiKey"
188-
| "codebaseIndexGeminiApiKey"
189-
| "codeIndexModelHarborApiKey"
190-
| "codebaseIndexMistralApiKey"
191-
| "huggingFaceApiKey"
192-
>
193-
194-
export const SECRET_STATE_KEYS = keysOf<SecretState>()([
164+
export const SECRET_STATE_KEYS = [
195165
"apiKey",
196166
"glamaApiKey",
197167
"openRouterApiKey",
198168
"awsAccessKey",
169+
"awsApiKey",
199170
"awsSecretKey",
200171
"awsSessionToken",
201172
"openAiApiKey",
202173
"geminiApiKey",
203174
"openAiNativeApiKey",
204175
"deepSeekApiKey",
176+
"moonshotApiKey",
205177
"mistralApiKey",
206178
"unboundApiKey",
207179
"requestyApiKey",
208180
"xaiApiKey",
209181
"groqApiKey",
210182
"chutesApiKey",
211183
"litellmApiKey",
212-
"modelharborApiKey",
213184
"codeIndexOpenAiKey",
214185
"codeIndexQdrantApiKey",
215186
"codebaseIndexOpenAiCompatibleApiKey",
216187
"codebaseIndexGeminiApiKey",
217-
"codeIndexModelHarborApiKey",
218188
"codebaseIndexMistralApiKey",
219189
"huggingFaceApiKey",
220-
])
190+
"sambaNovaApiKey",
191+
] as const satisfies readonly (keyof ProviderSettings)[]
192+
export type SecretState = Pick<ProviderSettings, (typeof SECRET_STATE_KEYS)[number]>
221193

222194
export const isSecretStateKey = (key: string): key is Keys<SecretState> =>
223195
SECRET_STATE_KEYS.includes(key as Keys<SecretState>)

packages/types/src/provider-settings.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ export const providerNames = [
3535
"litellm",
3636
"huggingface",
3737
"modelharbor",
38+
"sambanova",
3839
] as const
3940

4041
export const providerNamesSchema = z.enum(providerNames)
@@ -248,6 +249,11 @@ const modelharborSchema = baseProviderSettingsSchema.extend({
248249
modelharborModelId: z.string().optional(),
249250
})
250251

252+
const sambaNovaSchema = apiModelIdProviderModelSchema.extend({
253+
sambaNovaApiKey: z.string().optional(),
254+
})
255+
})
256+
251257
const defaultSchema = z.object({
252258
apiProvider: z.undefined(),
253259
})
@@ -279,6 +285,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
279285
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
280286
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
281287
modelharborSchema.merge(z.object({ apiProvider: z.literal("modelharbor") })),
288+
sambaNovaSchema.merge(z.object({ apiProvider: z.literal("sambanova") })),
282289
defaultSchema,
283290
])
284291

@@ -310,6 +317,7 @@ export const providerSettingsSchema = z.object({
310317
...chutesSchema.shape,
311318
...litellmSchema.shape,
312319
...modelharborSchema.shape,
320+
...sambaNovaSchema.shape,
313321
...codebaseIndexProviderSchema.shape,
314322
})
315323

packages/types/src/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ export * from "./ollama.js"
1515
export * from "./openai.js"
1616
export * from "./openrouter.js"
1717
export * from "./requesty.js"
18+
export * from "./sambanova.js"
1819
export * from "./unbound.js"
1920
export * from "./vertex.js"
2021
export * from "./vscode-llm.js"
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import type { ModelInfo } from "../model.js"
2+
3+
// https://docs.sambanova.ai/cloud/docs/get-started/supported-models
4+
export type SambaNovaModelId =
5+
| "Meta-Llama-3.1-8B-Instruct"
6+
| "Meta-Llama-3.3-70B-Instruct"
7+
| "DeepSeek-R1"
8+
| "DeepSeek-V3-0324"
9+
| "DeepSeek-R1-Distill-Llama-70B"
10+
| "Llama-4-Maverick-17B-128E-Instruct"
11+
| "Llama-3.3-Swallow-70B-Instruct-v0.4"
12+
| "Qwen3-32B"
13+
14+
export const sambaNovaDefaultModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
15+
16+
export const sambaNovaModels = {
17+
"Meta-Llama-3.1-8B-Instruct": {
18+
maxTokens: 8192,
19+
contextWindow: 16384,
20+
supportsImages: false,
21+
supportsPromptCache: false,
22+
inputPrice: 0.1,
23+
outputPrice: 0.2,
24+
description: "Meta Llama 3.1 8B Instruct model with 16K context window.",
25+
},
26+
"Meta-Llama-3.3-70B-Instruct": {
27+
maxTokens: 8192,
28+
contextWindow: 131072,
29+
supportsImages: false,
30+
supportsPromptCache: false,
31+
inputPrice: 0.6,
32+
outputPrice: 1.2,
33+
description: "Meta Llama 3.3 70B Instruct model with 128K context window.",
34+
},
35+
"DeepSeek-R1": {
36+
maxTokens: 8192,
37+
contextWindow: 32768,
38+
supportsImages: false,
39+
supportsPromptCache: false,
40+
supportsReasoningBudget: true,
41+
inputPrice: 5.0,
42+
outputPrice: 7.0,
43+
description: "DeepSeek R1 reasoning model with 32K context window.",
44+
},
45+
"DeepSeek-V3-0324": {
46+
maxTokens: 8192,
47+
contextWindow: 32768,
48+
supportsImages: false,
49+
supportsPromptCache: false,
50+
inputPrice: 3.0,
51+
outputPrice: 4.5,
52+
description: "DeepSeek V3 model with 32K context window.",
53+
},
54+
"DeepSeek-R1-Distill-Llama-70B": {
55+
maxTokens: 8192,
56+
contextWindow: 131072,
57+
supportsImages: false,
58+
supportsPromptCache: false,
59+
inputPrice: 0.7,
60+
outputPrice: 1.4,
61+
description: "DeepSeek R1 distilled Llama 70B model with 128K context window.",
62+
},
63+
"Llama-4-Maverick-17B-128E-Instruct": {
64+
maxTokens: 8192,
65+
contextWindow: 131072,
66+
supportsImages: true,
67+
supportsPromptCache: false,
68+
inputPrice: 0.63,
69+
outputPrice: 1.8,
70+
description: "Meta Llama 4 Maverick 17B 128E Instruct model with 128K context window.",
71+
},
72+
"Llama-3.3-Swallow-70B-Instruct-v0.4": {
73+
maxTokens: 8192,
74+
contextWindow: 16384,
75+
supportsImages: false,
76+
supportsPromptCache: false,
77+
inputPrice: 0.6,
78+
outputPrice: 1.2,
79+
description: "Tokyotech Llama 3.3 Swallow 70B Instruct v0.4 model with 16K context window.",
80+
},
81+
"Qwen3-32B": {
82+
maxTokens: 8192,
83+
contextWindow: 8192,
84+
supportsImages: false,
85+
supportsPromptCache: false,
86+
inputPrice: 0.4,
87+
outputPrice: 0.8,
88+
description: "Alibaba Qwen 3 32B model with 8K context window.",
89+
},
90+
} as const satisfies Record<string, ModelInfo>

src/api/index.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import {
3030
ChutesHandler,
3131
LiteLLMHandler,
3232
ClaudeCodeHandler,
33-
ModelHarborHandler,
33+
SambaNovaHandler,
3434
} from "./providers"
3535

3636
export interface SingleCompletionHandler {
@@ -116,8 +116,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
116116
return new ChutesHandler(options)
117117
case "litellm":
118118
return new LiteLLMHandler(options)
119-
case "modelharbor":
120-
return new ModelHarborHandler(options)
119+
case "sambanova":
120+
return new SambaNovaHandler(options)
121121
default:
122122
apiProvider satisfies "gemini-cli" | undefined
123123
return new AnthropicHandler(options)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// npx vitest run src/api/providers/__tests__/sambanova.spec.ts
2+
3+
// Mock vscode first to avoid import errors
4+
vitest.mock("vscode", () => ({}))
5+
6+
import OpenAI from "openai"
7+
import { Anthropic } from "@anthropic-ai/sdk"
8+
9+
import { type SambaNovaModelId, sambaNovaDefaultModelId, sambaNovaModels } from "@roo-code/types"
10+
11+
import { SambaNovaHandler } from "../sambanova"
12+
13+
vitest.mock("openai", () => {
14+
const createMock = vitest.fn()
15+
return {
16+
default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })),
17+
}
18+
})
19+
20+
describe("SambaNovaHandler", () => {
21+
let handler: SambaNovaHandler
22+
let mockCreate: any
23+
24+
beforeEach(() => {
25+
vitest.clearAllMocks()
26+
mockCreate = (OpenAI as unknown as any)().chat.completions.create
27+
handler = new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
28+
})
29+
30+
it("should use the correct SambaNova base URL", () => {
31+
new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
32+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.sambanova.ai/v1" }))
33+
})
34+
35+
it("should use the provided API key", () => {
36+
const sambaNovaApiKey = "test-sambanova-api-key"
37+
new SambaNovaHandler({ sambaNovaApiKey })
38+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: sambaNovaApiKey }))
39+
})
40+
41+
it("should return default model when no model is specified", () => {
42+
const model = handler.getModel()
43+
expect(model.id).toBe(sambaNovaDefaultModelId)
44+
expect(model.info).toEqual(sambaNovaModels[sambaNovaDefaultModelId])
45+
})
46+
47+
it("should return specified model when valid model is provided", () => {
48+
const testModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
49+
const handlerWithModel = new SambaNovaHandler({
50+
apiModelId: testModelId,
51+
sambaNovaApiKey: "test-sambanova-api-key",
52+
})
53+
const model = handlerWithModel.getModel()
54+
expect(model.id).toBe(testModelId)
55+
expect(model.info).toEqual(sambaNovaModels[testModelId])
56+
})
57+
58+
it("completePrompt method should return text from SambaNova API", async () => {
59+
const expectedResponse = "This is a test response from SambaNova"
60+
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
61+
const result = await handler.completePrompt("test prompt")
62+
expect(result).toBe(expectedResponse)
63+
})
64+
65+
it("should handle errors in completePrompt", async () => {
66+
const errorMessage = "SambaNova API error"
67+
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
68+
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
69+
`SambaNova completion error: ${errorMessage}`,
70+
)
71+
})
72+
73+
it("createMessage should yield text content from stream", async () => {
74+
const testContent = "This is test content from SambaNova stream"
75+
76+
mockCreate.mockImplementationOnce(() => {
77+
return {
78+
[Symbol.asyncIterator]: () => ({
79+
next: vitest
80+
.fn()
81+
.mockResolvedValueOnce({
82+
done: false,
83+
value: { choices: [{ delta: { content: testContent } }] },
84+
})
85+
.mockResolvedValueOnce({ done: true }),
86+
}),
87+
}
88+
})
89+
90+
const stream = handler.createMessage("system prompt", [])
91+
const firstChunk = await stream.next()
92+
93+
expect(firstChunk.done).toBe(false)
94+
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
95+
})
96+
97+
it("createMessage should yield usage data from stream", async () => {
98+
mockCreate.mockImplementationOnce(() => {
99+
return {
100+
[Symbol.asyncIterator]: () => ({
101+
next: vitest
102+
.fn()
103+
.mockResolvedValueOnce({
104+
done: false,
105+
value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } },
106+
})
107+
.mockResolvedValueOnce({ done: true }),
108+
}),
109+
}
110+
})
111+
112+
const stream = handler.createMessage("system prompt", [])
113+
const firstChunk = await stream.next()
114+
115+
expect(firstChunk.done).toBe(false)
116+
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
117+
})
118+
119+
it("createMessage should pass correct parameters to SambaNova client", async () => {
120+
const modelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
121+
const modelInfo = sambaNovaModels[modelId]
122+
const handlerWithModel = new SambaNovaHandler({
123+
apiModelId: modelId,
124+
sambaNovaApiKey: "test-sambanova-api-key",
125+
})
126+
127+
mockCreate.mockImplementationOnce(() => {
128+
return {
129+
[Symbol.asyncIterator]: () => ({
130+
async next() {
131+
return { done: true }
132+
},
133+
}),
134+
}
135+
})
136+
137+
const systemPrompt = "Test system prompt for SambaNova"
138+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for SambaNova" }]
139+
140+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
141+
await messageGenerator.next()
142+
143+
expect(mockCreate).toHaveBeenCalledWith(
144+
expect.objectContaining({
145+
model: modelId,
146+
max_tokens: modelInfo.maxTokens,
147+
temperature: 0.7,
148+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
149+
stream: true,
150+
stream_options: { include_usage: true },
151+
}),
152+
)
153+
})
154+
})

0 commit comments

Comments
 (0)