Skip to content
4 changes: 4 additions & 0 deletions evals/apps/web/src/app/runs/new/new-run.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ export function NewRun() {
ollamaModelId,
lmStudioModelId,
openAiModelId,
nebiusModelId,
} = providerSettings

switch (apiProvider) {
Expand Down Expand Up @@ -210,6 +211,9 @@ export function NewRun() {
case "lmstudio":
setValue("model", lmStudioModelId ?? "")
break
case "nebius":
setValue("model", nebiusModelId ?? "")
break
default:
throw new Error(`Unsupported API provider: ${apiProvider}`)
}
Expand Down
15 changes: 15 additions & 0 deletions evals/packages/types/src/roo-code.ts
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,12 @@ const litellmSchema = z.object({
litellmModelId: z.string().optional(),
})

const nebiusSchema = z.object({
nebiusBaseUrl: z.string().optional(),
nebiusApiKey: z.string().optional(),
nebiusModelId: z.string().optional(),
})

const defaultSchema = z.object({
apiProvider: z.undefined(),
})
Expand Down Expand Up @@ -588,6 +594,11 @@ export const providerSettingsSchemaDiscriminated = z
apiProvider: z.literal("litellm"),
}),
),
nebiusSchema.merge(
z.object({
apiProvider: z.literal("nebius"),
}),
),
defaultSchema,
])
.and(genericProviderSettingsSchema)
Expand Down Expand Up @@ -615,6 +626,7 @@ export const providerSettingsSchema = z.object({
...groqSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
...nebiusSchema.shape,
...genericProviderSettingsSchema.shape,
})

Expand Down Expand Up @@ -715,6 +727,9 @@ const providerSettingsRecord: ProviderSettingsRecord = {
litellmBaseUrl: undefined,
litellmApiKey: undefined,
litellmModelId: undefined,
nebiusBaseUrl: undefined,
nebiusApiKey: undefined,
nebiusModelId: undefined,
}

export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys<ProviderSettings>[]
Expand Down
13 changes: 13 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export const providerNames = [
"groq",
"chutes",
"litellm",
"nebius",
] as const

export const providerNamesSchema = z.enum(providerNames)
Expand Down Expand Up @@ -200,6 +201,12 @@ const litellmSchema = baseProviderSettingsSchema.extend({
litellmModelId: z.string().optional(),
})

const nebiusSchema = baseProviderSettingsSchema.extend({
nebiusBaseUrl: z.string().optional(),
nebiusApiKey: z.string().optional(),
nebiusModelId: z.string().optional(),
})

const defaultSchema = z.object({
apiProvider: z.undefined(),
})
Expand All @@ -226,6 +233,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
groqSchema.merge(z.object({ apiProvider: z.literal("groq") })),
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
nebiusSchema.merge(z.object({ apiProvider: z.literal("nebius") })),
defaultSchema,
])

Expand All @@ -252,6 +260,7 @@ export const providerSettingsSchema = z.object({
...groqSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
...nebiusSchema.shape,
...codebaseIndexProviderSchema.shape,
})

Expand Down Expand Up @@ -353,4 +362,8 @@ export const PROVIDER_SETTINGS_KEYS = keysOf<ProviderSettings>()([
"litellmBaseUrl",
"litellmApiKey",
"litellmModelId",
// Nebius
"nebiusBaseUrl",
"nebiusApiKey",
"nebiusModelId",
])
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
GroqHandler,
ChutesHandler,
LiteLLMHandler,
NebiusHandler,
} from "./providers"

export interface SingleCompletionHandler {
Expand Down Expand Up @@ -106,6 +107,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new ChutesHandler(options)
case "litellm":
return new LiteLLMHandler(options)
case "nebius":
return new NebiusHandler(options)
default:
return new AnthropicHandler(options)
}
Expand Down
227 changes: 227 additions & 0 deletions src/api/providers/__tests__/nebius.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// npx jest src/api/providers/__tests__/nebius.test.ts

import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { NebiusHandler } from "../nebius"
import { ApiHandlerOptions } from "../../../shared/api"

