Skip to content

Commit 8dc4ba8

Browse files
committed
Adding nebius to roocode
1 parent b8aa4b4 commit 8dc4ba8

File tree

18 files changed

+433
-10
lines changed

18 files changed

+433
-10
lines changed

evals/apps/web/src/app/runs/new/new-run.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ export function NewRun() {
173173
ollamaModelId,
174174
lmStudioModelId,
175175
openAiModelId,
176+
nebiusModelId,
176177
} = providerSettings
177178

178179
switch (apiProvider) {
@@ -207,6 +208,9 @@ export function NewRun() {
207208
case "lmstudio":
208209
setValue("model", lmStudioModelId ?? "")
209210
break
211+
case "nebius":
212+
setValue("model", nebiusModelId ?? "")
213+
break
210214
default:
211215
throw new Error(`Unsupported API provider: ${apiProvider}`)
212216
}

evals/packages/types/src/roo-code.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,12 @@ const litellmSchema = z.object({
476476
litellmModelId: z.string().optional(),
477477
})
478478

479+
const nebiusSchema = z.object({
480+
nebiusBaseUrl: z.string().optional(),
481+
nebiusApiKey: z.string().optional(),
482+
nebiusModelId: z.string().optional(),
483+
})
484+
479485
const defaultSchema = z.object({
480486
apiProvider: z.undefined(),
481487
})
@@ -587,6 +593,11 @@ export const providerSettingsSchemaDiscriminated = z
587593
apiProvider: z.literal("litellm"),
588594
}),
589595
),
596+
nebiusSchema.merge(
597+
z.object({
598+
apiProvider: z.literal("nebius"),
599+
}),
600+
),
590601
defaultSchema,
591602
])
592603
.and(genericProviderSettingsSchema)
@@ -614,6 +625,7 @@ export const providerSettingsSchema = z.object({
614625
...groqSchema.shape,
615626
...chutesSchema.shape,
616627
...litellmSchema.shape,
628+
...nebiusSchema.shape,
617629
...genericProviderSettingsSchema.shape,
618630
})
619631

@@ -714,6 +726,9 @@ const providerSettingsRecord: ProviderSettingsRecord = {
714726
litellmBaseUrl: undefined,
715727
litellmApiKey: undefined,
716728
litellmModelId: undefined,
729+
nebiusBaseUrl: undefined,
730+
nebiusApiKey: undefined,
731+
nebiusModelId: undefined,
717732
}
718733

719734
export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys<ProviderSettings>[]

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import { XAIHandler } from "./providers/xai"
2626
import { GroqHandler } from "./providers/groq"
2727
import { ChutesHandler } from "./providers/chutes"
2828
import { LiteLLMHandler } from "./providers/litellm"
29+
import { NebiusHandler } from "./providers/nebius"
2930

3031
export interface SingleCompletionHandler {
3132
completePrompt(prompt: string): Promise<string>
@@ -97,6 +98,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
9798
return new ChutesHandler(options)
9899
case "litellm":
99100
return new LiteLLMHandler(options)
101+
case "nebius":
102+
return new NebiusHandler(options)
100103
default:
101104
return new AnthropicHandler(options)
102105
}

src/api/providers/fetchers/modelCache.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import { getRequestyModels } from "./requesty"
1313
import { getGlamaModels } from "./glama"
1414
import { getUnboundModels } from "./unbound"
1515
import { getLiteLLMModels } from "./litellm"
16-
16+
import { getNebiusModels } from "./nebius"
1717
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
1818

1919
async function writeModels(router: RouterName, data: ModelRecord) {
@@ -75,6 +75,13 @@ export const getModels = async (
7575
models = {}
7676
}
7777
break
78+
case "nebius":
79+
if (apiKey && baseUrl) {
80+
models = await getNebiusModels(apiKey, baseUrl)
81+
} else {
82+
models = {}
83+
}
84+
break
7885
}
7986

