Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configuration/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ OPENAI_API_KEY=your_openai_key_here
ANTHROPIC_API_KEY=your_anthropic_key_here
GOOGLE_API_KEY=your_google_key_here
GROQ_API_KEY=your_groq_key_here
CEREBRAS_API_KEY=your_cerebras_key_here
```
</CodeGroup>

Expand Down
1 change: 1 addition & 0 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1031,4 +1031,5 @@ export * from "../types/stagehandApiErrors";
export * from "../types/stagehandErrors";
export * from "./llm/LLMClient";
export * from "./llm/aisdk";
export { CerebrasClient } from "./llm/CerebrasClient";
export { connectToMCPServer };
349 changes: 34 additions & 315 deletions lib/llm/CerebrasClient.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
import OpenAI from "openai";
import type { ClientOptions } from "openai";
import { zodToJsonSchema } from "zod-to-json-schema";
import { LogLine } from "../../types/log";
import { AvailableModel } from "../../types/model";
import { LLMCache } from "../cache/LLMCache";
import {
ChatMessage,
CreateChatCompletionOptions,
LLMClient,
LLMResponse,
} from "./LLMClient";
import { CreateChatCompletionResponseError } from "@/types/stagehandErrors";
import { OpenAIClient } from "./OpenAIClient";
import { LLMClient, CreateChatCompletionOptions, LLMResponse } from "./LLMClient";

export class CerebrasClient extends LLMClient {
public type = "cerebras" as const;
private client: OpenAI;
private cache: LLMCache | undefined;
private enableCaching: boolean;
public clientOptions: ClientOptions;
public hasVision = false;
private openaiClient: OpenAIClient;

constructor({
enableCaching = false,
Expand All @@ -31,313 +20,43 @@ export class CerebrasClient extends LLMClient {
enableCaching?: boolean;
cache?: LLMCache;
modelName: AvailableModel;
clientOptions?: ClientOptions;
clientOptions?: any;
userProvidedInstructions?: string;
}) {
super(modelName, userProvidedInstructions);

// Create OpenAI client with the base URL set to Cerebras API
this.client = new OpenAI({
baseURL: "https://api.cerebras.ai/v1",
apiKey: clientOptions?.apiKey || process.env.CEREBRAS_API_KEY,
...clientOptions,
});

this.cache = cache;
this.enableCaching = enableCaching;
this.modelName = modelName;
this.clientOptions = clientOptions;
}

async createChatCompletion<T = LLMResponse>({
options,
retries,
logger,
}: CreateChatCompletionOptions): Promise<T> {
const optionsWithoutImage = { ...options };
delete optionsWithoutImage.image;

logger({
category: "cerebras",
message: "creating chat completion",
level: 2,
auxiliary: {
options: {
value: JSON.stringify(optionsWithoutImage),
type: "object",
},
},
});

// Try to get cached response
const cacheOptions = {
model: this.modelName.split("cerebras-")[1],
messages: options.messages,
temperature: options.temperature,
response_model: options.response_model,
tools: options.tools,
retries: retries,
};

if (this.enableCaching) {
const cachedResponse = await this.cache.get<T>(
cacheOptions,
options.requestId,
);
if (cachedResponse) {
logger({
category: "llm_cache",
message: "LLM cache hit - returning cached response",
level: 1,
auxiliary: {
cachedResponse: {
value: JSON.stringify(cachedResponse),
type: "object",
},
requestId: {
value: options.requestId,
type: "string",
},
cacheOptions: {
value: JSON.stringify(cacheOptions),
type: "object",
},
},
});
return cachedResponse as T;
}
}

// Format messages for Cerebras API (using OpenAI format)
const formattedMessages = options.messages.map((msg: ChatMessage) => {
const baseMessage = {
content:
typeof msg.content === "string"
? msg.content
: Array.isArray(msg.content) &&
msg.content.length > 0 &&
"text" in msg.content[0]
? msg.content[0].text
: "",
};

// Cerebras only supports system, user, and assistant roles
if (msg.role === "system") {
return { ...baseMessage, role: "system" as const };
} else if (msg.role === "assistant") {
return { ...baseMessage, role: "assistant" as const };
} else {
// Default to user for any other role
return { ...baseMessage, role: "user" as const };
}
});

// Format tools if provided
let tools = options.tools?.map((tool) => ({
type: "function" as const,
function: {
name: tool.name,
description: tool.description,
parameters: {
type: "object",
properties: tool.parameters.properties,
required: tool.parameters.required,

// Transform model name to remove cerebras- prefix
const openaiModelName = modelName.startsWith("cerebras-")
? modelName.split("cerebras-")[1]
: modelName;

this.openaiClient = new OpenAIClient({
enableCaching,
cache,
modelName: openaiModelName as AvailableModel,
clientOptions: {
baseURL: "https://api.cerebras.ai/v1",
defaultHeaders: {
apikey: clientOptions?.apiKey || process.env.CEREBRAS_API_KEY,
},
...clientOptions,
},
}));

// Add response model as a tool if provided
if (options.response_model) {
const jsonSchema = zodToJsonSchema(options.response_model.schema) as {
properties?: Record<string, unknown>;
required?: string[];
};
const schemaProperties = jsonSchema.properties || {};
const schemaRequired = jsonSchema.required || [];

const responseTool = {
type: "function" as const,
function: {
name: "print_extracted_data",
description:
"Prints the extracted data based on the provided schema.",
parameters: {
type: "object",
properties: schemaProperties,
required: schemaRequired,
},
},
};

tools = tools ? [...tools, responseTool] : [responseTool];
}

try {
// Use OpenAI client with Cerebras API
const apiResponse = await this.client.chat.completions.create({
model: this.modelName.split("cerebras-")[1],
messages: [
...formattedMessages,
// Add explicit instruction to return JSON if we have a response model
...(options.response_model
? [
{
role: "system" as const,
content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(
options.response_model.schema,
)}`,
},
]
: []),
],
temperature: options.temperature || 0.7,
max_tokens: options.maxTokens,
tools: tools,
tool_choice: options.tool_choice || "auto",
});

