|
| 1 | +import { Anthropic } from "@anthropic-ai/sdk" |
| 2 | +import Cerebras from "@cerebras/cerebras_cloud_sdk" |
| 3 | +import { withRetry } from "../retry" |
| 4 | +import { ApiHandlerOptions, ModelInfo, CerebrasModelId, cerebrasDefaultModelId, cerebrasModels } from "@shared/api" |
| 5 | +import { ApiHandler } from "../index" |
| 6 | +import { ApiStream } from "@api/transform/stream" |
| 7 | + |
| 8 | +export class CerebrasHandler implements ApiHandler { |
| 9 | + private options: ApiHandlerOptions |
| 10 | + private client: Cerebras |
| 11 | + |
| 12 | + constructor(options: ApiHandlerOptions) { |
| 13 | + this.options = options |
| 14 | + |
| 15 | + // Clean and validate the API key |
| 16 | + const cleanApiKey = this.options.cerebrasApiKey?.trim() |
| 17 | + |
| 18 | + if (!cleanApiKey) { |
| 19 | + throw new Error("Cerebras API key is required") |
| 20 | + } |
| 21 | + |
| 22 | + this.client = new Cerebras({ |
| 23 | + apiKey: cleanApiKey, |
| 24 | + timeout: 30000, // 30 second timeout |
| 25 | + }) |
| 26 | + } |
| 27 | + |
| 28 | + @withRetry() |
| 29 | + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { |
| 30 | + // Convert Anthropic messages to Cerebras format |
| 31 | + const cerebrasMessages: Array<{ |
| 32 | + role: "system" | "user" | "assistant" |
| 33 | + content: string |
| 34 | + }> = [{ role: "system", content: systemPrompt }] |
| 35 | + |
| 36 | + // Convert Anthropic messages to Cerebras format |
| 37 | + for (const message of messages) { |
| 38 | + if (message.role === "user") { |
| 39 | + const content = Array.isArray(message.content) |
| 40 | + ? message.content |
| 41 | + .map((block) => { |
| 42 | + if (block.type === "text") { |
| 43 | + return block.text |
| 44 | + } else if (block.type === "image") { |
| 45 | + return "[Image content not supported in Cerebras]" |
| 46 | + } |
| 47 | + return "" |
| 48 | + }) |
| 49 | + .join("\n") |
| 50 | + : message.content |
| 51 | + cerebrasMessages.push({ role: "user", content }) |
| 52 | + } else if (message.role === "assistant") { |
| 53 | + const content = Array.isArray(message.content) |
| 54 | + ? message.content |
| 55 | + .map((block) => { |
| 56 | + if (block.type === "text") { |
| 57 | + return block.text |
| 58 | + } |
| 59 | + return "" |
| 60 | + }) |
| 61 | + .join("\n") |
| 62 | + : message.content || "" |
| 63 | + cerebrasMessages.push({ role: "assistant", content }) |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + try { |
| 68 | + const stream = await this.client.chat.completions.create({ |
| 69 | + model: this.getModel().id, |
| 70 | + messages: cerebrasMessages, |
| 71 | + temperature: 0, |
| 72 | + stream: true, |
| 73 | + }) |
| 74 | + |
| 75 | + // Handle streaming response |
| 76 | + let reasoning: string | null = null // Track reasoning content for models that support thinking |
| 77 | + const modelId = this.getModel().id |
| 78 | + const isReasoningModel = modelId.includes("qwen") || modelId.includes("deepseek-r1-distill") |
| 79 | + |
| 80 | + for await (const chunk of stream as any) { |
| 81 | + // Type assertion for the streaming chunk |
| 82 | + const streamChunk = chunk as any |
| 83 | + |
| 84 | + if (streamChunk.choices?.[0]?.delta?.content) { |
| 85 | + const content = streamChunk.choices[0].delta.content |
| 86 | + |
| 87 | + // Handle reasoning models (Qwen and DeepSeek R1 Distill) that use <think> tags |
| 88 | + if (isReasoningModel) { |
| 89 | + // Check if we're entering or continuing reasoning mode |
| 90 | + if (reasoning || content.includes("<think>")) { |
| 91 | + reasoning = (reasoning || "") + content |
| 92 | + |
| 93 | + // Clean the content by removing think tags for display |
| 94 | + let cleanContent = content.replace(/<think>/g, "").replace(/<\/think>/g, "") |
| 95 | + |
| 96 | + // Only yield reasoning content if there's actual content after cleaning |
| 97 | + if (cleanContent.trim()) { |
| 98 | + yield { |
| 99 | + type: "reasoning", |
| 100 | + reasoning: cleanContent, |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + // Check if reasoning is complete |
| 105 | + if (reasoning.includes("</think>")) { |
| 106 | + reasoning = null |
| 107 | + } |
| 108 | + } else { |
| 109 | + // Regular content outside of thinking tags |
| 110 | + yield { |
| 111 | + type: "text", |
| 112 | + text: content, |
| 113 | + } |
| 114 | + } |
| 115 | + } else { |
| 116 | + // Non-reasoning models - just yield text content |
| 117 | + yield { |
| 118 | + type: "text", |
| 119 | + text: content, |
| 120 | + } |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + // Handle usage information from Cerebras API |
| 125 | + // Usage is typically only available in the final chunk |
| 126 | + if (streamChunk.usage) { |
| 127 | + const totalCost = this.calculateCost({ |
| 128 | + inputTokens: streamChunk.usage.prompt_tokens || 0, |
| 129 | + outputTokens: streamChunk.usage.completion_tokens || 0, |
| 130 | + }) |
| 131 | + |
| 132 | + yield { |
| 133 | + type: "usage", |
| 134 | + inputTokens: streamChunk.usage.prompt_tokens || 0, |
| 135 | + outputTokens: streamChunk.usage.completion_tokens || 0, |
| 136 | + cacheReadTokens: 0, |
| 137 | + cacheWriteTokens: 0, |
| 138 | + totalCost, |
| 139 | + } |
| 140 | + } |
| 141 | + } |
| 142 | + } catch (error) { |
| 143 | + throw error |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + getModel(): { id: string; info: ModelInfo } { |
| 148 | + const modelId = this.options.apiModelId |
| 149 | + if (modelId && modelId in cerebrasModels) { |
| 150 | + const id = modelId as CerebrasModelId |
| 151 | + return { id, info: cerebrasModels[id] } |
| 152 | + } |
| 153 | + return { |
| 154 | + id: cerebrasDefaultModelId, |
| 155 | + info: cerebrasModels[cerebrasDefaultModelId], |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + private calculateCost({ inputTokens, outputTokens }: { inputTokens: number; outputTokens: number }): number { |
| 160 | + const model = this.getModel() |
| 161 | + const inputPrice = model.info.inputPrice || 0 |
| 162 | + const outputPrice = model.info.outputPrice || 0 |
| 163 | + |
| 164 | + const inputCost = (inputPrice / 1_000_000) * inputTokens |
| 165 | + const outputCost = (outputPrice / 1_000_000) * outputTokens |
| 166 | + |
| 167 | + return inputCost + outputCost |
| 168 | + } |
| 169 | +} |
0 commit comments