8087
if (Object.keys(models).length > 0) {
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import axios from "axios"
2+
import { COMPUTER_USE_MODELS, ModelRecord } from "../../../shared/api"
3+
4+
/**
5+
* Fetches available models from a Nebius server
6+
*
7+
* @param apiKey The API key for the Nebius server
8+
* @param baseUrl The base URL of the Nebius server
9+
* @returns A promise that resolves to a record of model IDs to model info
10+
*/
11+
export async function getNebiusModels(apiKey: string, baseUrl: string): Promise<ModelRecord> {
12+
try {
13+
const headers: Record<string, string> = {
14+
"Content-Type": "application/json",
15+
}
16+
17+
if (apiKey) {
18+
headers["Authorization"] = `Bearer ${apiKey}`
19+
}
20+
21+
const response = await axios.get(`${baseUrl}/v1/model/info`, { headers })
22+
const models: ModelRecord = {}
23+
24+
const computerModels = Array.from(COMPUTER_USE_MODELS)
25+
26+
// Process the model info from the response
27+
if (response.data && response.data.data && Array.isArray(response.data.data)) {
28+
for (const model of response.data.data) {
29+
const modelName = model.model_name
30+
const modelInfo = model.model_info
31+
const nebiusModelName = model?.nebius_params?.model as string | undefined
32+
33+
if (!modelName || !modelInfo || !nebiusModelName) continue
34+
35+
models[modelName] = {
36+
maxTokens: modelInfo.max_tokens || 8192,
37+
contextWindow: modelInfo.max_input_tokens || 200000,
38+
supportsImages: Boolean(modelInfo.supports_vision),
39+
// nebius_params.model may have a prefix like openrouter/
40+
supportsComputerUse: computerModels.some((computer_model) =>
41+
nebiusModelName.endsWith(computer_model),
42+
),
43+
supportsPromptCache: Boolean(modelInfo.supports_prompt_caching),
44+
inputPrice: modelInfo.input_cost_per_token ? modelInfo.input_cost_per_token * 1000000 : undefined,
45+
outputPrice: modelInfo.output_cost_per_token
46+
? modelInfo.output_cost_per_token * 1000000
47+
: undefined,
48+
description: `${modelName} via Nebius proxy`,
49+
}
50+
}
51+
}
52+
53+
return models
54+
} catch (error) {
55+
console.error("Error fetching Nebius models:", error)
56+
return {}
57+
}
58+
}

src/api/providers/nebius.ts

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import OpenAI from "openai"
3+
import { convertToOpenAiMessages } from "../transform/openai-format"
4+
import { ApiStream } from "../transform/stream"
5+
import { convertToR1Format } from "../transform/r1-format"
6+
// Removed unused imports: ApiHandler, nebiusModels, ModelInfo, NebiusModelId
7+
8+
import { SingleCompletionHandler } from "../index"
9+
import { RouterProvider } from "./router-provider"
10+
11+
import { ApiHandlerOptions, nebiusDefaultModelId, nebiusDefaultModelInfo } from "../../shared/api"
12+
13+
export class NebiusHandler extends RouterProvider implements SingleCompletionHandler {
14+
constructor(options: ApiHandlerOptions) {
15+
super({
16+
options,
17+
name: "nebius",
18+
baseURL: "https://api.studio.nebius.ai/v1",
19+
apiKey: options.nebiusApiKey || "dummy-key",
20+
modelId: options.nebiusModelId,
21+
defaultModelId: nebiusDefaultModelId,
22+
defaultModelInfo: nebiusDefaultModelInfo,
23+
})
24+
}
25+
26+
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
27+
const model = this.getModel()
28+
29+
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = model.id.includes("DeepSeek-R1")
30+
? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
31+
: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)]
32+
33+
const stream = await this.client.chat.completions.create({
34+
model: model.id,
35+
messages: openAiMessages,
36+
temperature: 0,
37+
stream: true,
38+
stream_options: { include_usage: true },
39+
})
40+
for await (const chunk of stream) {
41+
const delta = chunk.choices[0]?.delta
42+
if (delta?.content) {
43+
yield {
44+
type: "text",
45+
text: delta.content,
46+
}
47+
}
48+
49+
if (chunk.usage) {
50+
yield {
51+
type: "usage",
52+
inputTokens: chunk.usage.prompt_tokens || 0,
53+
outputTokens: chunk.usage.completion_tokens || 0,
54+
}
55+
}
56+
}
57+
}
58+
59+
async completePrompt(prompt: string): Promise<string> {
60+
const { id: modelId, info } = await this.fetchModel()
61+
62+
try {
63+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
64+
model: modelId,
65+
messages: [{ role: "user", content: prompt }],
66+
}
67+
68+
if (this.supportsTemperature(modelId)) {
69+
requestOptions.temperature = this.options.modelTemperature ?? 0
70+
}
71+
72+
requestOptions.max_tokens = info.maxTokens
73+
74+
const response = await this.client.chat.completions.create(requestOptions)
75+
return response.choices[0]?.message.content || ""
76+
} catch (error) {
77+
if (error instanceof Error) {
78+
throw new Error(`nebius completion error: ${error.message}`)
79+
}
80+
throw error
81+
}
82+
}
83+
}

src/core/webview/webviewMessageHandler.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,15 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We
288288
case "requestRouterModels":
289289
const { apiConfiguration } = await provider.getState()
290290

