Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ export const SECRET_STATE_KEYS = [
"openAiApiKey",
"geminiApiKey",
"openAiNativeApiKey",
"cerebrasApiKey",
"deepSeekApiKey",
"moonshotApiKey",
"mistralApiKey",
Expand Down
7 changes: 7 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const providerNames = [
"chutes",
"litellm",
"huggingface",
"cerebras",
] as const

export const providerNamesSchema = z.enum(providerNames)
Expand Down Expand Up @@ -241,6 +242,10 @@ const litellmSchema = baseProviderSettingsSchema.extend({
litellmUsePromptCache: z.boolean().optional(),
})

const cerebrasSchema = apiModelIdProviderModelSchema.extend({
cerebrasApiKey: z.string().optional(),
})

const defaultSchema = z.object({
apiProvider: z.undefined(),
})
Expand Down Expand Up @@ -271,6 +276,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
huggingFaceSchema.merge(z.object({ apiProvider: z.literal("huggingface") })),
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
cerebrasSchema.merge(z.object({ apiProvider: z.literal("cerebras") })),
defaultSchema,
])

Expand Down Expand Up @@ -301,6 +307,7 @@ export const providerSettingsSchema = z.object({
...huggingFaceSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
...cerebrasSchema.shape,
...codebaseIndexProviderSchema.shape,
})

Expand Down
46 changes: 46 additions & 0 deletions packages/types/src/providers/cerebras.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import type { ModelInfo } from "../model.js"

// https://inference-docs.cerebras.ai/api-reference/chat-completions
export type CerebrasModelId = keyof typeof cerebrasModels

export const cerebrasDefaultModelId: CerebrasModelId = "qwen-3-235b-a22b-instruct-2507"

export const cerebrasModels = {
"llama-3.3-70b": {
maxTokens: 64000,
contextWindow: 64000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "Smart model with ~2600 tokens/s",
},
"qwen-3-32b": {
maxTokens: 64000,
contextWindow: 64000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "SOTA coding performance with ~2500 tokens/s",
},
"qwen-3-235b-a22b": {
maxTokens: 40000,
contextWindow: 40000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "SOTA performance with ~1400 tokens/s",
},
"qwen-3-235b-a22b-instruct-2507": {
maxTokens: 64000,
contextWindow: 64000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "SOTA performance with ~1400 tokens/s",
supportsReasoningEffort: true,
},
} as const satisfies Record<string, ModelInfo>
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export * from "./anthropic.js"
export * from "./bedrock.js"
export * from "./cerebras.js"
export * from "./chutes.js"
export * from "./claude-code.js"
export * from "./deepseek.js"
Expand Down
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
GlamaHandler,
AnthropicHandler,
AwsBedrockHandler,
CerebrasHandler,
OpenRouterHandler,
VertexHandler,
AnthropicVertexHandler,
Expand Down Expand Up @@ -115,6 +116,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new ChutesHandler(options)
case "litellm":
return new LiteLLMHandler(options)
case "cerebras":
return new CerebrasHandler(options)
default:
apiProvider satisfies "gemini-cli" | undefined
return new AnthropicHandler(options)
Expand Down
152 changes: 152 additions & 0 deletions src/api/providers/__tests__/cerebras.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { CerebrasHandler } from "../cerebras"
import { cerebrasModels, type CerebrasModelId } from "@roo-code/types"

// Mock fetch globally
global.fetch = vi.fn()

describe("CerebrasHandler", () => {
let handler: CerebrasHandler
const mockOptions = {
cerebrasApiKey: "test-api-key",
apiModelId: "llama-3.3-70b" as CerebrasModelId,
}

beforeEach(() => {
vi.clearAllMocks()
handler = new CerebrasHandler(mockOptions)
})

describe("constructor", () => {
it("should throw error when API key is missing", () => {
expect(() => new CerebrasHandler({ cerebrasApiKey: "" })).toThrow("Cerebras API key is required")
})

it("should initialize with valid API key", () => {
expect(() => new CerebrasHandler(mockOptions)).not.toThrow()
})
})

describe("getModel", () => {
it("should return correct model info", () => {
const { id, info } = handler.getModel()
expect(id).toBe("llama-3.3-70b")
expect(info).toEqual(cerebrasModels["llama-3.3-70b"])
})

it("should fallback to default model when apiModelId is not provided", () => {
const handlerWithoutModel = new CerebrasHandler({ cerebrasApiKey: "test" })
const { id } = handlerWithoutModel.getModel()
expect(id).toBe("qwen-3-235b-a22b-instruct-2507") // cerebrasDefaultModelId
})
})

describe("message conversion", () => {
it("should strip thinking tokens from assistant messages", () => {
// This would test the stripThinkingTokens function
// Implementation details would test the regex functionality
})

it("should flatten complex message content to strings", () => {
// This would test the flattenMessageContent function
// Test various content types: strings, arrays, image objects
})

it("should convert OpenAI messages to Cerebras format", () => {
// This would test the convertToCerebrasMessages function
// Ensure all messages have string content and proper role/content structure
})
})

describe("createMessage", () => {
it("should make correct API request", async () => {
// Mock successful API response
const mockResponse = {
ok: true,
body: {
getReader: () => ({
read: vi.fn().mockResolvedValueOnce({ done: true, value: new Uint8Array() }),
releaseLock: vi.fn(),
}),
},
}
vi.mocked(fetch).mockResolvedValueOnce(mockResponse as any)

const generator = handler.createMessage("System prompt", [])
// Test that fetch was called with correct parameters
expect(fetch).toHaveBeenCalledWith(
"https://api.cerebras.ai/v1/chat/completions",
expect.objectContaining({
method: "POST",
headers: expect.objectContaining({
"Content-Type": "application/json",
Authorization: "Bearer test-api-key",
"User-Agent": "roo-cline/1.0.0",
}),
}),
)
})

it("should handle API errors properly", async () => {
const mockErrorResponse = {
ok: false,
status: 400,
text: () => Promise.resolve('{"error": "Bad Request"}'),
}
vi.mocked(fetch).mockResolvedValueOnce(mockErrorResponse as any)

const generator = handler.createMessage("System prompt", [])
await expect(generator.next()).rejects.toThrow("Cerebras API Error: 400")
})

it("should parse streaming responses correctly", async () => {
// Test streaming response parsing
// Mock ReadableStream with various data chunks
// Verify thinking token extraction and usage tracking
})

it("should handle temperature clamping", async () => {
const handlerWithTemp = new CerebrasHandler({
...mockOptions,
modelTemperature: 2.0, // Above Cerebras max of 1.5
})

vi.mocked(fetch).mockResolvedValueOnce({
ok: true,
body: { getReader: () => ({ read: () => Promise.resolve({ done: true }), releaseLock: vi.fn() }) },
} as any)

await handlerWithTemp.createMessage("test", []).next()

const requestBody = JSON.parse(vi.mocked(fetch).mock.calls[0][1]?.body as string)
expect(requestBody.temperature).toBe(1.5) // Should be clamped
})
})

describe("completePrompt", () => {
it("should handle non-streaming completion", async () => {
const mockResponse = {
ok: true,
json: () =>
Promise.resolve({
choices: [{ message: { content: "Test response" } }],
}),
}
vi.mocked(fetch).mockResolvedValueOnce(mockResponse as any)

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
})
})

describe("token usage and cost calculation", () => {
it("should track token usage properly", () => {
// Test that lastUsage is updated correctly
// Test getApiCost returns calculated cost based on actual usage
})

it("should provide usage estimates when API doesn't return usage", () => {
// Test fallback token estimation logic
})
})
})
Loading
Loading