Skip to content

Commit a6eaea8

Browse files
[evals] enable running evals on the Stagehand API (#894)
1 parent be8497c commit a6eaea8

File tree

3 files changed

+90
-40
lines changed

3 files changed

+90
-40
lines changed

evals/args.ts

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const rawArgs = process.argv.slice(2);
1111
const parsedArgs: {
1212
evalName?: string;
1313
env?: string;
14+
api?: string;
1415
trials?: number;
1516
concurrency?: number;
1617
provider?: string;
@@ -22,6 +23,8 @@ const parsedArgs: {
2223
for (const arg of rawArgs) {
2324
if (arg.startsWith("env=")) {
2425
parsedArgs.env = arg.split("=")[1]?.toLowerCase();
26+
} else if (arg.startsWith("api=")) {
27+
parsedArgs.api = arg.split("=")[1]?.toLowerCase();
2528
} else if (arg.startsWith("name=")) {
2629
parsedArgs.evalName = arg.split("=")[1];
2730
} else if (arg.startsWith("trials=")) {
@@ -48,6 +51,12 @@ if (parsedArgs.env === "browserbase") {
4851
process.env.EVAL_ENV = "LOCAL";
4952
}
5053

54+
if (parsedArgs.api === "true") {
55+
process.env.USE_API = "true";
56+
} else if (parsedArgs.api === "false") {
57+
process.env.USE_API = "false";
58+
}
59+
5160
if (parsedArgs.trials !== undefined) {
5261
process.env.EVAL_TRIAL_COUNT = String(parsedArgs.trials);
5362
}
@@ -80,22 +89,21 @@ function buildUsage(detailed = false): string {
8089

8190
const body = dedent`
8291
${chalk.magenta.underline("Keys\n")}
83-
${chalk.cyan("env")} target environment (default ${chalk.dim(
84-
"LOCAL",
85-
)}) [${chalk.yellow("LOCAL")}, ${chalk.yellow("BROWSERBASE")}]
86-
${chalk.cyan("trials")} number of trials (default ${chalk.dim(
87-
"10",
88-
)})
89-
${chalk.cyan(
90-
"concurrency",
91-
)} max parallel sessions (default ${chalk.dim("10")})
92-
${chalk.cyan("provider")} override LLM provider (default ${chalk.dim(
93-
providerDefault,
94-
)}) [${chalk.yellow("OPENAI")}, ${chalk.yellow(
95-
"ANTHROPIC",
96-
)}, ${chalk.yellow("GOOGLE")}, ${chalk.yellow("TOGETHER")}, ${chalk.yellow(
97-
"GROQ",
98-
)}, ${chalk.yellow("CEREBRAS")}]
92+
${chalk.cyan("env".padEnd(12))} ${"target environment".padEnd(24)}
93+
(default ${chalk.dim("LOCAL")}) [${chalk.yellow("BROWSERBASE")}, ${chalk.yellow("LOCAL")}] ${chalk.gray("← LOCAL sets api=false")}
94+
95+
${chalk.cyan("api".padEnd(12))} ${"use the Stagehand API".padEnd(24)}
96+
(default ${chalk.dim("false")}) [${chalk.yellow("true")}, ${chalk.yellow("false")}]
97+
98+
${chalk.cyan("trials".padEnd(12))} ${"number of trials".padEnd(24)}
99+
(default ${chalk.dim("10")})
100+
101+
${chalk.cyan("concurrency".padEnd(12))} ${"max parallel sessions".padEnd(24)}
102+
(default ${chalk.dim("10")})
103+
104+
${chalk.cyan("provider".padEnd(12))} ${"override LLM provider".padEnd(24)}
105+
(default ${chalk.dim(providerDefault)}) [${chalk.yellow("OPENAI")}, ${chalk.yellow("ANTHROPIC")}, ${chalk.yellow("GOOGLE")}, ${chalk.yellow("TOGETHER")}, ${chalk.yellow("GROQ")}, ${chalk.yellow("CEREBRAS")}]
106+
99107
100108
${chalk.magenta.underline("Positional filters\n")}
101109
category <category_name> one of: ${DEFAULT_EVAL_CATEGORIES.map((c) =>
@@ -114,6 +122,13 @@ function buildUsage(detailed = false): string {
114122
${chalk.green("pnpm run evals")} ${chalk.cyan("env=")}${chalk.yellow("BROWSERBASE")} ${chalk.cyan(
115123
"trials=",
116124
)}${chalk.yellow("3")}
125+
126+
127+
${chalk.dim("# Run evals using the Stagehand API")}
128+
129+
${chalk.green("pnpm run evals")} ${chalk.cyan("env=")}${chalk.yellow("BROWSERBASE")} ${chalk.cyan(
130+
"api=",
131+
)}${chalk.yellow("true")}
117132
118133
119134
${chalk.dim(
@@ -144,6 +159,8 @@ function buildUsage(detailed = false): string {
144159
EVAL_MAX_CONCURRENCY overridable via ${chalk.cyan("concurrency=")}
145160
146161
EVAL_PROVIDER overridable via ${chalk.cyan("provider=")}
162+
163+
USE_API overridable via ${chalk.cyan("api=true")}
147164
`;
148165

149166
return `${header}\n\n${synopsis}\n\n${body}\n${envSection}\n`;

evals/index.eval.ts

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ import OpenAI from "openai";
3535
import { initStagehand } from "./initStagehand";
3636
import { AISdkClient } from "@/examples/external_clients/aisdk";
3737
import { getAISDKLanguageModel } from "@/lib/llm/LLMProvider";
38+
import { loadApiKeyFromEnv } from "@/lib/utils";
39+
import { LogLine } from "@/types/log";
3840

3941
dotenv.config();
4042

@@ -50,6 +52,8 @@ const TRIAL_COUNT = process.env.EVAL_TRIAL_COUNT
5052
? parseInt(process.env.EVAL_TRIAL_COUNT, 10)
5153
: 3;
5254

55+
const USE_API: boolean = (process.env.USE_API ?? "").toLowerCase() === "true";
56+
5357
/**
5458
* generateSummary:
5559
* After all evaluations have finished, aggregate the results into a summary.
@@ -316,32 +320,53 @@ const generateFilteredTestcases = (): Testcase[] => {
316320
}
317321

318322
// Execute the task
319-
let llmClient: LLMClient;
320-
if (input.modelName.includes("/")) {
321-
llmClient = new AISdkClient({
322-
model: wrapAISDKModel(
323-
getAISDKLanguageModel(
324-
input.modelName.split("/")[0],
325-
input.modelName.split("/")[1],
326-
),
327-
),
323+
let taskInput: Awaited<ReturnType<typeof initStagehand>>;
324+
325+
if (USE_API) {
326+
const [provider] = input.modelName.split("/") as [string, string];
327+
328+
const logFn = (line: LogLine): void => logger.log(line);
329+
const apiKey = loadApiKeyFromEnv(provider, logFn);
330+
331+
if (!apiKey) {
332+
throw new StagehandEvalError(
333+
`USE_API=true but no API key found for provider “${provider}”.`,
334+
);
335+
}
336+
337+
taskInput = await initStagehand({
338+
logger,
339+
modelName: input.modelName,
340+
modelClientOptions: { apiKey: apiKey },
328341
});
329342
} else {
330-
llmClient = new CustomOpenAIClient({
331-
modelName: input.modelName as AvailableModel,
332-
client: wrapOpenAI(
333-
new OpenAI({
334-
apiKey: process.env.TOGETHER_AI_API_KEY,
335-
baseURL: "https://api.together.xyz/v1",
336-
}),
337-
),
343+
let llmClient: LLMClient;
344+
if (input.modelName.includes("/")) {
345+
llmClient = new AISdkClient({
346+
model: wrapAISDKModel(
347+
getAISDKLanguageModel(
348+
input.modelName.split("/")[0],
349+
input.modelName.split("/")[1],
350+
),
351+
),
352+
});
353+
} else {
354+
llmClient = new CustomOpenAIClient({
355+
modelName: input.modelName as AvailableModel,
356+
client: wrapOpenAI(
357+
new OpenAI({
358+
apiKey: process.env.TOGETHER_AI_API_KEY,
359+
baseURL: "https://api.together.xyz/v1",
360+
}),
361+
),
362+
});
363+
}
364+
taskInput = await initStagehand({
365+
logger,
366+
llmClient,
367+
modelName: input.modelName,
338368
});
339369
}
340-
const taskInput = await initStagehand({
341-
logger,
342-
llmClient,
343-
modelName: input.modelName,
344-
});
345370
let result;
346371
try {
347372
result = await taskFunction(taskInput);

evals/initStagehand.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ const StagehandConfig = {
3232
env: env,
3333
apiKey: process.env.BROWSERBASE_API_KEY,
3434
projectId: process.env.BROWSERBASE_PROJECT_ID,
35+
useAPI: process.env.USE_API === "true",
3536
verbose: 2 as const,
3637
debugDom: true,
3738
headless: false,
3839
enableCaching,
3940
domSettleTimeoutMs: 30_000,
4041
disablePino: true,
41-
experimental: true,
4242
browserbaseSessionCreateParams: {
4343
projectId: process.env.BROWSERBASE_PROJECT_ID!,
4444
browserSettings: {
@@ -63,13 +63,15 @@ const StagehandConfig = {
6363
*/
6464
export const initStagehand = async ({
6565
llmClient,
66+
modelClientOptions,
6667
domSettleTimeoutMs,
6768
logger,
6869
configOverrides,
6970
actTimeoutMs,
7071
modelName,
7172
}: {
72-
llmClient: LLMClient;
73+
llmClient?: LLMClient;
74+
modelClientOptions?: { apiKey: string };
7375
domSettleTimeoutMs?: number;
7476
logger: EvalLogger;
7577
configOverrides?: Partial<ConstructorParams>;
@@ -78,9 +80,15 @@ export const initStagehand = async ({
7880
}): Promise<StagehandInitResult> => {
7981
const config = {
8082
...StagehandConfig,
83+
modelClientOptions,
8184
llmClient,
8285
...(domSettleTimeoutMs && { domSettleTimeoutMs }),
8386
actTimeoutMs,
87+
modelName,
88+
experimental:
89+
typeof configOverrides?.experimental === "boolean"
90+
? configOverrides.experimental
91+
: !StagehandConfig.useAPI,
8492
...configOverrides,
8593
logger: logger.log.bind(logger),
8694
};

0 commit comments

Comments
 (0)