diff --git a/.changeset/shiny-wings-reply.md b/.changeset/shiny-wings-reply.md new file mode 100644 index 000000000..7d45118f0 --- /dev/null +++ b/.changeset/shiny-wings-reply.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": patch +--- + +Add support for callbacks in stagehand agent diff --git a/packages/core/lib/v3/cache/AgentCache.ts b/packages/core/lib/v3/cache/AgentCache.ts index 637374d3e..c603aa4f3 100644 --- a/packages/core/lib/v3/cache/AgentCache.ts +++ b/packages/core/lib/v3/cache/AgentCache.ts @@ -20,7 +20,7 @@ import type { AgentResult, AgentStreamResult, AgentConfig, - AgentExecuteOptions, + AgentExecuteOptionsBase, Logger, } from "../types/public"; import type { Page } from "../understudy/page"; @@ -74,7 +74,7 @@ export class AgentCache { } sanitizeExecuteOptions( - options?: AgentExecuteOptions, + options?: AgentExecuteOptionsBase, ): SanitizedAgentExecuteOptions { if (!options) return {}; const sanitized: SanitizedAgentExecuteOptions = {}; diff --git a/packages/core/lib/v3/handlers/v3AgentHandler.ts b/packages/core/lib/v3/handlers/v3AgentHandler.ts index 72e885f95..6648256bf 100644 --- a/packages/core/lib/v3/handlers/v3AgentHandler.ts +++ b/packages/core/lib/v3/handlers/v3AgentHandler.ts @@ -8,19 +8,27 @@ import { stepCountIs, type LanguageModelUsage, type StepResult, + type GenerateTextOnStepFinishCallback, + type StreamTextOnStepFinishCallback, } from "ai"; import { processMessages } from "../agent/utils/messageProcessing"; import { LLMClient } from "../llm/LLMClient"; import { AgentExecuteOptions, + AgentStreamExecuteOptions, + AgentExecuteOptionsBase, AgentResult, AgentContext, AgentState, AgentStreamResult, + AgentStreamCallbacks, } from "../types/public/agent"; import { V3FunctionName } from "../types/public/methods"; import { mapToolResultToActions } from "../agent/utils/actionMapping"; -import { MissingLLMConfigurationError } from "../types/public/sdkErrors"; +import { + MissingLLMConfigurationError, + StreamingCallbacksInNonStreamingModeError, +} from "../types/public/sdkErrors"; export class V3AgentHandler { private v3: V3; @@ -47,7 +55,7 @@ export class V3AgentHandler { } private async prepareAgent( - instructionOrOptions: string | AgentExecuteOptions, + instructionOrOptions: string | AgentExecuteOptionsBase, ): Promise { try { const options = @@ -102,7 +110,12 @@ export class V3AgentHandler { } } - private createStepHandler(state: AgentState) { + private createStepHandler( + state: AgentState, + userCallback?: + | GenerateTextOnStepFinishCallback + | StreamTextOnStepFinishCallback, + ) { return async (event: StepResult) => { this.logger({ category: "agent", @@ -150,6 +163,10 @@ export class V3AgentHandler { } state.currentPageUrl = (await this.v3.context.awaitActivePage()).url(); } + + if (userCallback) { + await userCallback(event); + } }; } @@ -166,6 +183,23 @@ export class V3AgentHandler { initialPageUrl, } = await this.prepareAgent(instructionOrOptions); + const callbacks = (instructionOrOptions as AgentExecuteOptions).callbacks; + + if (callbacks) { + const streamingOnlyCallbacks = [ + "onChunk", + "onFinish", + "onError", + "onAbort", + ]; + const invalidCallbacks = streamingOnlyCallbacks.filter( + (name) => callbacks[name as keyof typeof callbacks] != null, + ); + if (invalidCallbacks.length > 0) { + throw new StreamingCallbacksInNonStreamingModeError(invalidCallbacks); + } + } + const state: AgentState = { collectedReasoning: [], actions: [], @@ -183,7 +217,8 @@ export class V3AgentHandler { stopWhen: (result) => this.handleStop(result, maxSteps), temperature: 1, toolChoice: "auto", - onStepFinish: this.createStepHandler(state), + prepareStep: callbacks?.prepareStep, + onStepFinish: this.createStepHandler(state, callbacks?.onStepFinish), }); return this.consolidateMetricsAndResult(startTime, state, result); @@ -204,7 +239,7 @@ export class V3AgentHandler { } public async stream( - instructionOrOptions: string | AgentExecuteOptions, + instructionOrOptions: string | AgentStreamExecuteOptions, ): Promise { const { maxSteps, @@ -215,6 +250,9 @@ export class V3AgentHandler { initialPageUrl, } = await this.prepareAgent(instructionOrOptions); + const callbacks = (instructionOrOptions as AgentStreamExecuteOptions) + .callbacks as AgentStreamCallbacks | undefined; + const state: AgentState = { collectedReasoning: [], actions: [], @@ -250,11 +288,19 @@ export class V3AgentHandler { stopWhen: (result) => this.handleStop(result, maxSteps), temperature: 1, toolChoice: "auto", - onStepFinish: this.createStepHandler(state), - onError: ({ error }) => { - handleError(error); + prepareStep: callbacks?.prepareStep, + onStepFinish: this.createStepHandler(state, callbacks?.onStepFinish), + onError: (event) => { + if (callbacks?.onError) { + callbacks.onError(event); + } + handleError(event.error); }, + onChunk: callbacks?.onChunk, onFinish: (event) => { + if (callbacks?.onFinish) { + callbacks.onFinish(event); + } try { const result = this.consolidateMetricsAndResult( startTime, diff --git a/packages/core/lib/v3/tests/agent-callbacks.spec.ts b/packages/core/lib/v3/tests/agent-callbacks.spec.ts new file mode 100644 index 000000000..9c7908872 --- /dev/null +++ b/packages/core/lib/v3/tests/agent-callbacks.spec.ts @@ -0,0 +1,475 @@ +import { test, expect } from "@playwright/test"; +import { V3 } from "../v3"; +import { v3TestConfig } from "./v3.config"; +import type { StepResult, ToolSet } from "ai"; +import { StreamingCallbacksInNonStreamingModeError } from "../types/public/sdkErrors"; + +test.describe("Stagehand agent callbacks behavior", () => { + let v3: V3; + + test.beforeEach(async () => { + v3 = new V3({ + ...v3TestConfig, + experimental: true, // Required for callbacks and streaming + }); + await v3.init(); + }); + + test.afterEach(async () => { + await v3?.close?.().catch(() => {}); + }); + + test.describe("Non-streaming callbacks (stream: false)", () => { + test("onStepFinish callback is called for each step", async () => { + test.setTimeout(60000); + + const stepFinishEvents: StepResult[] = []; + + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + await agent.execute({ + instruction: + "What is the title of this page? Use close tool with taskComplete: true immediately after answering.", + maxSteps: 5, + callbacks: { + onStepFinish: async (event) => { + stepFinishEvents.push(event); + }, + }, + }); + + // Should have at least one step finish event + expect(stepFinishEvents.length).toBeGreaterThan(0); + + // Each event should have expected properties + for (const event of stepFinishEvents) { + expect(event).toHaveProperty("finishReason"); + expect(event).toHaveProperty("text"); + } + }); + + test("prepareStep callback is called before each step", async () => { + test.setTimeout(60000); + + let prepareStepCallCount = 0; + + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + await agent.execute({ + instruction: "Use close tool with taskComplete: true immediately.", + maxSteps: 3, + callbacks: { + prepareStep: async (stepContext) => { + prepareStepCallCount++; + return stepContext; + }, + }, + }); + + // prepareStep should have been called at least once + expect(prepareStepCallCount).toBeGreaterThan(0); + }); + + test("callbacks receive tool call information", async () => { + test.setTimeout(60000); + + const toolCalls: Array<{ toolName: string; input: unknown }> = []; + + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + await agent.execute({ + instruction: "Use the close tool with taskComplete: true immediately.", + maxSteps: 3, + callbacks: { + onStepFinish: async (event) => { + if (event.toolCalls) { + for (const tc of event.toolCalls) { + toolCalls.push({ + toolName: tc.toolName, + input: tc.input, + }); + } + } + }, + }, + }); + + // Should have captured at least the close tool call + expect(toolCalls.length).toBeGreaterThan(0); + expect(toolCalls.some((tc) => tc.toolName === "close")).toBe(true); + }); + }); + + test.describe("Streaming callbacks (stream: true)", () => { + test("onStepFinish callback is called for each step in stream mode", async () => { + test.setTimeout(60000); + + const stepFinishEvents: StepResult[] = []; + + const agent = v3.agent({ + stream: true, + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + const streamResult = await agent.execute({ + instruction: + "What is this page? Use close tool with taskComplete: true after answering.", + maxSteps: 5, + callbacks: { + onStepFinish: async (event) => { + stepFinishEvents.push(event); + }, + }, + }); + + // Consume the stream + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _ of streamResult.textStream) { + // Just consume + } + + // Wait for result to complete + await streamResult.result; + + // Should have at least one step finish event + expect(stepFinishEvents.length).toBeGreaterThan(0); + }); + + test("onChunk callback is called for each chunk", async () => { + test.setTimeout(60000); + + let chunkCount = 0; + + const agent = v3.agent({ + stream: true, + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + const streamResult = await agent.execute({ + instruction: + "Say hello and then use close tool with taskComplete: true", + maxSteps: 3, + callbacks: { + onChunk: async () => { + chunkCount++; + }, + }, + }); + + // Consume the stream + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _ of streamResult.textStream) { + // Just consume + } + + await streamResult.result; + + // Should have received chunks + expect(chunkCount).toBeGreaterThan(0); + }); + + test("onFinish callback is called when stream completes", async () => { + test.setTimeout(60000); + + let finishCalled = false; + let finishEvent: unknown = null; + + const agent = v3.agent({ + stream: true, + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + const streamResult = await agent.execute({ + instruction: "Use close tool with taskComplete: true immediately.", + maxSteps: 3, + callbacks: { + onFinish: (event) => { + finishCalled = true; + finishEvent = event; + }, + }, + }); + + // Consume the stream + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _ of streamResult.textStream) { + // Just consume + } + + await streamResult.result; + + // onFinish should have been called + expect(finishCalled).toBe(true); + expect(finishEvent).not.toBeNull(); + }); + + test("prepareStep callback works in stream mode", async () => { + test.setTimeout(60000); + + let prepareStepCallCount = 0; + + const agent = v3.agent({ + stream: true, + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + const streamResult = await agent.execute({ + instruction: "Use close tool with taskComplete: true immediately.", + maxSteps: 3, + callbacks: { + prepareStep: async (stepContext) => { + prepareStepCallCount++; + return stepContext; + }, + }, + }); + + // Consume the stream + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _ of streamResult.textStream) { + // Just consume + } + + await streamResult.result; + + // prepareStep should have been called at least once + expect(prepareStepCallCount).toBeGreaterThan(0); + }); + }); + + test.describe("Streaming-only callbacks runtime validation", () => { + test("throws StreamingCallbacksInNonStreamingModeError when onChunk is used", async () => { + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + try { + await agent.execute({ + instruction: "test", + callbacks: { + onChunk: (() => {}) as never, + }, + }); + throw new Error("Expected error to be thrown"); + } catch (error) { + expect(error).toBeInstanceOf(StreamingCallbacksInNonStreamingModeError); + expect( + (error as StreamingCallbacksInNonStreamingModeError).invalidCallbacks, + ).toEqual(["onChunk"]); + } + }); + + test("throws StreamingCallbacksInNonStreamingModeError when onFinish is used", async () => { + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + try { + await agent.execute({ + instruction: "test", + callbacks: { + onFinish: (() => {}) as never, + }, + }); + throw new Error("Expected error to be thrown"); + } catch (error) { + expect(error).toBeInstanceOf(StreamingCallbacksInNonStreamingModeError); + expect( + (error as StreamingCallbacksInNonStreamingModeError).invalidCallbacks, + ).toEqual(["onFinish"]); + } + }); + + test("throws StreamingCallbacksInNonStreamingModeError when onError is used", async () => { + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + try { + await agent.execute({ + instruction: "test", + callbacks: { + onError: (() => {}) as never, + }, + }); + throw new Error("Expected error to be thrown"); + } catch (error) { + expect(error).toBeInstanceOf(StreamingCallbacksInNonStreamingModeError); + expect( + (error as StreamingCallbacksInNonStreamingModeError).invalidCallbacks, + ).toEqual(["onError"]); + } + }); + + test("throws StreamingCallbacksInNonStreamingModeError when onAbort is used", async () => { + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + try { + await agent.execute({ + instruction: "test", + callbacks: { + onAbort: (() => {}) as never, + }, + }); + throw new Error("Expected error to be thrown"); + } catch (error) { + expect(error).toBeInstanceOf(StreamingCallbacksInNonStreamingModeError); + expect( + (error as StreamingCallbacksInNonStreamingModeError).invalidCallbacks, + ).toEqual(["onAbort"]); + } + }); + + test("error includes all invalid callbacks when multiple are used", async () => { + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + try { + await agent.execute({ + instruction: "test", + callbacks: { + onChunk: (() => {}) as never, + onFinish: (() => {}) as never, + }, + }); + throw new Error("Expected error to be thrown"); + } catch (error) { + expect(error).toBeInstanceOf(StreamingCallbacksInNonStreamingModeError); + expect( + (error as StreamingCallbacksInNonStreamingModeError).invalidCallbacks, + ).toEqual(["onChunk", "onFinish"]); + } + }); + }); + + test.describe("Combined callbacks", () => { + test("multiple callbacks can be used together", async () => { + test.setTimeout(60000); + + let prepareStepCount = 0; + let stepFinishCount = 0; + + const agent = v3.agent({ + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + await agent.execute({ + instruction: "Use close tool with taskComplete: true immediately.", + maxSteps: 3, + callbacks: { + prepareStep: async (stepContext) => { + prepareStepCount++; + return stepContext; + }, + onStepFinish: async () => { + stepFinishCount++; + }, + }, + }); + + // Both callbacks should have been called + expect(prepareStepCount).toBeGreaterThan(0); + expect(stepFinishCount).toBeGreaterThan(0); + }); + + test("streaming with multiple callbacks", async () => { + test.setTimeout(60000); + + let prepareStepCount = 0; + let stepFinishCount = 0; + let chunkCount = 0; + let finishCalled = false; + + const agent = v3.agent({ + stream: true, + model: "anthropic/claude-haiku-4-5-20251001", + }); + + const page = v3.context.pages()[0]; + await page.goto("https://example.com"); + + const streamResult = await agent.execute({ + instruction: + "Say hello briefly and then use close tool with taskComplete: true", + maxSteps: 3, + callbacks: { + prepareStep: async (stepContext) => { + prepareStepCount++; + return stepContext; + }, + onStepFinish: async () => { + stepFinishCount++; + }, + onChunk: async () => { + chunkCount++; + }, + onFinish: () => { + finishCalled = true; + }, + }, + }); + + // Consume the stream + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _ of streamResult.textStream) { + // Just consume + } + + await streamResult.result; + + // All callbacks should have been called + expect(prepareStepCount).toBeGreaterThan(0); + expect(stepFinishCount).toBeGreaterThan(0); + expect(chunkCount).toBeGreaterThan(0); + expect(finishCalled).toBe(true); + }); + }); +}); diff --git a/packages/core/lib/v3/types/public/agent.ts b/packages/core/lib/v3/types/public/agent.ts index 3e76799a9..54ed72345 100644 --- a/packages/core/lib/v3/types/public/agent.ts +++ b/packages/core/lib/v3/types/public/agent.ts @@ -1,5 +1,17 @@ import type { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { ToolSet, ModelMessage, wrapLanguageModel, StreamTextResult } from "ai"; +import { + ToolSet, + ModelMessage, + wrapLanguageModel, + StreamTextResult, + StepResult, + PrepareStepFunction, + GenerateTextOnStepFinishCallback, + StreamTextOnStepFinishCallback, + StreamTextOnErrorCallback, + StreamTextOnChunkCallback, + StreamTextOnFinishCallback, +} from "ai"; import { LogLine } from "./logs"; import { ClientOptions } from "./model"; import { Page as PlaywrightPage } from "playwright-core"; @@ -8,7 +20,7 @@ import { Page as PatchrightPage } from "patchright-core"; import { Page } from "../../understudy/page"; export interface AgentContext { - options: AgentExecuteOptions; + options: AgentExecuteOptionsBase; maxSteps: number; systemPrompt: string; allTools: ToolSet; @@ -57,12 +69,169 @@ export type AgentStreamResult = StreamTextResult & { result: Promise; }; -export interface AgentExecuteOptions { +/** + * Base callbacks shared between execute (non-streaming) and streaming modes. + */ +export interface AgentCallbacks { + /** + * Optional function called before each step to modify settings. + * You can change the model, tool choices, active tools, system prompt, + * and input messages for each step. + */ + prepareStep?: PrepareStepFunction; + /** + * Callback called when each step (LLM call) is finished. + * This is called for intermediate steps as well as the final step. + */ + onStepFinish?: + | GenerateTextOnStepFinishCallback + | StreamTextOnStepFinishCallback; +} + +/** + * Error message type for streaming-only callbacks used in non-streaming mode. + * This provides a clear error message when users try to use streaming callbacks without stream: true. + */ +type StreamingCallbackNotAvailable = + "This callback requires 'stream: true' in AgentConfig. Set stream: true to use streaming callbacks like onChunk, onFinish, onError, and onAbort."; + +/** + * Callbacks specific to the non-streaming execute method. + */ +export interface AgentExecuteCallbacks extends AgentCallbacks { + /** + * Callback called when each step (LLM call) is finished. + */ + onStepFinish?: GenerateTextOnStepFinishCallback; + + /** + * NOT AVAILABLE in non-streaming mode. + * This callback requires `stream: true` in AgentConfig. + * + * @example + * ```typescript + * // Enable streaming to use onChunk: + * const agent = stagehand.agent({ stream: true }); + * await agent.execute({ + * instruction: "...", + * callbacks: { onChunk: async (chunk) => console.log(chunk) } + * }); + * ``` + */ + onChunk?: StreamingCallbackNotAvailable; + + /** + * NOT AVAILABLE in non-streaming mode. + * This callback requires `stream: true` in AgentConfig. + * + * @example + * ```typescript + * // Enable streaming to use onFinish: + * const agent = stagehand.agent({ stream: true }); + * await agent.execute({ + * instruction: "...", + * callbacks: { onFinish: (event) => console.log("Done!", event) } + * }); + * ``` + */ + onFinish?: StreamingCallbackNotAvailable; + + /** + * NOT AVAILABLE in non-streaming mode. + * This callback requires `stream: true` in AgentConfig. + * + * @example + * ```typescript + * // Enable streaming to use onError: + * const agent = stagehand.agent({ stream: true }); + * await agent.execute({ + * instruction: "...", + * callbacks: { onError: ({ error }) => console.error(error) } + * }); + * ``` + */ + onError?: StreamingCallbackNotAvailable; + + /** + * NOT AVAILABLE in non-streaming mode. + * This callback requires `stream: true` in AgentConfig. + * + * @example + * ```typescript + * // Enable streaming to use onAbort: + * const agent = stagehand.agent({ stream: true }); + * await agent.execute({ + * instruction: "...", + * callbacks: { onAbort: (event) => console.log("Aborted", event.steps) } + * }); + * ``` + */ + onAbort?: StreamingCallbackNotAvailable; +} + +/** + * Callbacks specific to the streaming mode. + */ +export interface AgentStreamCallbacks extends AgentCallbacks { + /** + * Callback called when each step (LLM call) is finished during streaming. + */ + onStepFinish?: StreamTextOnStepFinishCallback; + /** + * Callback called when an error occurs during streaming. + * Use this to log errors or handle error states. + */ + onError?: StreamTextOnErrorCallback; + /** + * Callback called for each chunk of the stream. + * Stream processing will pause until the callback promise resolves. + */ + onChunk?: StreamTextOnChunkCallback; + /** + * Callback called when the stream finishes. + */ + onFinish?: StreamTextOnFinishCallback; + /** + * Callback called when the stream is aborted. + */ + onAbort?: (event: { + steps: Array>; + }) => PromiseLike | void; +} + +/** + * Base options for agent execution (without callbacks). + */ +export interface AgentExecuteOptionsBase { instruction: string; maxSteps?: number; page?: PlaywrightPage | PuppeteerPage | PatchrightPage | Page; highlightCursor?: boolean; } + +/** + * Options for non-streaming agent execution. + * Only accepts AgentExecuteCallbacks (no streaming-specific callbacks like onChunk, onFinish). + */ +export interface AgentExecuteOptions extends AgentExecuteOptionsBase { + /** + * Callbacks for non-streaming agent execution. + * For streaming callbacks (onChunk, onFinish, onError, onAbort), use stream: true in AgentConfig. + */ + callbacks?: AgentExecuteCallbacks; +} + +/** + * Options for streaming agent execution. + * Accepts AgentStreamCallbacks including onChunk, onFinish, onError, and onAbort. + */ +export interface AgentStreamExecuteOptions extends AgentExecuteOptionsBase { + /** + * Callbacks for streaming agent execution. + * Includes streaming-specific callbacks: onChunk, onFinish, onError, onAbort. + */ + callbacks?: AgentStreamCallbacks; +} export type AgentType = "openai" | "anthropic" | "google" | "microsoft"; export const AVAILABLE_CUA_MODELS = [ @@ -233,16 +402,18 @@ export type AgentConfig = { /** * Agent instance returned when stream: true is set in AgentConfig. * execute() returns a streaming result that can be consumed incrementally. + * Accepts AgentStreamExecuteOptions with streaming-specific callbacks. */ export interface StreamingAgentInstance { execute: ( - instructionOrOptions: string | AgentExecuteOptions, + instructionOrOptions: string | AgentStreamExecuteOptions, ) => Promise; } /** * Agent instance returned when stream is false or not set in AgentConfig. * execute() returns a result after the agent completes. + * Accepts AgentExecuteOptions with non-streaming callbacks only. */ export interface NonStreamingAgentInstance { execute: ( diff --git a/packages/core/lib/v3/types/public/sdkErrors.ts b/packages/core/lib/v3/types/public/sdkErrors.ts index 3131414e1..8840aa30e 100644 --- a/packages/core/lib/v3/types/public/sdkErrors.ts +++ b/packages/core/lib/v3/types/public/sdkErrors.ts @@ -319,3 +319,15 @@ export class ConnectionTimeoutError extends StagehandError { super(`Connection timeout: ${message}`); } } + +export class StreamingCallbacksInNonStreamingModeError extends StagehandError { + public readonly invalidCallbacks: string[]; + + constructor(invalidCallbacks: string[]) { + super( + `Streaming-only callback(s) "${invalidCallbacks.join('", "')}" cannot be used in non-streaming mode. ` + + `Set 'stream: true' in AgentConfig to use these callbacks.`, + ); + this.invalidCallbacks = invalidCallbacks; + } +} diff --git a/packages/core/lib/v3/v3.ts b/packages/core/lib/v3/v3.ts index e364c19ba..4cd17fc94 100644 --- a/packages/core/lib/v3/v3.ts +++ b/packages/core/lib/v3/v3.ts @@ -37,6 +37,7 @@ import { import { AgentConfig, AgentExecuteOptions, + AgentStreamExecuteOptions, AgentResult, AVAILABLE_CUA_MODELS, LogLine, @@ -1499,11 +1500,14 @@ export class V3 { */ private async prepareAgentExecution( options: AgentConfig | undefined, - instructionOrOptions: string | AgentExecuteOptions, + instructionOrOptions: + | string + | AgentExecuteOptions + | AgentStreamExecuteOptions, agentConfigSignature: string, ): Promise<{ handler: V3AgentHandler; - resolvedOptions: AgentExecuteOptions; + resolvedOptions: AgentExecuteOptions | AgentStreamExecuteOptions; instruction: string; cacheContext: AgentCacheContext | null; }> { @@ -1532,7 +1536,7 @@ export class V3 { tools, ); - const resolvedOptions: AgentExecuteOptions = + const resolvedOptions: AgentExecuteOptions | AgentStreamExecuteOptions = typeof instructionOrOptions === "string" ? { instruction: instructionOrOptions } : instructionOrOptions; @@ -1567,7 +1571,7 @@ export class V3 { */ agent(options: AgentConfig & { stream: true }): { execute: ( - instructionOrOptions: string | AgentExecuteOptions, + instructionOrOptions: string | AgentStreamExecuteOptions, ) => Promise; }; agent(options?: AgentConfig & { stream?: false }): { @@ -1577,7 +1581,10 @@ export class V3 { }; agent(options?: AgentConfig): { execute: ( - instructionOrOptions: string | AgentExecuteOptions, + instructionOrOptions: + | string + | AgentExecuteOptions + | AgentStreamExecuteOptions, ) => Promise; } { this.logger({ @@ -1728,9 +1735,20 @@ export class V3 { return { execute: async ( - instructionOrOptions: string | AgentExecuteOptions, + instructionOrOptions: + | string + | AgentExecuteOptions + | AgentStreamExecuteOptions, ): Promise => withInstanceLogContext(this.instanceId, async () => { + if ( + typeof instructionOrOptions === "object" && + instructionOrOptions.callbacks && + !this.experimental + ) { + throw new ExperimentalNotConfiguredError("Agent callbacks"); + } + // Streaming mode if (isStreaming) { if (!this.experimental) { @@ -1751,7 +1769,9 @@ export class V3 { } } - const streamResult = await handler.stream(instructionOrOptions); + const streamResult = await handler.stream( + instructionOrOptions as string | AgentStreamExecuteOptions, + ); if (cacheContext) { return this.agentCache.wrapStreamForCaching( @@ -1793,11 +1813,13 @@ export class V3 { const page = await this.ctx!.awaitActivePage(); result = await this.apiClient.agentExecute( options ?? {}, - resolvedOptions, + resolvedOptions as AgentExecuteOptions, page.mainFrameId(), ); } else { - result = await handler.execute(instructionOrOptions); + result = await handler.execute( + instructionOrOptions as string | AgentExecuteOptions, + ); } if (recording) { agentSteps = this.endAgentReplayRecording(); diff --git a/packages/core/tests/public-api/public-error-types.test.ts b/packages/core/tests/public-api/public-error-types.test.ts index 6534dce39..b13aa35b9 100644 --- a/packages/core/tests/public-api/public-error-types.test.ts +++ b/packages/core/tests/public-api/public-error-types.test.ts @@ -44,6 +44,8 @@ export const publicErrorTypes = { StagehandShadowSegmentEmptyError: Stagehand.StagehandShadowSegmentEmptyError, StagehandShadowSegmentNotFoundError: Stagehand.StagehandShadowSegmentNotFoundError, + StreamingCallbacksInNonStreamingModeError: + Stagehand.StreamingCallbacksInNonStreamingModeError, TimeoutError: Stagehand.TimeoutError, UnsupportedAISDKModelProviderError: Stagehand.UnsupportedAISDKModelProviderError, diff --git a/packages/core/tests/public-api/public-types.test.ts b/packages/core/tests/public-api/public-types.test.ts index 80a91a8d2..28b3f7bc9 100644 --- a/packages/core/tests/public-api/public-types.test.ts +++ b/packages/core/tests/public-api/public-types.test.ts @@ -47,6 +47,11 @@ type ExpectedExportedTypes = { AgentProviderType: Stagehand.AgentProviderType; AgentModelConfig: Stagehand.AgentModelConfig; AgentConfig: Stagehand.AgentConfig; + AgentCallbacks: Stagehand.AgentCallbacks; + AgentExecuteCallbacks: Stagehand.AgentExecuteCallbacks; + AgentStreamCallbacks: Stagehand.AgentStreamCallbacks; + AgentExecuteOptionsBase: Stagehand.AgentExecuteOptionsBase; + AgentStreamExecuteOptions: Stagehand.AgentStreamExecuteOptions; // Types from logs.ts LogLevel: Stagehand.LogLevel; LogLine: Stagehand.LogLine; @@ -178,6 +183,7 @@ describe("Stagehand public API types", () => { maxSteps?: number; page?: Stagehand.AnyPage; highlightCursor?: boolean; + callbacks?: Stagehand.AgentExecuteCallbacks; }; it("matches expected type shape", () => { @@ -185,6 +191,20 @@ describe("Stagehand public API types", () => { }); }); + describe("AgentStreamExecuteOptions", () => { + type ExpectedAgentStreamExecuteOptions = { + instruction: string; + maxSteps?: number; + page?: Stagehand.AnyPage; + highlightCursor?: boolean; + callbacks?: Stagehand.AgentStreamCallbacks; + }; + + it("matches expected type shape", () => { + expectTypeOf().toEqualTypeOf(); + }); + }); + describe("AgentExecutionOptions", () => { type ExpectedAgentExecutionOptions = { options: T;