Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configuration/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ OPENAI_API_KEY=your_openai_key_here
ANTHROPIC_API_KEY=your_anthropic_key_here
GOOGLE_API_KEY=your_google_key_here
GROQ_API_KEY=your_groq_key_here
CEREBRAS_API_KEY=your_cerebras_key_here
```
</CodeGroup>

Expand Down
1 change: 1 addition & 0 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1031,4 +1031,5 @@ export * from "../types/stagehandApiErrors";
export * from "../types/stagehandErrors";
export * from "./llm/LLMClient";
export * from "./llm/aisdk";
export { CerebrasClient } from "./llm/CerebrasClient";
export { connectToMCPServer };
349 changes: 34 additions & 315 deletions lib/llm/CerebrasClient.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
import OpenAI from "openai";
import type { ClientOptions } from "openai";
import { zodToJsonSchema } from "zod-to-json-schema";
import { LogLine } from "../../types/log";
import { AvailableModel } from "../../types/model";
import { LLMCache } from "../cache/LLMCache";
import {
ChatMessage,
CreateChatCompletionOptions,
LLMClient,
LLMResponse,
} from "./LLMClient";
import { CreateChatCompletionResponseError } from "@/types/stagehandErrors";
import { OpenAIClient } from "./OpenAIClient";
import { LLMClient, CreateChatCompletionOptions, LLMResponse } from "./LLMClient";

export class CerebrasClient extends LLMClient {
public type = "cerebras" as const;
private client: OpenAI;
private cache: LLMCache | undefined;
private enableCaching: boolean;
public clientOptions: ClientOptions;
public hasVision = false;
private openaiClient: OpenAIClient;

constructor({
enableCaching = false,
Expand All @@ -31,313 +20,43 @@ export class CerebrasClient extends LLMClient {
enableCaching?: boolean;
cache?: LLMCache;
modelName: AvailableModel;
clientOptions?: ClientOptions;
clientOptions?: any;
userProvidedInstructions?: string;
}) {
super(modelName, userProvidedInstructions);

// Create OpenAI client with the base URL set to Cerebras API
this.client = new OpenAI({
baseURL: "https://api.cerebras.ai/v1",
apiKey: clientOptions?.apiKey || process.env.CEREBRAS_API_KEY,
...clientOptions,
});

this.cache = cache;
this.enableCaching = enableCaching;
this.modelName = modelName;
this.clientOptions = clientOptions;
}

async createChatCompletion<T = LLMResponse>({
options,
retries,
logger,
}: CreateChatCompletionOptions): Promise<T> {
const optionsWithoutImage = { ...options };
delete optionsWithoutImage.image;

logger({
category: "cerebras",
message: "creating chat completion",
level: 2,
auxiliary: {
options: {
value: JSON.stringify(optionsWithoutImage),
type: "object",
},
},
});

// Try to get cached response
const cacheOptions = {
model: this.modelName.split("cerebras-")[1],
messages: options.messages,
temperature: options.temperature,
response_model: options.response_model,
tools: options.tools,
retries: retries,
};

if (this.enableCaching) {
const cachedResponse = await this.cache.get<T>(
cacheOptions,
options.requestId,
);
if (cachedResponse) {
logger({
category: "llm_cache",
message: "LLM cache hit - returning cached response",
level: 1,
auxiliary: {
cachedResponse: {
value: JSON.stringify(cachedResponse),
type: "object",
},
requestId: {
value: options.requestId,
type: "string",
},
cacheOptions: {
value: JSON.stringify(cacheOptions),
type: "object",
},
},
});
return cachedResponse as T;
}
}

// Format messages for Cerebras API (using OpenAI format)
const formattedMessages = options.messages.map((msg: ChatMessage) => {
const baseMessage = {
content:
typeof msg.content === "string"
? msg.content
: Array.isArray(msg.content) &&
msg.content.length > 0 &&
"text" in msg.content[0]
? msg.content[0].text
: "",
};

// Cerebras only supports system, user, and assistant roles
if (msg.role === "system") {
return { ...baseMessage, role: "system" as const };
} else if (msg.role === "assistant") {
return { ...baseMessage, role: "assistant" as const };
} else {
// Default to user for any other role
return { ...baseMessage, role: "user" as const };
}
});

// Format tools if provided
let tools = options.tools?.map((tool) => ({
type: "function" as const,
function: {
name: tool.name,
description: tool.description,
parameters: {
type: "object",
properties: tool.parameters.properties,
required: tool.parameters.required,

// Transform model name to remove cerebras- prefix
const openaiModelName = modelName.startsWith("cerebras-")
? modelName.split("cerebras-")[1]
: modelName;

this.openaiClient = new OpenAIClient({
enableCaching,
cache,
modelName: openaiModelName as AvailableModel,
clientOptions: {
baseURL: "https://api.cerebras.ai/v1",
defaultHeaders: {
apikey: clientOptions?.apiKey || process.env.CEREBRAS_API_KEY,
},
...clientOptions,
},
}));

// Add response model as a tool if provided
if (options.response_model) {
const jsonSchema = zodToJsonSchema(options.response_model.schema) as {
properties?: Record<string, unknown>;
required?: string[];
};
const schemaProperties = jsonSchema.properties || {};
const schemaRequired = jsonSchema.required || [];

const responseTool = {
type: "function" as const,
function: {
name: "print_extracted_data",
description:
"Prints the extracted data based on the provided schema.",
parameters: {
type: "object",
properties: schemaProperties,
required: schemaRequired,
},
},
};

tools = tools ? [...tools, responseTool] : [responseTool];
}

try {
// Use OpenAI client with Cerebras API
const apiResponse = await this.client.chat.completions.create({
model: this.modelName.split("cerebras-")[1],
messages: [
...formattedMessages,
// Add explicit instruction to return JSON if we have a response model
...(options.response_model
? [
{
role: "system" as const,
content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(
options.response_model.schema,
)}`,
},
]
: []),
],
temperature: options.temperature || 0.7,
max_tokens: options.maxTokens,
tools: tools,
tool_choice: options.tool_choice || "auto",
});

