Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions evals/apps/web/src/app/runs/new/new-run.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ export function NewRun() {
ollamaModelId,
lmStudioModelId,
openAiModelId,
nebiusModelId,
} = providerSettings

switch (apiProvider) {
Expand Down Expand Up @@ -210,6 +211,9 @@ export function NewRun() {
case "lmstudio":
setValue("model", lmStudioModelId ?? "")
break
case "nebius":
setValue("model", nebiusModelId ?? "")
break
default:
throw new Error(`Unsupported API provider: ${apiProvider}`)
}
Expand Down
15 changes: 15 additions & 0 deletions evals/packages/types/src/roo-code.ts
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ const litellmSchema = z.object({
litellmModelId: z.string().optional(),
})

const nebiusSchema = z.object({
nebiusBaseUrl: z.string().optional(),
nebiusApiKey: z.string().optional(),
nebiusModelId: z.string().optional(),
})

const defaultSchema = z.object({
apiProvider: z.undefined(),
})
Expand Down Expand Up @@ -589,6 +595,11 @@ export const providerSettingsSchemaDiscriminated = z
apiProvider: z.literal("litellm"),
}),
),
nebiusSchema.merge(
z.object({
apiProvider: z.literal("nebius"),
}),
),
defaultSchema,
])
.and(genericProviderSettingsSchema)
Expand Down Expand Up @@ -616,6 +627,7 @@ export const providerSettingsSchema = z.object({
...groqSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
...nebiusSchema.shape,
...genericProviderSettingsSchema.shape,
})

Expand Down Expand Up @@ -716,6 +728,9 @@ const providerSettingsRecord: ProviderSettingsRecord = {
litellmBaseUrl: undefined,
litellmApiKey: undefined,
litellmModelId: undefined,
nebiusBaseUrl: undefined,
nebiusApiKey: undefined,
nebiusModelId: undefined,
}

export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys<ProviderSettings>[]
Expand Down
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { XAIHandler } from "./providers/xai"
import { GroqHandler } from "./providers/groq"
import { ChutesHandler } from "./providers/chutes"
import { LiteLLMHandler } from "./providers/litellm"
import { NebiusHandler } from "./providers/nebius"

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
Expand Down Expand Up @@ -104,6 +105,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new ChutesHandler(options)
case "litellm":
return new LiteLLMHandler(options)
case "nebius":
return new NebiusHandler(options)
default:
return new AnthropicHandler(options)
}
Expand Down
221 changes: 221 additions & 0 deletions src/api/providers/__tests__/nebius.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// npx jest src/api/providers/__tests__/nebius.test.ts

import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { NebiusHandler } from "../nebius"
import { ApiHandlerOptions } from "../../../shared/api"

// Mock dependencies
jest.mock("openai")
jest.mock("delay", () => jest.fn(() => Promise.resolve()))
jest.mock("../fetchers/modelCache", () => ({
getModels: jest.fn().mockImplementation(() => {
return Promise.resolve({
"Qwen/Qwen2.5-32B-Instruct-fast": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.13,
outputPrice: 0.4,
description: "Qwen 2.5 32B Instruct Fast",
},
"deepseek-ai/DeepSeek-R1": {
maxTokens: 32000,
contextWindow: 96000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.8,
outputPrice: 2.4,
description: "DeepSeek R1",
},
})
}),
}))

describe("NebiusHandler", () => {
const mockOptions: ApiHandlerOptions = {
nebiusApiKey: "test-key",
nebiusModelId: "Qwen/Qwen2.5-32B-Instruct-fast",
nebiusBaseUrl: "https://api.studio.nebius.ai/v1",
}

beforeEach(() => jest.clearAllMocks())

it("initializes with correct options", () => {
const handler = new NebiusHandler(mockOptions)
expect(handler).toBeInstanceOf(NebiusHandler)

expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://api.studio.nebius.ai/v1",
apiKey: mockOptions.nebiusApiKey,
})
})

it("uses default base URL when not provided", () => {
const handler = new NebiusHandler({
nebiusApiKey: "test-key",
nebiusModelId: "Qwen/Qwen2.5-32B-Instruct-fast",
})
expect(handler).toBeInstanceOf(NebiusHandler)

expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://api.studio.nebius.ai/v1",
apiKey: "test-key",
})
})

describe("fetchModel", () => {
it("returns correct model info when options are provided", async () => {
const handler = new NebiusHandler(mockOptions)
const result = await handler.fetchModel()

expect(result).toMatchObject({
id: mockOptions.nebiusModelId,
info: {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.13,
outputPrice: 0.4,
description: "Qwen 2.5 32B Instruct Fast",
},
})
})

it("returns default model info when options are not provided", async () => {
const handler = new NebiusHandler({})
const result = await handler.fetchModel()
expect(result.id).toBe("Qwen/Qwen2.5-32B-Instruct-fast")
})
})

describe("createMessage", () => {
it("generates correct stream chunks", async () => {
const handler = new NebiusHandler(mockOptions)

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "test response" } }],
}
yield {
choices: [{ delta: {} }],
usage: { prompt_tokens: 10, completion_tokens: 20 },
}
},
}

// Mock OpenAI chat.completions.create
const mockCreate = jest.fn().mockResolvedValue(mockStream)

;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]

const generator = handler.createMessage(systemPrompt, messages)
const chunks = []

for await (const chunk of generator) {
chunks.push(chunk)
}

// Verify stream chunks
expect(chunks).toHaveLength(2) // One text chunk and one usage chunk
expect(chunks[0]).toEqual({ type: "text", text: "test response" })
expect(chunks[1]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })

// Verify OpenAI client was called with correct parameters
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "Qwen/Qwen2.5-32B-Instruct-fast",
messages: [
{ role: "system", content: "test system prompt" },
{ role: "user", content: "test message" },
],
temperature: 0,
stream: true,
stream_options: { include_usage: true },
}),
)
})

it("handles R1 format for DeepSeek-R1 models", async () => {
const handler = new NebiusHandler({
...mockOptions,
nebiusModelId: "deepseek-ai/DeepSeek-R1",
})

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "test response" } }],
}
},
}

const mockCreate = jest.fn().mockResolvedValue(mockStream)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]

await handler.createMessage(systemPrompt, messages).next()

// Verify R1 format is used - the first message should combine system and user content
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "deepseek-ai/DeepSeek-R1",
messages: expect.arrayContaining([
expect.objectContaining({
role: "user",
content: expect.stringContaining("test system prompt"),
}),
]),
}),
)
})
})

describe("completePrompt", () => {
it("returns correct response", async () => {
const handler = new NebiusHandler(mockOptions)
const mockResponse = { choices: [{ message: { content: "test completion" } }] }

const mockCreate = jest.fn().mockResolvedValue(mockResponse)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

const result = await handler.completePrompt("test prompt")

expect(result).toBe("test completion")

expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.nebiusModelId,
max_tokens: 8192,
temperature: 0,
messages: [{ role: "user", content: "test prompt" }],
})
})

it("handles errors", async () => {
const handler = new NebiusHandler(mockOptions)
const mockError = new Error("API Error")

const mockCreate = jest.fn().mockRejectedValue(mockError)
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
completions: { create: mockCreate },
} as any

await expect(handler.completePrompt("test prompt")).rejects.toThrow("nebius completion error: API Error")
})
})
})
Loading