// Format the response to match the expected LLMResponse format
const response: LLMResponse = {
id: apiResponse.id,
object: "chat.completion",
created: Date.now(),
model: this.modelName.split("cerebras-")[1],
choices: [
{
index: 0,
message: {
role: "assistant",
content: apiResponse.choices[0]?.message?.content || null,
tool_calls: apiResponse.choices[0]?.message?.tool_calls || [],
},
finish_reason: apiResponse.choices[0]?.finish_reason || "stop",
},
],
usage: {
prompt_tokens: apiResponse.usage?.prompt_tokens || 0,
completion_tokens: apiResponse.usage?.completion_tokens || 0,
total_tokens: apiResponse.usage?.total_tokens || 0,
},
};

logger({
category: "cerebras",
message: "response",
level: 2,
auxiliary: {
response: {
value: JSON.stringify(response),
type: "object",
},
requestId: {
value: options.requestId,
type: "string",
},
},
});

// If we have no response model, just return the entire LLMResponse
if (!options.response_model) {
if (this.enableCaching) {
await this.cache.set(cacheOptions, response, options.requestId);
}
return response as T;
}

// If we have a response model, parse JSON from tool calls or content
const toolCall = response.choices[0]?.message?.tool_calls?.[0];
if (toolCall?.function?.arguments) {
try {
const result = JSON.parse(toolCall.function.arguments);
const finalResponse = {
data: result,
usage: response.usage,
};
if (this.enableCaching) {
await this.cache.set(
cacheOptions,
finalResponse,
options.requestId,
);
}
return finalResponse as T;
} catch (e) {
logger({
category: "cerebras",
message: "failed to parse tool call arguments as JSON, retrying",
level: 0,
auxiliary: {
error: {
value: e.message,
type: "string",
},
},
});
logger: (message: LogLine) => {
// Transform log messages to use cerebras category
const transformedMessage = {
...message,
category: message.category === "openai" ? "cerebras" : message.category,
};
// Call the original logger if it exists
if (typeof (this as any).logger === 'function') {
(this as any).logger(transformedMessage);
}
}

// If we have content but no tool calls, try to parse the content as JSON
const content = response.choices[0]?.message?.content;
if (content) {
try {
const jsonMatch = content.match(/\{[\s\S]*\}/);
if (jsonMatch) {
const result = JSON.parse(jsonMatch[0]);
const finalResponse = {
data: result,
usage: response.usage,
};
if (this.enableCaching) {
await this.cache.set(
cacheOptions,
finalResponse,
options.requestId,
);
}
return finalResponse as T;
}
} catch (e) {
logger({
category: "cerebras",
message: "failed to parse content as JSON",
level: 0,
auxiliary: {
error: {
value: e.message,
type: "string",
},
},
});
}
}

// If we still haven't found valid JSON and have retries left, try again
if (!retries || retries < 5) {
return this.createChatCompletion({
options,
logger,
retries: (retries ?? 0) + 1,
});
}
},
});
}

throw new CreateChatCompletionResponseError("Invalid response schema");
} catch (error) {
logger({
category: "cerebras",
message: "error creating chat completion",
level: 0,
auxiliary: {
error: {
value: error.message,
type: "string",
},
requestId: {
value: options.requestId,
type: "string",
},
},
});
throw error;
}
async createChatCompletion<T = LLMResponse>(options: CreateChatCompletionOptions): Promise<T> {
return this.openaiClient.createChatCompletion<T>(options);
}

}
Loading