// Format the response to match the expected LLMResponse format
const response: LLMResponse = {
id: apiResponse.id,
object: "chat.completion",
created: Date.now(),
model: this.modelName.split("cerebras-")[1],
choices: [
{
index: 0,
message: {
role: "assistant",
content: apiResponse.choices[0]?.message?.content || null,
tool_calls: apiResponse.choices[0]?.message?.tool_calls || [],
},
finish_reason: apiResponse.choices[0]?.finish_reason || "stop",
},
],
usage: {
prompt_tokens: apiResponse.usage?.prompt_tokens || 0,
completion_tokens: apiResponse.usage?.completion_tokens || 0,
total_tokens: apiResponse.usage?.total_tokens || 0,
},
};

logger({
category: "cerebras",
message: "response",
level: 2,
auxiliary: {
response: {
value: JSON.stringify(response),
type: "object",
},
requestId: {
value: options.requestId,
type: "string",
},
},
});

// If we have no response model, just return the entire LLMResponse
if (!options.response_model) {
if (this.enableCaching) {
await this.cache.set(cacheOptions, response, options.requestId);
}
return response as T;
}

// If we have a response model, parse JSON from tool calls or content
const toolCall = response.choices[0]?.message?.tool_calls?.[0];
if (toolCall?.function?.arguments) {
try {
const result = JSON.parse(toolCall.function.arguments);
const finalResponse = {
data: result,
usage: response.usage,
};
if (this.enableCaching) {
await this.cache.set(
cacheOptions,
finalResponse,
options.requestId,
);
}
return finalResponse as T;
} catch (e) {
logger({
category: "cerebras",
message: "failed to parse tool call arguments as JSON, retrying",
level: 0,
auxiliary: {
error: {
value: e.message,
type: "string",
},
},
});
logger: (message: LogLine) => {
// Transform log messages to use cerebras category
const transformedMessage = {
...message,
category: message.category === "openai" ? "cerebras" : message.category,
};
// Call the original logger if it exists
if (typeof (this as any).logger === 'function') {
(this as any).logger(transformedMessage);
}
}

// If we have content but no tool calls, try to parse the content as JSON
const content = response.choices[0]?.message?.content;
if (content) {
try {
const jsonMatch = content.match(/\{[\s\S]*\}/);
if (jsonMatch) {
const result = JSON.parse(jsonMatch[0]);
const finalResponse = {
data: result,
usage: response.usage,
};
if (this.enableCaching) {
await this.cache.set(
cacheOptions,
finalResponse,
options.requestId,
);
}
return finalResponse as T;
}
} catch (e) {
logger({
category: "cerebras",
message: "failed to parse content as JSON",
level: 0,
auxiliary: {
error: {
value: e.message,
type: "string",
},
},
});
}
}

// If we still haven't found valid JSON and have retries left, try again
if (!retries || retries < 5) {
return this.createChatCompletion({
options,
logger,
retries: (retries ?? 0) + 1,
});
}
},
});
}

throw new CreateChatCompletionResponseError("Invalid response schema");
} catch (error) {
logger({
category: "cerebras",
message: "error creating chat completion",
level: 0,
auxiliary: {
error: {
value: error.message,
type: "string",
},
requestId: {
value: options.requestId,
type: "string",
},
},
});
throw error;
}
async createChatCompletion<T = LLMResponse>(options: CreateChatCompletionOptions): Promise<T> {
return this.openaiClient.createChatCompletion<T>(options);
}

}
Loading