Skip to content

Commit fd2b43f

Browse files
arafatkatzedaniel-lxs
authored andcommitted
Adding nebius to roocode
1 parent 61a381a commit fd2b43f

File tree

20 files changed

+440
-5
lines changed

20 files changed

+440
-5
lines changed

evals/apps/web/src/app/runs/new/new-run.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ export function NewRun() {
176176
ollamaModelId,
177177
lmStudioModelId,
178178
openAiModelId,
179+
nebiusModelId,
179180
} = providerSettings
180181

181182
switch (apiProvider) {
@@ -210,6 +211,9 @@ export function NewRun() {
210211
case "lmstudio":
211212
setValue("model", lmStudioModelId ?? "")
212213
break
214+
case "nebius":
215+
setValue("model", nebiusModelId ?? "")
216+
break
213217
default:
214218
throw new Error(`Unsupported API provider: ${apiProvider}`)
215219
}

evals/packages/types/src/roo-code.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,12 @@ const litellmSchema = z.object({
478478
litellmModelId: z.string().optional(),
479479
})
480480

481+
const nebiusSchema = z.object({
482+
nebiusBaseUrl: z.string().optional(),
483+
nebiusApiKey: z.string().optional(),
484+
nebiusModelId: z.string().optional(),
485+
})
486+
481487
const defaultSchema = z.object({
482488
apiProvider: z.undefined(),
483489
})
@@ -589,6 +595,11 @@ export const providerSettingsSchemaDiscriminated = z
589595
apiProvider: z.literal("litellm"),
590596
}),
591597
),
598+
nebiusSchema.merge(
599+
z.object({
600+
apiProvider: z.literal("nebius"),
601+
}),
602+
),
592603
defaultSchema,
593604
])
594605
.and(genericProviderSettingsSchema)
@@ -616,6 +627,7 @@ export const providerSettingsSchema = z.object({
616627
...groqSchema.shape,
617628
...chutesSchema.shape,
618629
...litellmSchema.shape,
630+
...nebiusSchema.shape,
619631
...genericProviderSettingsSchema.shape,
620632
})
621633

@@ -716,6 +728,9 @@ const providerSettingsRecord: ProviderSettingsRecord = {
716728
litellmBaseUrl: undefined,
717729
litellmApiKey: undefined,
718730
litellmModelId: undefined,
731+
nebiusBaseUrl: undefined,
732+
nebiusApiKey: undefined,
733+
nebiusModelId: undefined,
719734
}
720735

721736
export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys<ProviderSettings>[]

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import { XAIHandler } from "./providers/xai"
2424
import { GroqHandler } from "./providers/groq"
2525
import { ChutesHandler } from "./providers/chutes"
2626
import { LiteLLMHandler } from "./providers/litellm"
27+
import { NebiusHandler } from "./providers/nebius"
2728

2829
export interface SingleCompletionHandler {
2930
completePrompt(prompt: string): Promise<string>
@@ -104,6 +105,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
104105
return new ChutesHandler(options)
105106
case "litellm":
106107
return new LiteLLMHandler(options)
108+
case "nebius":
109+
return new NebiusHandler(options)
107110
default:
108111
return new AnthropicHandler(options)
109112
}

src/api/providers/fetchers/modelCache.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@ import NodeCache from "node-cache"
55

66
import { ContextProxy } from "../../../core/config/ContextProxy"
77
import { getCacheDirectoryPath } from "../../../utils/storage"
8-
import { RouterName, ModelRecord } from "../../../shared/api"
8+
import { RouterName, ModelRecord, GetModelsOptions } from "../../../shared/api"
99
import { fileExistsAtPath } from "../../../utils/fs"
1010

1111
import { getOpenRouterModels } from "./openrouter"
1212
import { getRequestyModels } from "./requesty"
1313
import { getGlamaModels } from "./glama"
1414
import { getUnboundModels } from "./unbound"
1515
import { getLiteLLMModels } from "./litellm"
16-
import { GetModelsOptions } from "../../../shared/api"
16+
import { getNebiusModels } from "./nebius"
17+
1718
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
1819