291-
const [openRouterModels, requestyModels, glamaModels, unboundModels, litellmModels] = await Promise.all([
292-
getModels("openrouter", apiConfiguration.openRouterApiKey),
293-
getModels("requesty", apiConfiguration.requestyApiKey),
294-
getModels("glama", apiConfiguration.glamaApiKey),
295-
getModels("unbound", apiConfiguration.unboundApiKey),
296-
getModels("litellm", apiConfiguration.litellmApiKey, apiConfiguration.litellmBaseUrl),
297-
])
291+
const [openRouterModels, requestyModels, glamaModels, unboundModels, litellmModels, nebiusModels] =
292+
await Promise.all([
293+
getModels("openrouter", apiConfiguration.openRouterApiKey),
294+
getModels("requesty", apiConfiguration.requestyApiKey),
295+
getModels("glama", apiConfiguration.glamaApiKey),
296+
getModels("unbound", apiConfiguration.unboundApiKey),
297+
getModels("litellm", apiConfiguration.litellmApiKey, apiConfiguration.litellmBaseUrl),
298+
getModels("nebius", apiConfiguration.nebiusApiKey, apiConfiguration.nebiusBaseUrl),
299+
])
298300

299301
provider.postMessageToWebview({
300302
type: "routerModels",
@@ -304,6 +306,7 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We
304306
glama: glamaModels,
305307
unbound: unboundModels,
306308
litellm: litellmModels,
309+
nebius: nebiusModels,
307310
},
308311
})
309312
break

src/exports/roo-code.d.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type GlobalSettings = {
3030
| "groq"
3131
| "chutes"
3232
| "litellm"
33+
| "nebius"
3334
)
3435
| undefined
3536
}[]
@@ -198,6 +199,7 @@ type ProviderName =
198199
| "groq"
199200
| "chutes"
200201
| "litellm"
202+
| "nebius"
201203

202204
type ProviderSettings = {
203205
apiProvider?:
@@ -223,6 +225,7 @@ type ProviderSettings = {
223225
| "groq"
224226
| "chutes"
225227
| "litellm"
228+
| "nebius"
226229
)
227230
| undefined
228231
includeMaxTokens?: boolean | undefined
@@ -334,6 +337,9 @@ type ProviderSettings = {
334337
litellmBaseUrl?: string | undefined
335338
litellmApiKey?: string | undefined
336339
litellmModelId?: string | undefined
340+
nebiusBaseUrl?: string | undefined
341+
nebiusApiKey?: string | undefined
342+
nebiusModelId?: string | undefined
337343
}
338344

339345
type ProviderSettingsEntry = {
@@ -362,6 +368,7 @@ type ProviderSettingsEntry = {
362368
| "groq"
363369
| "chutes"
364370
| "litellm"
371+
| "nebius"
365372
)
366373
| undefined
367374
}
@@ -626,6 +633,7 @@ type IpcMessage =
626633
| "groq"
627634
| "chutes"
628635
| "litellm"
636+
| "nebius"
629637
)
630638
| undefined
631639
includeMaxTokens?: boolean | undefined
@@ -737,6 +745,9 @@ type IpcMessage =
737745
litellmBaseUrl?: string | undefined
738746
litellmApiKey?: string | undefined
739747
litellmModelId?: string | undefined
748+
nebiusBaseUrl?: string | undefined
749+
nebiusApiKey?: string | undefined
750+
nebiusModelId?: string | undefined
740751
currentApiConfigName?: string | undefined
741752
listApiConfigMeta?:
742753
| {
@@ -765,6 +776,7 @@ type IpcMessage =
765776
| "groq"
766777
| "chutes"
767778
| "litellm"
779+
| "nebius"
768780
)
769781
| undefined
770782
}[]
@@ -1101,6 +1113,7 @@ type TaskCommand =
11011113
| "groq"
11021114
| "chutes"
11031115
| "litellm"
1116+
| "nebius"
11041117
)
11051118
| undefined
11061119
includeMaxTokens?: boolean | undefined
@@ -1212,6 +1225,9 @@ type TaskCommand =
12121225
litellmBaseUrl?: string | undefined
12131226
litellmApiKey?: string | undefined
12141227
litellmModelId?: string | undefined
1228+
nebiusBaseUrl?: string | undefined
1229+
nebiusApiKey?: string | undefined
1230+
nebiusModelId?: string | undefined
12151231
currentApiConfigName?: string | undefined
12161232
listApiConfigMeta?:
12171233
| {
@@ -1240,6 +1256,7 @@ type TaskCommand =
12401256
| "groq"
12411257
| "chutes"
12421258
| "litellm"
1259+
| "nebius"
12431260
)
12441261
| undefined
12451262
}[]
@@ -1574,6 +1591,7 @@ declare const providerNames: readonly [
15741591
"groq",
15751592
"chutes",
15761593
"litellm",
1594+
"nebius",
15771595
]
15781596
/**
15791597
* RooCodeEvent

0 commit comments

Comments
 (0)