diff --git a/docs/configuration/models.mdx b/docs/configuration/models.mdx index 83d06a974..7eebb8ae3 100644 --- a/docs/configuration/models.mdx +++ b/docs/configuration/models.mdx @@ -36,6 +36,7 @@ OPENAI_API_KEY=your_openai_key_here ANTHROPIC_API_KEY=your_anthropic_key_here GOOGLE_API_KEY=your_google_key_here GROQ_API_KEY=your_groq_key_here +CEREBRAS_API_KEY=your_cerebras_key_here ``` diff --git a/lib/index.ts b/lib/index.ts index 10eecee98..4dc633f0d 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -1031,4 +1031,5 @@ export * from "../types/stagehandApiErrors"; export * from "../types/stagehandErrors"; export * from "./llm/LLMClient"; export * from "./llm/aisdk"; +export { CerebrasClient } from "./llm/CerebrasClient"; export { connectToMCPServer }; diff --git a/lib/llm/CerebrasClient.ts b/lib/llm/CerebrasClient.ts index 4b12c0380..aee89b101 100644 --- a/lib/llm/CerebrasClient.ts +++ b/lib/llm/CerebrasClient.ts @@ -1,24 +1,13 @@ -import OpenAI from "openai"; -import type { ClientOptions } from "openai"; -import { zodToJsonSchema } from "zod-to-json-schema"; import { LogLine } from "../../types/log"; import { AvailableModel } from "../../types/model"; import { LLMCache } from "../cache/LLMCache"; -import { - ChatMessage, - CreateChatCompletionOptions, - LLMClient, - LLMResponse, -} from "./LLMClient"; -import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; +import { OpenAIClient } from "./OpenAIClient"; +import { LLMClient, CreateChatCompletionOptions, LLMResponse } from "./LLMClient"; export class CerebrasClient extends LLMClient { public type = "cerebras" as const; - private client: OpenAI; - private cache: LLMCache | undefined; - private enableCaching: boolean; - public clientOptions: ClientOptions; public hasVision = false; + private openaiClient: OpenAIClient; constructor({ enableCaching = false, @@ -31,313 +20,43 @@ export class CerebrasClient extends LLMClient { enableCaching?: boolean; cache?: LLMCache; modelName: AvailableModel; - clientOptions?: ClientOptions; + clientOptions?: any; userProvidedInstructions?: string; }) { super(modelName, userProvidedInstructions); - - // Create OpenAI client with the base URL set to Cerebras API - this.client = new OpenAI({ - baseURL: "https://api.cerebras.ai/v1", - apiKey: clientOptions?.apiKey || process.env.CEREBRAS_API_KEY, - ...clientOptions, - }); - - this.cache = cache; - this.enableCaching = enableCaching; - this.modelName = modelName; - this.clientOptions = clientOptions; - } - - async createChatCompletion({ - options, - retries, - logger, - }: CreateChatCompletionOptions): Promise { - const optionsWithoutImage = { ...options }; - delete optionsWithoutImage.image; - - logger({ - category: "cerebras", - message: "creating chat completion", - level: 2, - auxiliary: { - options: { - value: JSON.stringify(optionsWithoutImage), - type: "object", - }, - }, - }); - - // Try to get cached response - const cacheOptions = { - model: this.modelName.split("cerebras-")[1], - messages: options.messages, - temperature: options.temperature, - response_model: options.response_model, - tools: options.tools, - retries: retries, - }; - - if (this.enableCaching) { - const cachedResponse = await this.cache.get( - cacheOptions, - options.requestId, - ); - if (cachedResponse) { - logger({ - category: "llm_cache", - message: "LLM cache hit - returning cached response", - level: 1, - auxiliary: { - cachedResponse: { - value: JSON.stringify(cachedResponse), - type: "object", - }, - requestId: { - value: options.requestId, - type: "string", - }, - cacheOptions: { - value: JSON.stringify(cacheOptions), - type: "object", - }, - }, - }); - return cachedResponse as T; - } - } - - // Format messages for Cerebras API (using OpenAI format) - const formattedMessages = options.messages.map((msg: ChatMessage) => { - const baseMessage = { - content: - typeof msg.content === "string" - ? msg.content - : Array.isArray(msg.content) && - msg.content.length > 0 && - "text" in msg.content[0] - ? msg.content[0].text - : "", - }; - - // Cerebras only supports system, user, and assistant roles - if (msg.role === "system") { - return { ...baseMessage, role: "system" as const }; - } else if (msg.role === "assistant") { - return { ...baseMessage, role: "assistant" as const }; - } else { - // Default to user for any other role - return { ...baseMessage, role: "user" as const }; - } - }); - - // Format tools if provided - let tools = options.tools?.map((tool) => ({ - type: "function" as const, - function: { - name: tool.name, - description: tool.description, - parameters: { - type: "object", - properties: tool.parameters.properties, - required: tool.parameters.required, + + // Transform model name to remove cerebras- prefix + const openaiModelName = modelName.startsWith("cerebras-") + ? modelName.split("cerebras-")[1] + : modelName; + + this.openaiClient = new OpenAIClient({ + enableCaching, + cache, + modelName: openaiModelName as AvailableModel, + clientOptions: { + baseURL: "https://api.cerebras.ai/v1", + defaultHeaders: { + apikey: clientOptions?.apiKey || process.env.CEREBRAS_API_KEY, }, + ...clientOptions, }, - })); - - // Add response model as a tool if provided - if (options.response_model) { - const jsonSchema = zodToJsonSchema(options.response_model.schema) as { - properties?: Record; - required?: string[]; - }; - const schemaProperties = jsonSchema.properties || {}; - const schemaRequired = jsonSchema.required || []; - - const responseTool = { - type: "function" as const, - function: { - name: "print_extracted_data", - description: - "Prints the extracted data based on the provided schema.", - parameters: { - type: "object", - properties: schemaProperties, - required: schemaRequired, - }, - }, - }; - - tools = tools ? [...tools, responseTool] : [responseTool]; - } - - try { - // Use OpenAI client with Cerebras API - const apiResponse = await this.client.chat.completions.create({ - model: this.modelName.split("cerebras-")[1], - messages: [ - ...formattedMessages, - // Add explicit instruction to return JSON if we have a response model - ...(options.response_model - ? [ - { - role: "system" as const, - content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify( - options.response_model.schema, - )}`, - }, - ] - : []), - ], - temperature: options.temperature || 0.7, - max_tokens: options.maxTokens, - tools: tools, - tool_choice: options.tool_choice || "auto", - }); - - // Format the response to match the expected LLMResponse format - const response: LLMResponse = { - id: apiResponse.id, - object: "chat.completion", - created: Date.now(), - model: this.modelName.split("cerebras-")[1], - choices: [ - { - index: 0, - message: { - role: "assistant", - content: apiResponse.choices[0]?.message?.content || null, - tool_calls: apiResponse.choices[0]?.message?.tool_calls || [], - }, - finish_reason: apiResponse.choices[0]?.finish_reason || "stop", - }, - ], - usage: { - prompt_tokens: apiResponse.usage?.prompt_tokens || 0, - completion_tokens: apiResponse.usage?.completion_tokens || 0, - total_tokens: apiResponse.usage?.total_tokens || 0, - }, - }; - - logger({ - category: "cerebras", - message: "response", - level: 2, - auxiliary: { - response: { - value: JSON.stringify(response), - type: "object", - }, - requestId: { - value: options.requestId, - type: "string", - }, - }, - }); - - // If we have no response model, just return the entire LLMResponse - if (!options.response_model) { - if (this.enableCaching) { - await this.cache.set(cacheOptions, response, options.requestId); - } - return response as T; - } - - // If we have a response model, parse JSON from tool calls or content - const toolCall = response.choices[0]?.message?.tool_calls?.[0]; - if (toolCall?.function?.arguments) { - try { - const result = JSON.parse(toolCall.function.arguments); - const finalResponse = { - data: result, - usage: response.usage, - }; - if (this.enableCaching) { - await this.cache.set( - cacheOptions, - finalResponse, - options.requestId, - ); - } - return finalResponse as T; - } catch (e) { - logger({ - category: "cerebras", - message: "failed to parse tool call arguments as JSON, retrying", - level: 0, - auxiliary: { - error: { - value: e.message, - type: "string", - }, - }, - }); + logger: (message: LogLine) => { + // Transform log messages to use cerebras category + const transformedMessage = { + ...message, + category: message.category === "openai" ? "cerebras" : message.category, + }; + // Call the original logger if it exists + if (typeof (this as any).logger === 'function') { + (this as any).logger(transformedMessage); } - } - - // If we have content but no tool calls, try to parse the content as JSON - const content = response.choices[0]?.message?.content; - if (content) { - try { - const jsonMatch = content.match(/\{[\s\S]*\}/); - if (jsonMatch) { - const result = JSON.parse(jsonMatch[0]); - const finalResponse = { - data: result, - usage: response.usage, - }; - if (this.enableCaching) { - await this.cache.set( - cacheOptions, - finalResponse, - options.requestId, - ); - } - return finalResponse as T; - } - } catch (e) { - logger({ - category: "cerebras", - message: "failed to parse content as JSON", - level: 0, - auxiliary: { - error: { - value: e.message, - type: "string", - }, - }, - }); - } - } - - // If we still haven't found valid JSON and have retries left, try again - if (!retries || retries < 5) { - return this.createChatCompletion({ - options, - logger, - retries: (retries ?? 0) + 1, - }); - } + }, + }); + } - throw new CreateChatCompletionResponseError("Invalid response schema"); - } catch (error) { - logger({ - category: "cerebras", - message: "error creating chat completion", - level: 0, - auxiliary: { - error: { - value: error.message, - type: "string", - }, - requestId: { - value: options.requestId, - type: "string", - }, - }, - }); - throw error; - } + async createChatCompletion(options: CreateChatCompletionOptions): Promise { + return this.openaiClient.createChatCompletion(options); } + } diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index f99d6edf8..b98ae4d74 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -82,6 +82,13 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "claude-3-7-sonnet-latest": "anthropic", "cerebras-llama-3.3-70b": "cerebras", "cerebras-llama-3.1-8b": "cerebras", + "cerebras-gpt-oss-120b": "cerebras", + "cerebras-llama-4-maverick-17b-128e-instruct": "cerebras", + "cerebras-llama-4-scout-17b-16e-instruct": "cerebras", + "cerebras-qwen-3-235b-a22b-instruct-2507": "cerebras", + "cerebras-qwen-3-235b-a22b-thinking-2507": "cerebras", + "cerebras-qwen-3-32b": "cerebras", + "cerebras-qwen-3-coder-480b": "cerebras", "groq-llama-3.3-70b-versatile": "groq", "groq-llama-3.3-70b-specdec": "groq", "moonshotai/kimi-k2-instruct": "groq", diff --git a/types/model.ts b/types/model.ts index beb960a17..19817dea8 100644 --- a/types/model.ts +++ b/types/model.ts @@ -23,6 +23,13 @@ export const AvailableModelSchema = z.enum([ "claude-3-7-sonnet-20250219", "cerebras-llama-3.3-70b", "cerebras-llama-3.1-8b", + "cerebras-gpt-oss-120b", + "cerebras-llama-4-maverick-17b-128e-instruct", + "cerebras-llama-4-scout-17b-16e-instruct", + "cerebras-qwen-3-235b-a22b-instruct-2507", + "cerebras-qwen-3-235b-a22b-thinking-2507", + "cerebras-qwen-3-32b", + "cerebras-qwen-3-coder-480b", "groq-llama-3.3-70b-versatile", "groq-llama-3.3-70b-specdec", "gemini-1.5-flash",