1920
async function writeModels(router: RouterName, data: ModelRecord) {
@@ -68,6 +69,10 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
6869
// Type safety ensures apiKey and baseUrl are always provided for litellm
6970
models = await getLiteLLMModels(options.apiKey, options.baseUrl)
7071
break
72+
case "nebius":
73+
// Type safety ensures apiKey and baseUrl are always provided for nebius
74+
models = await getNebiusModels(options.apiKey, options.baseUrl)
75+
break
7176
default: {
7277
// Ensures router is exhaustively checked if RouterName is a strict union
7378
const exhaustiveCheck: never = provider
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import axios from "axios"
2+
import { OPEN_ROUTER_COMPUTER_USE_MODELS, ModelRecord } from "../../../shared/api"
3+
4+
/**
5+
* Fetches available models from a Nebius server
6+
*
7+
* @param apiKey The API key for the Nebius server
8+
* @param baseUrl The base URL of the Nebius server
9+
* @returns A promise that resolves to a record of model IDs to model info
10+
*/
11+
export async function getNebiusModels(apiKey: string, baseUrl: string): Promise<ModelRecord> {
12+
try {
13+
const headers: Record<string, string> = {
14+
"Content-Type": "application/json",
15+
}
16+
17+
if (apiKey) {
18+
headers["Authorization"] = `Bearer ${apiKey}`
19+
}
20+
21+
const response = await axios.get(`${baseUrl}/v1/model/info`, { headers })
22+
const models: ModelRecord = {}
23+
24+
const computerModels = Array.from(OPEN_ROUTER_COMPUTER_USE_MODELS)
25+
26+
// Process the model info from the response
27+
if (response.data && response.data.data && Array.isArray(response.data.data)) {
28+
for (const model of response.data.data) {
29+
const modelName = model.model_name
30+
const modelInfo = model.model_info
31+
const nebiusModelName = model?.nebius_params?.model as string | undefined
32+
33+
if (!modelName || !modelInfo || !nebiusModelName) continue
34+
35+
models[modelName] = {
36+
maxTokens: modelInfo.max_tokens || 8192,
37+
contextWindow: modelInfo.max_input_tokens || 200000,
38+
supportsImages: Boolean(modelInfo.supports_vision),
39+
// nebius_params.model may have a prefix like openrouter/
40+
supportsComputerUse: computerModels.some((computer_model) =>
41+
nebiusModelName.endsWith(computer_model),
42+
),
43+
supportsPromptCache: Boolean(modelInfo.supports_prompt_caching),
44+
inputPrice: modelInfo.input_cost_per_token ? modelInfo.input_cost_per_token * 1000000 : undefined,
45+
outputPrice: modelInfo.output_cost_per_token
46+
? modelInfo.output_cost_per_token * 1000000
47+
: undefined,
48+
description: `${modelName} via Nebius proxy`,
49+
}
50+
}
51+
}
52+
53+
return models
54+
} catch (error) {
55+
console.error("Error fetching Nebius models:", error)
56+
return {}
57+
}
58+
}

src/api/providers/nebius.ts

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import OpenAI from "openai"
3+
import { convertToOpenAiMessages } from "../transform/openai-format"
4+
import { ApiStream } from "../transform/stream"
5+
import { convertToR1Format } from "../transform/r1-format"
6+
// Removed unused imports: ApiHandler, nebiusModels, ModelInfo, NebiusModelId
7+
8+
import { SingleCompletionHandler } from "../index"
9+
import { RouterProvider } from "./router-provider"
10+
11+
import { ApiHandlerOptions, nebiusDefaultModelId, nebiusDefaultModelInfo } from "../../shared/api"
12+
13+
export class NebiusHandler extends RouterProvider implements SingleCompletionHandler {
14+
constructor(options: ApiHandlerOptions) {
15+
super({
16+
options,
17+
name: "nebius",
18+
baseURL: "https://api.studio.nebius.ai/v1",
19+
apiKey: options.nebiusApiKey || "dummy-key",
20+
modelId: options.nebiusModelId,
21+
defaultModelId: nebiusDefaultModelId,
22+
defaultModelInfo: nebiusDefaultModelInfo,
23+
})
24+
}
25+
26+
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
27+
const model = this.getModel()
28+
29+
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = model.id.includes("DeepSeek-R1")
30+
? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
31+
: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)]
32+
33+
const stream = await this.client.chat.completions.create({
34+
model: model.id,
35+
messages: openAiMessages,
36+
temperature: 0,
37+
stream: true,
38+
stream_options: { include_usage: true },
39+
})
40+
for await (const chunk of stream) {
41+
const delta = chunk.choices[0]?.delta
42+
if (delta?.content) {
43+
yield {
44+
type: "text",
45+
text: delta.content,
46+
}
47+
}
48+
49+
if (chunk.usage) {
50+
yield {
51+
type: "usage",
52+
inputTokens: chunk.usage.prompt_tokens || 0,
53+
outputTokens: chunk.usage.completion_tokens || 0,
54+
}
55+
}
56+
}
57+
}
58+
59+
async completePrompt(prompt: string): Promise<string> {
60+
const { id: modelId, info } = await this.fetchModel()
61+
62+
try {
63+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
64+
model: modelId,
65+
messages: [{ role: "user", content: prompt }],
66+
}
67+
68+
if (this.supportsTemperature(modelId)) {
69+
requestOptions.temperature = this.options.modelTemperature ?? 0
70+
}
71+
72+
requestOptions.max_tokens = info.maxTokens
73+
74+
const response = await this.client.chat.completions.create(requestOptions)
75+
return response.choices[0]?.message.content || ""
76+
} catch (error) {
77+
if (error instanceof Error) {
78+
throw new Error(`nebius completion error: ${error.message}`)
79+
}
80+
throw error
81+
}
82+
}
83+
}