// Mock dependencies
jest.mock("openai")
jest.mock("delay", () => jest.fn(() => Promise.resolve()))
jest.mock("../fetchers/modelCache", () => ({
getModels: jest.fn().mockImplementation(() => {
return Promise.resolve({
"Qwen/Qwen2.5-32B-Instruct-fast": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.13,
outputPrice: 0.4,
description: "Qwen 2.5 32B Instruct Fast",
},
"deepseek-ai/DeepSeek-R1": {
maxTokens: 32000,
contextWindow: 96000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.8,
outputPrice: 2.4,
description: "DeepSeek R1",
},
})
}),
}))

describe("NebiusHandler", () => {
const mockOptions: ApiHandlerOptions = {
nebiusApiKey: "test-key",
nebiusModelId: "Qwen/Qwen2.5-32B-Instruct-fast",
nebiusBaseUrl: "https://api.studio.nebius.ai/v1",
}

beforeEach(() => jest.clearAllMocks())

it("initializes with correct options", () => {
const handler = new NebiusHandler(mockOptions)
expect(handler).toBeInstanceOf(NebiusHandler)

expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://api.studio.nebius.ai/v1",
apiKey: mockOptions.nebiusApiKey,
})
})

it("uses default base URL when not provided", () => {
const handler = new NebiusHandler({
nebiusApiKey: "test-key",
nebiusModelId: "Qwen/Qwen2.5-32B-Instruct-fast",
})
expect(handler).toBeInstanceOf(NebiusHandler)

expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://api.studio.nebius.ai/v1",
apiKey: "test-key",
})
})

describe("fetchModel", () => {
it("returns correct model info when options are provided", async () => {
const handler = new NebiusHandler(mockOptions)
const result = await handler.fetchModel()

expect(result).toMatchObject({
id: mockOptions.nebiusModelId,
info: {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.13,
outputPrice: 0.4,
description: "Qwen 2.5 32B Instruct Fast",
},
})
})

it("returns default model info when options are not provided", async () => {
const handler = new NebiusHandler({})
const result = await handler.fetchModel()
expect(result.id).toBe("Qwen/Qwen2.5-32B-Instruct-fast")
})
})

describe("createMessage", () => {
it("generates correct stream chunks", async () => {
const handler = new NebiusHandler(mockOptions)

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "test response" } }],
}
yield {
choices: [{ delta: {} }],
usage: { prompt_tokens: 10, completion_tokens: 20 },
}
},
}

// Mock OpenAI chat.completions.create
const mockCreate = jest.fn().mockResolvedValue(mockStream)

;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]

const generator = handler.createMessage(systemPrompt, messages)
const chunks = []

for await (const chunk of generator) {
chunks.push(chunk)
}

// Verify stream chunks
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
expect(chunks[0]).toEqual({ type: "text", text: "test response" })
expect(chunks[1]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })

// Verify OpenAI client was called with correct parameters
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "Qwen/Qwen2.5-32B-Instruct-fast",
messages: [
{ role: "system", content: "test system prompt" },
{ role: "user", content: "test message" },
],
temperature: 0,
stream: true,
stream_options: { include_usage: true },
}),
)
})

it("handles R1 format for DeepSeek-R1 models", async () => {
const handler = new NebiusHandler({
...mockOptions,
nebiusModelId: "deepseek-ai/DeepSeek-R1",
})

// Ensure the model is loaded
await handler.fetchModel()

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "test response" } }],
}
},
}

const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]

await handler.createMessage(systemPrompt, messages).next()

// Verify R1 format is used - the system prompt and first user message should be merged
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "deepseek-ai/DeepSeek-R1",
messages: [
{
role: "user",
content: "test system prompt\ntest message",
},
],
temperature: 0,
stream: true,
stream_options: { include_usage: true },
}),
)
})
})

describe("completePrompt", () => {
it("returns correct response", async () => {
const handler = new NebiusHandler(mockOptions)
const mockResponse = { choices: [{ message: { content: "test completion" } }] }

const mockCreate = jest.fn().mockResolvedValue(mockResponse)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

const result = await handler.completePrompt("test prompt")

expect(result).toBe("test completion")

expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.nebiusModelId,
max_tokens: 8192,
temperature: 0,
messages: [{ role: "user", content: "test prompt" }],
})
})

it("handles errors", async () => {
const handler = new NebiusHandler(mockOptions)
const mockError = new Error("API Error")

const mockCreate = jest.fn().mockRejectedValue(mockError)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

await expect(handler.completePrompt("test prompt")).rejects.toThrow("nebius completion error: API Error")
})
})
})
Loading