From fd2b43f7ae1bb2411e7327025b200f6b6605b97d Mon Sep 17 00:00:00 2001 From: arafatkatze Date: Thu, 22 May 2025 05:04:42 +0400 Subject: [PATCH 1/4] Adding nebius to roocode --- evals/apps/web/src/app/runs/new/new-run.tsx | 4 + evals/packages/types/src/roo-code.ts | 15 +++ src/api/index.ts | 3 + src/api/providers/fetchers/modelCache.ts | 9 +- src/api/providers/fetchers/nebius.ts | 58 ++++++++++ src/api/providers/nebius.ts | 83 ++++++++++++++ .../webview/__tests__/ClineProvider.test.ts | 3 + .../__tests__/webviewMessageHandler.test.ts | 3 + src/core/webview/webviewMessageHandler.ts | 10 ++ src/exports/roo-code.d.ts | 18 ++++ src/exports/types.ts | 17 +++ src/schemas/index.ts | 17 ++- src/shared/api.ts | 102 +++++++++++++++++- .../src/components/settings/ApiOptions.tsx | 14 +++ .../src/components/settings/ModelPicker.tsx | 8 +- .../src/components/settings/constants.ts | 1 + .../components/settings/providers/Nebius.tsx | 65 +++++++++++ .../components/settings/providers/index.ts | 1 + .../components/ui/hooks/useSelectedModel.ts | 6 ++ webview-ui/src/utils/validate.ts | 8 ++ 20 files changed, 440 insertions(+), 5 deletions(-) create mode 100644 src/api/providers/fetchers/nebius.ts create mode 100644 src/api/providers/nebius.ts create mode 100644 webview-ui/src/components/settings/providers/Nebius.tsx diff --git a/evals/apps/web/src/app/runs/new/new-run.tsx b/evals/apps/web/src/app/runs/new/new-run.tsx index 47fe8a89c4..a35b8f3cf5 100644 --- a/evals/apps/web/src/app/runs/new/new-run.tsx +++ b/evals/apps/web/src/app/runs/new/new-run.tsx @@ -176,6 +176,7 @@ export function NewRun() { ollamaModelId, lmStudioModelId, openAiModelId, + nebiusModelId, } = providerSettings switch (apiProvider) { @@ -210,6 +211,9 @@ export function NewRun() { case "lmstudio": setValue("model", lmStudioModelId ?? "") break + case "nebius": + setValue("model", nebiusModelId ?? "") + break default: throw new Error(`Unsupported API provider: ${apiProvider}`) } diff --git a/evals/packages/types/src/roo-code.ts b/evals/packages/types/src/roo-code.ts index b397d37b64..3a9f988594 100644 --- a/evals/packages/types/src/roo-code.ts +++ b/evals/packages/types/src/roo-code.ts @@ -478,6 +478,12 @@ const litellmSchema = z.object({ litellmModelId: z.string().optional(), }) +const nebiusSchema = z.object({ + nebiusBaseUrl: z.string().optional(), + nebiusApiKey: z.string().optional(), + nebiusModelId: z.string().optional(), +}) + const defaultSchema = z.object({ apiProvider: z.undefined(), }) @@ -589,6 +595,11 @@ export const providerSettingsSchemaDiscriminated = z apiProvider: z.literal("litellm"), }), ), + nebiusSchema.merge( + z.object({ + apiProvider: z.literal("nebius"), + }), + ), defaultSchema, ]) .and(genericProviderSettingsSchema) @@ -616,6 +627,7 @@ export const providerSettingsSchema = z.object({ ...groqSchema.shape, ...chutesSchema.shape, ...litellmSchema.shape, + ...nebiusSchema.shape, ...genericProviderSettingsSchema.shape, }) @@ -716,6 +728,9 @@ const providerSettingsRecord: ProviderSettingsRecord = { litellmBaseUrl: undefined, litellmApiKey: undefined, litellmModelId: undefined, + nebiusBaseUrl: undefined, + nebiusApiKey: undefined, + nebiusModelId: undefined, } export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys[] diff --git a/src/api/index.ts b/src/api/index.ts index f831e58e8d..bc1e7b3b93 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -24,6 +24,7 @@ import { XAIHandler } from "./providers/xai" import { GroqHandler } from "./providers/groq" import { ChutesHandler } from "./providers/chutes" import { LiteLLMHandler } from "./providers/litellm" +import { NebiusHandler } from "./providers/nebius" export interface SingleCompletionHandler { completePrompt(prompt: string): Promise @@ -104,6 +105,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new ChutesHandler(options) case "litellm": return new LiteLLMHandler(options) + case "nebius": + return new NebiusHandler(options) default: return new AnthropicHandler(options) } diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 12d636bc46..84db198130 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -5,7 +5,7 @@ import NodeCache from "node-cache" import { ContextProxy } from "../../../core/config/ContextProxy" import { getCacheDirectoryPath } from "../../../utils/storage" -import { RouterName, ModelRecord } from "../../../shared/api" +import { RouterName, ModelRecord, GetModelsOptions } from "../../../shared/api" import { fileExistsAtPath } from "../../../utils/fs" import { getOpenRouterModels } from "./openrouter" @@ -13,7 +13,8 @@ import { getRequestyModels } from "./requesty" import { getGlamaModels } from "./glama" import { getUnboundModels } from "./unbound" import { getLiteLLMModels } from "./litellm" -import { GetModelsOptions } from "../../../shared/api" +import { getNebiusModels } from "./nebius" + const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { @@ -68,6 +69,10 @@ export const getModels = async (options: GetModelsOptions): Promise // Type safety ensures apiKey and baseUrl are always provided for litellm models = await getLiteLLMModels(options.apiKey, options.baseUrl) break + case "nebius": + // Type safety ensures apiKey and baseUrl are always provided for nebius + models = await getNebiusModels(options.apiKey, options.baseUrl) + break default: { // Ensures router is exhaustively checked if RouterName is a strict union const exhaustiveCheck: never = provider diff --git a/src/api/providers/fetchers/nebius.ts b/src/api/providers/fetchers/nebius.ts new file mode 100644 index 0000000000..f68cd5d75e --- /dev/null +++ b/src/api/providers/fetchers/nebius.ts @@ -0,0 +1,58 @@ +import axios from "axios" +import { OPEN_ROUTER_COMPUTER_USE_MODELS, ModelRecord } from "../../../shared/api" + +/** + * Fetches available models from a Nebius server + * + * @param apiKey The API key for the Nebius server + * @param baseUrl The base URL of the Nebius server + * @returns A promise that resolves to a record of model IDs to model info + */ +export async function getNebiusModels(apiKey: string, baseUrl: string): Promise { + try { + const headers: Record = { + "Content-Type": "application/json", + } + + if (apiKey) { + headers["Authorization"] = `Bearer ${apiKey}` + } + + const response = await axios.get(`${baseUrl}/v1/model/info`, { headers }) + const models: ModelRecord = {} + + const computerModels = Array.from(OPEN_ROUTER_COMPUTER_USE_MODELS) + + // Process the model info from the response + if (response.data && response.data.data && Array.isArray(response.data.data)) { + for (const model of response.data.data) { + const modelName = model.model_name + const modelInfo = model.model_info + const nebiusModelName = model?.nebius_params?.model as string | undefined + + if (!modelName || !modelInfo || !nebiusModelName) continue + + models[modelName] = { + maxTokens: modelInfo.max_tokens || 8192, + contextWindow: modelInfo.max_input_tokens || 200000, + supportsImages: Boolean(modelInfo.supports_vision), + // nebius_params.model may have a prefix like openrouter/ + supportsComputerUse: computerModels.some((computer_model) => + nebiusModelName.endsWith(computer_model), + ), + supportsPromptCache: Boolean(modelInfo.supports_prompt_caching), + inputPrice: modelInfo.input_cost_per_token ? modelInfo.input_cost_per_token * 1000000 : undefined, + outputPrice: modelInfo.output_cost_per_token + ? modelInfo.output_cost_per_token * 1000000 + : undefined, + description: `${modelName} via Nebius proxy`, + } + } + } + + return models + } catch (error) { + console.error("Error fetching Nebius models:", error) + return {} + } +} diff --git a/src/api/providers/nebius.ts b/src/api/providers/nebius.ts new file mode 100644 index 0000000000..0b31973c1d --- /dev/null +++ b/src/api/providers/nebius.ts @@ -0,0 +1,83 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" +import { convertToR1Format } from "../transform/r1-format" +// Removed unused imports: ApiHandler, nebiusModels, ModelInfo, NebiusModelId + +import { SingleCompletionHandler } from "../index" +import { RouterProvider } from "./router-provider" + +import { ApiHandlerOptions, nebiusDefaultModelId, nebiusDefaultModelInfo } from "../../shared/api" + +export class NebiusHandler extends RouterProvider implements SingleCompletionHandler { + constructor(options: ApiHandlerOptions) { + super({ + options, + name: "nebius", + baseURL: "https://api.studio.nebius.ai/v1", + apiKey: options.nebiusApiKey || "dummy-key", + modelId: options.nebiusModelId, + defaultModelId: nebiusDefaultModelId, + defaultModelInfo: nebiusDefaultModelInfo, + }) + } + + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const model = this.getModel() + + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = model.id.includes("DeepSeek-R1") + ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) + : [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)] + + const stream = await this.client.chat.completions.create({ + model: model.id, + messages: openAiMessages, + temperature: 0, + stream: true, + stream_options: { include_usage: true }, + }) + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, + } + } + } + } + + async completePrompt(prompt: string): Promise { + const { id: modelId, info } = await this.fetchModel() + + try { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: modelId, + messages: [{ role: "user", content: prompt }], + } + + if (this.supportsTemperature(modelId)) { + requestOptions.temperature = this.options.modelTemperature ?? 0 + } + + requestOptions.max_tokens = info.maxTokens + + const response = await this.client.chat.completions.create(requestOptions) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`nebius completion error: ${error.message}`) + } + throw error + } + } +} diff --git a/src/core/webview/__tests__/ClineProvider.test.ts b/src/core/webview/__tests__/ClineProvider.test.ts index 72a40e7044..3acbe101bf 100644 --- a/src/core/webview/__tests__/ClineProvider.test.ts +++ b/src/core/webview/__tests__/ClineProvider.test.ts @@ -2253,6 +2253,7 @@ describe("ClineProvider - Router Models", () => { glama: mockModels, unbound: mockModels, litellm: mockModels, + nebius: {}, }, }) }) @@ -2294,6 +2295,7 @@ describe("ClineProvider - Router Models", () => { glama: mockModels, unbound: {}, litellm: {}, + nebius: {}, }, }) @@ -2391,6 +2393,7 @@ describe("ClineProvider - Router Models", () => { glama: mockModels, unbound: mockModels, litellm: {}, + nebius: {}, }, }) }) diff --git a/src/core/webview/__tests__/webviewMessageHandler.test.ts b/src/core/webview/__tests__/webviewMessageHandler.test.ts index 7f3bc49654..dd22b4f013 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.test.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.test.ts @@ -70,6 +70,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { glama: mockModels, unbound: mockModels, litellm: mockModels, + nebius: {}, }, }) }) @@ -155,6 +156,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { glama: mockModels, unbound: mockModels, litellm: {}, + nebius: {}, }, }) }) @@ -190,6 +192,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { glama: mockModels, unbound: {}, litellm: {}, + nebius: {}, }, }) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index c8fd3608e4..c429b39f50 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -295,6 +295,7 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We glama: {}, unbound: {}, litellm: {}, + nebius: {}, } const safeGetModels = async (options: GetModelsOptions): Promise => { @@ -325,6 +326,15 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We }) } + const nebiusApiKey = apiConfiguration.nebiusApiKey || message?.values?.nebiusApiKey + const nebiusBaseUrl = apiConfiguration.nebiusBaseUrl || message?.values?.nebiusBaseUrl + if (nebiusApiKey && nebiusBaseUrl) { + modelFetchPromises.push({ + key: "nebius", + options: { provider: "nebius", apiKey: nebiusApiKey, baseUrl: nebiusBaseUrl }, + }) + } + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index 904eba8530..0babf59e80 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -30,6 +30,7 @@ type GlobalSettings = { | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined }[] @@ -227,6 +228,7 @@ type ProviderName = | "groq" | "chutes" | "litellm" + | "nebius" type ProviderSettings = { apiProvider?: @@ -252,6 +254,7 @@ type ProviderSettings = { | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined includeMaxTokens?: boolean | undefined @@ -366,6 +369,9 @@ type ProviderSettings = { litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined + nebiusBaseUrl?: string | undefined + nebiusApiKey?: string | undefined + nebiusModelId?: string | undefined codeIndexOpenAiKey?: string | undefined codeIndexQdrantApiKey?: string | undefined } @@ -396,6 +402,7 @@ type ProviderSettingsEntry = { | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined } @@ -663,6 +670,7 @@ type IpcMessage = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined includeMaxTokens?: boolean | undefined @@ -779,6 +787,9 @@ type IpcMessage = litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined + nebiusBaseUrl?: string | undefined + nebiusApiKey?: string | undefined + nebiusModelId?: string | undefined codeIndexOpenAiKey?: string | undefined codeIndexQdrantApiKey?: string | undefined currentApiConfigName?: string | undefined @@ -809,6 +820,7 @@ type IpcMessage = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined }[] @@ -1175,6 +1187,7 @@ type TaskCommand = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined includeMaxTokens?: boolean | undefined @@ -1291,6 +1304,9 @@ type TaskCommand = litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined + nebiusBaseUrl?: string | undefined + nebiusApiKey?: string | undefined + nebiusModelId?: string | undefined codeIndexOpenAiKey?: string | undefined codeIndexQdrantApiKey?: string | undefined currentApiConfigName?: string | undefined @@ -1321,6 +1337,7 @@ type TaskCommand = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined }[] @@ -1686,6 +1703,7 @@ declare const providerNames: readonly [ "groq", "chutes", "litellm", + "nebius", ] /** * RooCodeEvent diff --git a/src/exports/types.ts b/src/exports/types.ts index 6f4989df62..87648c7048 100644 --- a/src/exports/types.ts +++ b/src/exports/types.ts @@ -30,6 +30,7 @@ type GlobalSettings = { | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined }[] @@ -229,6 +230,7 @@ type ProviderName = | "groq" | "chutes" | "litellm" + | "nebius" export type { ProviderName } @@ -256,6 +258,7 @@ type ProviderSettings = { | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined includeMaxTokens?: boolean | undefined @@ -370,6 +373,9 @@ type ProviderSettings = { litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined + nebiusBaseUrl?: string | undefined + nebiusApiKey?: string | undefined + nebiusModelId?: string | undefined codeIndexOpenAiKey?: string | undefined codeIndexQdrantApiKey?: string | undefined } @@ -402,6 +408,7 @@ type ProviderSettingsEntry = { | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined } @@ -677,6 +684,7 @@ type IpcMessage = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined includeMaxTokens?: boolean | undefined @@ -793,6 +801,9 @@ type IpcMessage = litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined + nebiusBaseUrl?: string | undefined + nebiusApiKey?: string | undefined + nebiusModelId?: string | undefined codeIndexOpenAiKey?: string | undefined codeIndexQdrantApiKey?: string | undefined currentApiConfigName?: string | undefined @@ -823,6 +834,7 @@ type IpcMessage = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined }[] @@ -1191,6 +1203,7 @@ type TaskCommand = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined includeMaxTokens?: boolean | undefined @@ -1307,6 +1320,9 @@ type TaskCommand = litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined + nebiusBaseUrl?: string | undefined + nebiusApiKey?: string | undefined + nebiusModelId?: string | undefined codeIndexOpenAiKey?: string | undefined codeIndexQdrantApiKey?: string | undefined currentApiConfigName?: string | undefined @@ -1337,6 +1353,7 @@ type TaskCommand = | "groq" | "chutes" | "litellm" + | "nebius" ) | undefined }[] diff --git a/src/schemas/index.ts b/src/schemas/index.ts index 4fb893ae1f..aac28f5408 100644 --- a/src/schemas/index.ts +++ b/src/schemas/index.ts @@ -105,6 +105,7 @@ export const providerNames = [ "groq", "chutes", "litellm", + "nebius", ] as const export const providerNamesSchema = z.enum(providerNames) @@ -609,6 +610,12 @@ const litellmSchema = baseProviderSettingsSchema.extend({ litellmModelId: z.string().optional(), }) +const nebiusSchema = baseProviderSettingsSchema.extend({ + nebiusBaseUrl: z.string().optional(), + nebiusApiKey: z.string().optional(), + nebiusModelId: z.string().optional(), +}) + const defaultSchema = z.object({ apiProvider: z.undefined(), }) @@ -635,6 +642,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv groqSchema.merge(z.object({ apiProvider: z.literal("groq") })), chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })), litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })), + nebiusSchema.merge(z.object({ apiProvider: z.literal("nebius") })), defaultSchema, ]) @@ -661,7 +669,8 @@ export const providerSettingsSchema = z.object({ ...groqSchema.shape, ...chutesSchema.shape, ...litellmSchema.shape, - ...codebaseIndexProviderSchema.shape + ...nebiusSchema.shape, + ...codebaseIndexProviderSchema.shape }) export type ProviderSettings = z.infer @@ -764,6 +773,10 @@ const providerSettingsRecord: ProviderSettingsRecord = { litellmBaseUrl: undefined, litellmApiKey: undefined, litellmModelId: undefined, + // Nebius LLM + nebiusBaseUrl: undefined, + nebiusApiKey: undefined, + nebiusModelId: undefined, } export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys[] @@ -973,6 +986,7 @@ export type SecretState = Pick< | "groqApiKey" | "chutesApiKey" | "litellmApiKey" + | "nebiusApiKey" | "codeIndexOpenAiKey" | "codeIndexQdrantApiKey" > @@ -999,6 +1013,7 @@ const secretStateRecord: SecretStateRecord = { groqApiKey: undefined, chutesApiKey: undefined, litellmApiKey: undefined, + nebiusApiKey: undefined, codeIndexOpenAiKey: undefined, codeIndexQdrantApiKey: undefined, } diff --git a/src/shared/api.ts b/src/shared/api.ts index d66aca6721..eb6c2ee436 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1261,6 +1261,105 @@ export const litellmDefaultModelInfo: ModelInfo = { cacheWritesPrice: 3.75, cacheReadsPrice: 0.3, } + +export const nebiusDefaultModelInfo: ModelInfo = { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsComputerUse: true, + supportsPromptCache: true, + inputPrice: 3.0, + outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, +} +// Nebius AI Studio +// https://docs.nebius.com/studio/inference/models +export const nebiusModels = { + "deepseek-ai/DeepSeek-V3": { + maxTokens: 32_000, + contextWindow: 96_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.5, + outputPrice: 1.5, + }, + "deepseek-ai/DeepSeek-V3-0324-fast": { + maxTokens: 128_000, + contextWindow: 128_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 2, + outputPrice: 6, + }, + "deepseek-ai/DeepSeek-R1": { + maxTokens: 32_000, + contextWindow: 96_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.8, + outputPrice: 2.4, + }, + "deepseek-ai/DeepSeek-R1-fast": { + maxTokens: 32_000, + contextWindow: 96_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 2, + outputPrice: 6, + }, + "meta-llama/Llama-3.3-70B-Instruct-fast": { + maxTokens: 32_000, + contextWindow: 96_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.25, + outputPrice: 0.75, + }, + "Qwen/Qwen2.5-32B-Instruct-fast": { + maxTokens: 8_192, + contextWindow: 32_768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.13, + outputPrice: 0.4, + }, + "Qwen/Qwen2.5-Coder-32B-Instruct-fast": { + maxTokens: 128_000, + contextWindow: 128_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.1, + outputPrice: 0.3, + }, + "Qwen/Qwen3-4B-fast": { + maxTokens: 32_000, + contextWindow: 41_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.08, + outputPrice: 0.24, + }, + "Qwen/Qwen3-30B-A3B-fast": { + maxTokens: 32_000, + contextWindow: 41_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.3, + outputPrice: 0.9, + }, + "Qwen/Qwen3-235B-A22B": { + maxTokens: 32_000, + contextWindow: 41_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.2, + outputPrice: 0.6, + }, +} as const satisfies Record +export type NebiusModelId = keyof typeof nebiusModels +export const nebiusDefaultModelId = "Qwen/Qwen2.5-32B-Instruct-fast" satisfies NebiusModelId + // xAI // https://docs.x.ai/docs/api-reference export type XAIModelId = keyof typeof xaiModels @@ -1929,7 +2028,7 @@ export const OPEN_ROUTER_REQUIRED_REASONING_BUDGET_MODELS = new Set([ "google/gemini-2.5-flash-preview-05-20:thinking", ]) -const routerNames = ["openrouter", "requesty", "glama", "unbound", "litellm"] as const +const routerNames = ["openrouter", "requesty", "glama", "unbound", "litellm", "nebius"] as const export type RouterName = (typeof routerNames)[number] @@ -2002,3 +2101,4 @@ export type GetModelsOptions = | { provider: "requesty"; apiKey?: string } | { provider: "unbound"; apiKey?: string } | { provider: "litellm"; apiKey: string; baseUrl: string } + | { provider: "nebius"; apiKey: string; baseUrl: string } diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 20d10cf459..e45d7cdf50 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -11,6 +11,7 @@ import { glamaDefaultModelId, unboundDefaultModelId, litellmDefaultModelId, + nebiusDefaultModelId, } from "@roo/shared/api" import { vscode } from "@src/utils/vscode" @@ -30,6 +31,7 @@ import { Groq, LMStudio, LiteLLM, + Nebius, Mistral, Ollama, OpenAI, @@ -168,6 +170,8 @@ const ApiOptions = ({ apiConfiguration?.lmStudioBaseUrl, apiConfiguration?.litellmBaseUrl, apiConfiguration?.litellmApiKey, + apiConfiguration?.nebiusBaseUrl, + apiConfiguration?.nebiusApiKey, customHeaders, ], ) @@ -225,6 +229,11 @@ const ApiOptions = ({ setApiConfigurationField("litellmModelId", litellmDefaultModelId) } break + case "nebius": + if (!apiConfiguration.nebiusModelId) { + setApiConfigurationField("nebiusModelId", nebiusDefaultModelId) + } + break } setApiConfigurationField("apiProvider", value) @@ -236,6 +245,7 @@ const ApiOptions = ({ apiConfiguration.unboundModelId, apiConfiguration.requestyModelId, apiConfiguration.litellmModelId, + apiConfiguration.nebiusModelId, ], ) @@ -393,6 +403,10 @@ const ApiOptions = ({ )} + {selectedProvider === "nebius" && ( + + )} + {selectedProvider === "human-relay" && ( <>
diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 4ac7f530a6..457405dd1a 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -25,7 +25,13 @@ import { ModelInfoView } from "./ModelInfoView" type ModelIdKey = keyof Pick< ProviderSettings, - "glamaModelId" | "openRouterModelId" | "unboundModelId" | "requestyModelId" | "openAiModelId" | "litellmModelId" + | "glamaModelId" + | "openRouterModelId" + | "unboundModelId" + | "requestyModelId" + | "openAiModelId" + | "litellmModelId" + | "nebiusModelId" > interface ModelPickerProps { diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index 295088f9de..11599ab16f 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -49,6 +49,7 @@ export const PROVIDERS = [ { value: "groq", label: "Groq" }, { value: "chutes", label: "Chutes AI" }, { value: "litellm", label: "LiteLLM" }, + { value: "nebius", label: "Nebius" }, ].sort((a, b) => a.label.localeCompare(b.label)) export const VERTEX_REGIONS = [ diff --git a/webview-ui/src/components/settings/providers/Nebius.tsx b/webview-ui/src/components/settings/providers/Nebius.tsx new file mode 100644 index 0000000000..c24de27ca4 --- /dev/null +++ b/webview-ui/src/components/settings/providers/Nebius.tsx @@ -0,0 +1,65 @@ +import { useCallback } from "react" +import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" + +import { ProviderSettings, RouterModels, nebiusDefaultModelId } from "@roo/shared/api" + +import { useAppTranslation } from "@src/i18n/TranslationContext" + +import { inputEventTransform } from "../transforms" +import { ModelPicker } from "../ModelPicker" + +type NebiusProps = { + apiConfiguration: ProviderSettings + setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void + routerModels?: RouterModels +} + +export const Nebius = ({ apiConfiguration, setApiConfigurationField, routerModels }: NebiusProps) => { + const { t } = useAppTranslation() + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + return ( + <> + + + + + + + + +
+ {t("settings:providers.apiKeyStorageNotice")} +
+ + + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index b244fb515c..386aabae60 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -17,3 +17,4 @@ export { Vertex } from "./Vertex" export { VSCodeLM } from "./VSCodeLM" export { XAI } from "./XAI" export { LiteLLM } from "./LiteLLM" +export { Nebius } from "./Nebius" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index e28b24824b..9007b9b2ab 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -31,6 +31,7 @@ import { glamaDefaultModelId, unboundDefaultModelId, litellmDefaultModelId, + nebiusDefaultModelId, } from "@roo/shared/api" import { useRouterModels } from "./useRouterModels" @@ -120,6 +121,11 @@ function getSelectedModel({ ? { id, info } : { id: litellmDefaultModelId, info: routerModels.litellm[litellmDefaultModelId] } } + case "nebius": { + const id = apiConfiguration.nebiusModelId ?? nebiusDefaultModelId + const info = routerModels.nebius[id] + return info ? { id, info } : { id: nebiusDefaultModelId, info: routerModels.nebius[nebiusDefaultModelId] } + } case "xai": { const id = apiConfiguration.apiModelId ?? xaiDefaultModelId const info = xaiModels[id as keyof typeof xaiModels] diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 0765fffeda..993e5869a6 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -29,6 +29,11 @@ export function validateApiConfiguration(apiConfiguration: ProviderSettings): st return i18next.t("settings:validation.apiKey") } break + case "nebius": + if (!apiConfiguration.nebiusApiKey) { + return i18next.t("settings:validation.apiKey") + } + break case "anthropic": if (!apiConfiguration.apiKey) { return i18next.t("settings:validation.apiKey") @@ -143,6 +148,9 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels case "litellm": modelId = apiConfiguration.litellmModelId break + case "nebius": + modelId = apiConfiguration.nebiusModelId + break } if (!modelId) { From b125511a2576779700148204b14da16c8b27f0c9 Mon Sep 17 00:00:00 2001 From: Ara Date: Thu, 22 May 2025 08:24:23 +0530 Subject: [PATCH 2/4] Update webview-ui/src/components/settings/providers/Nebius.tsx Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- webview-ui/src/components/settings/providers/Nebius.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webview-ui/src/components/settings/providers/Nebius.tsx b/webview-ui/src/components/settings/providers/Nebius.tsx index c24de27ca4..63d142d270 100644 --- a/webview-ui/src/components/settings/providers/Nebius.tsx +++ b/webview-ui/src/components/settings/providers/Nebius.tsx @@ -39,7 +39,7 @@ export const Nebius = ({ apiConfiguration, setApiConfigurationField, routerModel Date: Thu, 29 May 2025 10:57:30 -0500 Subject: [PATCH 3/4] refactor: update placeholder --- webview-ui/src/components/settings/providers/Nebius.tsx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/webview-ui/src/components/settings/providers/Nebius.tsx b/webview-ui/src/components/settings/providers/Nebius.tsx index 63d142d270..cc70cf6d47 100644 --- a/webview-ui/src/components/settings/providers/Nebius.tsx +++ b/webview-ui/src/components/settings/providers/Nebius.tsx @@ -33,7 +33,9 @@ export const Nebius = ({ apiConfiguration, setApiConfigurationField, routerModel From 0d38b54c2201a6c65fd5fa180f01a78401cb4165 Mon Sep 17 00:00:00 2001 From: Daniel Riccio Date: Thu, 29 May 2025 11:34:11 -0500 Subject: [PATCH 4/4] fix: address PR review comments for Nebius provider - Fix typo in Nebius component (was using litellmApiKey) - Fix inconsistent default base URL between UI and API handler - Add Zod validation for Nebius API responses - Add comprehensive unit tests for Nebius provider and fetcher - Improve error handling with schema validation --- src/api/providers/__tests__/nebius.test.ts | 221 ++++++++++++++++++ .../fetchers/__tests__/nebius.test.ts | 213 +++++++++++++++++ src/api/providers/fetchers/nebius.ts | 79 +++++-- .../components/settings/providers/Nebius.tsx | 2 - 4 files changed, 488 insertions(+), 27 deletions(-) create mode 100644 src/api/providers/__tests__/nebius.test.ts create mode 100644 src/api/providers/fetchers/__tests__/nebius.test.ts diff --git a/src/api/providers/__tests__/nebius.test.ts b/src/api/providers/__tests__/nebius.test.ts new file mode 100644 index 0000000000..433f37ee25 --- /dev/null +++ b/src/api/providers/__tests__/nebius.test.ts @@ -0,0 +1,221 @@ +// npx jest src/api/providers/__tests__/nebius.test.ts + +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import { NebiusHandler } from "../nebius" +import { ApiHandlerOptions } from "../../../shared/api" + +// Mock dependencies +jest.mock("openai") +jest.mock("delay", () => jest.fn(() => Promise.resolve())) +jest.mock("../fetchers/modelCache", () => ({ + getModels: jest.fn().mockImplementation(() => { + return Promise.resolve({ + "Qwen/Qwen2.5-32B-Instruct-fast": { + maxTokens: 8192, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.13, + outputPrice: 0.4, + description: "Qwen 2.5 32B Instruct Fast", + }, + "deepseek-ai/DeepSeek-R1": { + maxTokens: 32000, + contextWindow: 96000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.8, + outputPrice: 2.4, + description: "DeepSeek R1", + }, + }) + }), +})) + +describe("NebiusHandler", () => { + const mockOptions: ApiHandlerOptions = { + nebiusApiKey: "test-key", + nebiusModelId: "Qwen/Qwen2.5-32B-Instruct-fast", + nebiusBaseUrl: "https://api.studio.nebius.ai/v1", + } + + beforeEach(() => jest.clearAllMocks()) + + it("initializes with correct options", () => { + const handler = new NebiusHandler(mockOptions) + expect(handler).toBeInstanceOf(NebiusHandler) + + expect(OpenAI).toHaveBeenCalledWith({ + baseURL: "https://api.studio.nebius.ai/v1", + apiKey: mockOptions.nebiusApiKey, + }) + }) + + it("uses default base URL when not provided", () => { + const handler = new NebiusHandler({ + nebiusApiKey: "test-key", + nebiusModelId: "Qwen/Qwen2.5-32B-Instruct-fast", + }) + expect(handler).toBeInstanceOf(NebiusHandler) + + expect(OpenAI).toHaveBeenCalledWith({ + baseURL: "https://api.studio.nebius.ai/v1", + apiKey: "test-key", + }) + }) + + describe("fetchModel", () => { + it("returns correct model info when options are provided", async () => { + const handler = new NebiusHandler(mockOptions) + const result = await handler.fetchModel() + + expect(result).toMatchObject({ + id: mockOptions.nebiusModelId, + info: { + maxTokens: 8192, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.13, + outputPrice: 0.4, + description: "Qwen 2.5 32B Instruct Fast", + }, + }) + }) + + it("returns default model info when options are not provided", async () => { + const handler = new NebiusHandler({}) + const result = await handler.fetchModel() + expect(result.id).toBe("Qwen/Qwen2.5-32B-Instruct-fast") + }) + }) + + describe("createMessage", () => { + it("generates correct stream chunks", async () => { + const handler = new NebiusHandler(mockOptions) + + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "test response" } }], + } + yield { + choices: [{ delta: {} }], + usage: { prompt_tokens: 10, completion_tokens: 20 }, + } + }, + } + + // Mock OpenAI chat.completions.create + const mockCreate = jest.fn().mockResolvedValue(mockStream) + + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + + const systemPrompt = "test system prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }] + + const generator = handler.createMessage(systemPrompt, messages) + const chunks = [] + + for await (const chunk of generator) { + chunks.push(chunk) + } + + // Verify stream chunks + expect(chunks).toHaveLength(2) // One text chunk and one usage chunk + expect(chunks[0]).toEqual({ type: "text", text: "test response" }) + expect(chunks[1]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + + // Verify OpenAI client was called with correct parameters + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "Qwen/Qwen2.5-32B-Instruct-fast", + messages: [ + { role: "system", content: "test system prompt" }, + { role: "user", content: "test message" }, + ], + temperature: 0, + stream: true, + stream_options: { include_usage: true }, + }), + ) + }) + + it("handles R1 format for DeepSeek-R1 models", async () => { + const handler = new NebiusHandler({ + ...mockOptions, + nebiusModelId: "deepseek-ai/DeepSeek-R1", + }) + + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "test response" } }], + } + }, + } + + const mockCreate = jest.fn().mockResolvedValue(mockStream) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + + const systemPrompt = "test system prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }] + + await handler.createMessage(systemPrompt, messages).next() + + // Verify R1 format is used - the first message should combine system and user content + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "deepseek-ai/DeepSeek-R1", + messages: expect.arrayContaining([ + expect.objectContaining({ + role: "user", + content: expect.stringContaining("test system prompt"), + }), + ]), + }), + ) + }) + }) + + describe("completePrompt", () => { + it("returns correct response", async () => { + const handler = new NebiusHandler(mockOptions) + const mockResponse = { choices: [{ message: { content: "test completion" } }] } + + const mockCreate = jest.fn().mockResolvedValue(mockResponse) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("test completion") + + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.nebiusModelId, + max_tokens: 8192, + temperature: 0, + messages: [{ role: "user", content: "test prompt" }], + }) + }) + + it("handles errors", async () => { + const handler = new NebiusHandler(mockOptions) + const mockError = new Error("API Error") + + const mockCreate = jest.fn().mockRejectedValue(mockError) + ;(OpenAI as jest.MockedClass).prototype.chat = { + completions: { create: mockCreate }, + } as any + + await expect(handler.completePrompt("test prompt")).rejects.toThrow("nebius completion error: API Error") + }) + }) +}) diff --git a/src/api/providers/fetchers/__tests__/nebius.test.ts b/src/api/providers/fetchers/__tests__/nebius.test.ts new file mode 100644 index 0000000000..3384b9c3fb --- /dev/null +++ b/src/api/providers/fetchers/__tests__/nebius.test.ts @@ -0,0 +1,213 @@ +// npx jest src/api/providers/fetchers/__tests__/nebius.test.ts + +import axios from "axios" +import { getNebiusModels } from "../nebius" +import { COMPUTER_USE_MODELS } from "../../../../shared/api" + +jest.mock("axios") + +describe("Nebius API", () => { + describe("getNebiusModels", () => { + const mockApiKey = "test-api-key" + const mockBaseUrl = "https://api.studio.nebius.ai/v1" + + beforeEach(() => { + jest.clearAllMocks() + }) + + it("fetches models and validates schema", async () => { + const mockResponse = { + data: { + data: [ + { + model_name: "Qwen/Qwen2.5-32B-Instruct-fast", + model_info: { + max_tokens: 8192, + max_input_tokens: 32768, + supports_vision: false, + supports_prompt_caching: false, + input_cost_per_token: 0.00000013, + output_cost_per_token: 0.0000004, + }, + nebius_params: { + model: "qwen/qwen2.5-32b-instruct", + }, + }, + { + model_name: "deepseek-ai/DeepSeek-R1", + model_info: { + max_tokens: 32000, + max_input_tokens: 96000, + supports_vision: false, + supports_prompt_caching: false, + input_cost_per_token: 0.0000008, + output_cost_per_token: 0.0000024, + }, + nebius_params: { + model: "deepseek/deepseek-r1", + }, + }, + ], + }, + } + + ;(axios.get as jest.Mock).mockResolvedValue(mockResponse) + + const result = await getNebiusModels(mockApiKey, mockBaseUrl) + + expect(axios.get).toHaveBeenCalledWith(`${mockBaseUrl}/v1/model/info`, { + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${mockApiKey}`, + }, + }) + + expect(result["Qwen/Qwen2.5-32B-Instruct-fast"]).toMatchObject({ + maxTokens: 8192, + contextWindow: 32768, + supportsImages: false, + supportsComputerUse: false, + supportsPromptCache: false, + description: "Qwen/Qwen2.5-32B-Instruct-fast via Nebius proxy", + }) + expect(result["Qwen/Qwen2.5-32B-Instruct-fast"].inputPrice).toBeCloseTo(0.13, 5) + expect(result["Qwen/Qwen2.5-32B-Instruct-fast"].outputPrice).toBeCloseTo(0.4, 5) + + expect(result["deepseek-ai/DeepSeek-R1"]).toMatchObject({ + maxTokens: 32000, + contextWindow: 96000, + supportsImages: false, + supportsComputerUse: false, + supportsPromptCache: false, + description: "deepseek-ai/DeepSeek-R1 via Nebius proxy", + }) + expect(result["deepseek-ai/DeepSeek-R1"].inputPrice).toBeCloseTo(0.8, 5) + expect(result["deepseek-ai/DeepSeek-R1"].outputPrice).toBeCloseTo(2.4, 5) + }) + + it("validates computer use models", async () => { + const mockResponse = { + data: { + data: [ + { + model_name: "anthropic/claude-3.5-sonnet", + model_info: { + max_tokens: 8192, + max_input_tokens: 200000, + supports_vision: true, + supports_prompt_caching: true, + }, + nebius_params: { + model: "anthropic/claude-3.5-sonnet", + }, + }, + { + model_name: "anthropic/claude-3.7-sonnet", + model_info: { + max_tokens: 8192, + max_input_tokens: 200000, + supports_vision: true, + supports_prompt_caching: true, + }, + nebius_params: { + model: "anthropic/claude-3.7-sonnet", + }, + }, + ], + }, + } + + ;(axios.get as jest.Mock).mockResolvedValue(mockResponse) + + const result = await getNebiusModels(mockApiKey, mockBaseUrl) + + // Verify that computer use models are correctly identified + const computerUseModels = Object.entries(result) + .filter(([_, model]) => model.supportsComputerUse) + .map(([id, _]) => id) + + expect(computerUseModels).toContain("anthropic/claude-3.5-sonnet") + expect(computerUseModels).toContain("anthropic/claude-3.7-sonnet") + + // Verify these models are in the COMPUTER_USE_MODELS set + computerUseModels.forEach((modelId) => { + expect(COMPUTER_USE_MODELS.has(modelId)).toBe(true) + }) + }) + + it("handles missing model info gracefully", async () => { + const mockResponse = { + data: { + data: [ + { + model_name: "test-model", + // Missing model_info + }, + { + // Missing model_name + model_info: { + max_tokens: 8192, + }, + }, + { + model_name: "another-model", + model_info: { + max_tokens: 8192, + }, + // Missing nebius_params + }, + ], + }, + } + + ;(axios.get as jest.Mock).mockResolvedValue(mockResponse) + + const result = await getNebiusModels(mockApiKey, mockBaseUrl) + + expect(result).toEqual({}) + }) + + it("handles invalid response format", async () => { + const mockResponse = { + data: { + // Invalid structure - missing 'data' array + models: [], + }, + } + + ;(axios.get as jest.Mock).mockResolvedValue(mockResponse) + + const result = await getNebiusModels(mockApiKey, mockBaseUrl) + + expect(result).toEqual({}) + }) + + it("handles API errors gracefully", async () => { + ;(axios.get as jest.Mock).mockRejectedValue(new Error("Network error")) + + const result = await getNebiusModels(mockApiKey, mockBaseUrl) + + expect(result).toEqual({}) + }) + + it("handles missing API key", async () => { + const mockResponse = { + data: { + data: [], + }, + } + + ;(axios.get as jest.Mock).mockResolvedValue(mockResponse) + + const result = await getNebiusModels("", mockBaseUrl) + + expect(axios.get).toHaveBeenCalledWith(`${mockBaseUrl}/v1/model/info`, { + headers: { + "Content-Type": "application/json", + }, + }) + + expect(result).toEqual({}) + }) + }) +}) diff --git a/src/api/providers/fetchers/nebius.ts b/src/api/providers/fetchers/nebius.ts index f68cd5d75e..c28f7ec796 100644 --- a/src/api/providers/fetchers/nebius.ts +++ b/src/api/providers/fetchers/nebius.ts @@ -1,5 +1,34 @@ import axios from "axios" import { OPEN_ROUTER_COMPUTER_USE_MODELS, ModelRecord } from "../../../shared/api" +import { z } from 'zod' + +/** + * Nebius API response schemas + */ +const nebiusModelInfoSchema = z.object({ + max_tokens: z.number().optional(), + max_input_tokens: z.number().optional(), + supports_vision: z.boolean().optional(), + supports_prompt_caching: z.boolean().optional(), + input_cost_per_token: z.number().optional(), + output_cost_per_token: z.number().optional(), +}) + +const nebiusModelSchema = z.object({ + model_name: z.string(), + model_info: nebiusModelInfoSchema, + nebius_params: z + .object({ + model: z.string().optional(), + }) + .optional(), +}) + +const nebiusModelsResponseSchema = z.object({ + data: z.array(nebiusModelSchema), +}) + +type NebiusModelsResponse = z.infer /** * Fetches available models from a Nebius server @@ -18,35 +47,35 @@ export async function getNebiusModels(apiKey: string, baseUrl: string): Promise< headers["Authorization"] = `Bearer ${apiKey}` } - const response = await axios.get(`${baseUrl}/v1/model/info`, { headers }) - const models: ModelRecord = {} + const response = await axios.get(`${baseUrl}/v1/model/info`, { headers }) + const result = nebiusModelsResponseSchema.safeParse(response.data) + + if (!result.success) { + console.error("Nebius models response is invalid", result.error.format()) + return {} + } + const models: ModelRecord = {} const computerModels = Array.from(OPEN_ROUTER_COMPUTER_USE_MODELS) // Process the model info from the response - if (response.data && response.data.data && Array.isArray(response.data.data)) { - for (const model of response.data.data) { - const modelName = model.model_name - const modelInfo = model.model_info - const nebiusModelName = model?.nebius_params?.model as string | undefined - - if (!modelName || !modelInfo || !nebiusModelName) continue - - models[modelName] = { - maxTokens: modelInfo.max_tokens || 8192, - contextWindow: modelInfo.max_input_tokens || 200000, - supportsImages: Boolean(modelInfo.supports_vision), - // nebius_params.model may have a prefix like openrouter/ - supportsComputerUse: computerModels.some((computer_model) => - nebiusModelName.endsWith(computer_model), - ), - supportsPromptCache: Boolean(modelInfo.supports_prompt_caching), - inputPrice: modelInfo.input_cost_per_token ? modelInfo.input_cost_per_token * 1000000 : undefined, - outputPrice: modelInfo.output_cost_per_token - ? modelInfo.output_cost_per_token * 1000000 - : undefined, - description: `${modelName} via Nebius proxy`, - } + for (const model of result.data.data) { + const modelName = model.model_name + const modelInfo = model.model_info + const nebiusModelName = model.nebius_params?.model + + if (!modelName || !modelInfo || !nebiusModelName) continue + + models[modelName] = { + maxTokens: modelInfo.max_tokens || 8192, + contextWindow: modelInfo.max_input_tokens || 200000, + supportsImages: Boolean(modelInfo.supports_vision), + // nebius_params.model may have a prefix like openrouter/ + supportsComputerUse: computerModels.some((computer_model) => nebiusModelName.endsWith(computer_model)), + supportsPromptCache: Boolean(modelInfo.supports_prompt_caching), + inputPrice: modelInfo.input_cost_per_token ? modelInfo.input_cost_per_token * 1000000 : undefined, + outputPrice: modelInfo.output_cost_per_token ? modelInfo.output_cost_per_token * 1000000 : undefined, + description: `${modelName} via Nebius proxy`, } } diff --git a/webview-ui/src/components/settings/providers/Nebius.tsx b/webview-ui/src/components/settings/providers/Nebius.tsx index cc70cf6d47..69fe93e705 100644 --- a/webview-ui/src/components/settings/providers/Nebius.tsx +++ b/webview-ui/src/components/settings/providers/Nebius.tsx @@ -31,8 +31,6 @@ export const Nebius = ({ apiConfiguration, setApiConfigurationField, routerModel return ( <>