src/core/webview/__tests__/ClineProvider.test.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,6 +2253,7 @@ describe("ClineProvider - Router Models", () => {
22532253
glama: mockModels,
22542254
unbound: mockModels,
22552255
litellm: mockModels,
2256+
nebius: {},
22562257
},
22572258
})
22582259
})
@@ -2294,6 +2295,7 @@ describe("ClineProvider - Router Models", () => {
22942295
glama: mockModels,
22952296
unbound: {},
22962297
litellm: {},
2298+
nebius: {},
22972299
},
22982300
})
22992301

@@ -2391,6 +2393,7 @@ describe("ClineProvider - Router Models", () => {
23912393
glama: mockModels,
23922394
unbound: mockModels,
23932395
litellm: {},
2396+
nebius: {},
23942397
},
23952398
})
23962399
})

src/core/webview/__tests__/webviewMessageHandler.test.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
7070
glama: mockModels,
7171
unbound: mockModels,
7272
litellm: mockModels,
73+
nebius: {},
7374
},
7475
})
7576
})
@@ -155,6 +156,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
155156
glama: mockModels,
156157
unbound: mockModels,
157158
litellm: {},
159+
nebius: {},
158160
},
159161
})
160162
})
@@ -190,6 +192,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
190192
glama: mockModels,
191193
unbound: {},
192194
litellm: {},
195+
nebius: {},
193196
},
194197
})
195198

src/core/webview/webviewMessageHandler.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We
295295
glama: {},
296296
unbound: {},
297297
litellm: {},
298+
nebius: {},
298299
}
299300

300301
const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
@@ -325,6 +326,15 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We
325326
})
326327
}
327328

329+
const nebiusApiKey = apiConfiguration.nebiusApiKey || message?.values?.nebiusApiKey
330+
const nebiusBaseUrl = apiConfiguration.nebiusBaseUrl || message?.values?.nebiusBaseUrl
331+
if (nebiusApiKey && nebiusBaseUrl) {
332+
modelFetchPromises.push({
333+
key: "nebius",
334+
options: { provider: "nebius", apiKey: nebiusApiKey, baseUrl: nebiusBaseUrl },
335+
})
336+
}
337+
328338
const results = await Promise.allSettled(
329339
modelFetchPromises.map(async ({ key, options }) => {
330340
const models = await safeGetModels(options)

0 commit comments

Comments
 (0)