diff --git a/.changeset/six-oranges-report.md b/.changeset/six-oranges-report.md new file mode 100644 index 000000000..d60a15641 --- /dev/null +++ b/.changeset/six-oranges-report.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": patch +--- + +Enhanced Stagehand agent with smart model routing, expanded toolset, and robust context management. For more information, reference the [stagehand agent docs](https://docs.stagehand.dev/basics/agent) diff --git a/docs/basics/agent.mdx b/docs/basics/agent.mdx index e1abd828b..b5ed14537 100644 --- a/docs/basics/agent.mdx +++ b/docs/basics/agent.mdx @@ -62,13 +62,41 @@ await agent.execute("apply for a job at Browserbase") Use the agent without specifying a provider to utilize any model or LLM provider: -Non CUA agents are currently only supported in TypeScript +Stagehand agent is currently only supported in TypeScript ```typescript TypeScript +// Basic usage const agent = stagehand.agent(); await agent.execute("apply for a job at Browserbase") ``` +#### Recommended Configuration + +For optimal performance, we recommend using Claude 4 sonnet with Gemini 2.5 Flash as the execution model: + +```typescript TypeScript +const agent = stagehand.agent({ + model: "anthropic/claude-4-20250514", // Reliable reasoning and planning for the agent + executionModel: "google/gemini-2.5-flash", // Fast and reliable execution for stagehand primitives (act, extract, observe) + instructions: "You are a helpful assistant that can use a web browser.", +}); + +// Enable Claude-specific optimizations for best performance +await agent.execute({ + instruction: "apply for a job at Browserbase", + storeActions: false, // Unlocks claude-specific tools + maxSteps: 25 +}); +``` + + +**Why this configuration?** Claude 4 provides excellent reasoning and planning, while Gemini 2.5 Flash offers fast execution for stagehand primitives. Setting `storeActions: false` enables coordinate-based tools for a hybrid approach of stagehand primitives and coordinate-based actions, but removes the ability to turn your agent runs into repeatable deterministic scripts. + + + +All configuration options are optional. The agent works well with default settings, but the above configuration provides the most optimal performance. + + ## MCP Integrations diff --git a/evals/tasks/agent/onlineMind2Web.ts b/evals/tasks/agent/onlineMind2Web.ts index f168b6bcd..27d1e1933 100644 --- a/evals/tasks/agent/onlineMind2Web.ts +++ b/evals/tasks/agent/onlineMind2Web.ts @@ -44,7 +44,8 @@ export const onlineMind2Web: EvalFunction = async ({ } await stagehand.page.goto(params.website, { - timeout: 75_000, + timeout: 120_000, + waitUntil: "commit", }); const provider = diff --git a/lib/agent/contextManager/checkpoints.ts b/lib/agent/contextManager/checkpoints.ts new file mode 100644 index 000000000..755a60845 --- /dev/null +++ b/lib/agent/contextManager/checkpoints.ts @@ -0,0 +1,146 @@ +import { CoreAssistantMessage, CoreMessage } from "ai"; +import { isToolCallPart, messagesToText } from "."; +import type { LLMClient } from "../../llm/LLMClient"; +import { RECENT_MESSAGES_TO_KEEP_IN_SUMMARY } from "./constants"; + +export interface CheckpointPlan { + messagesToCheckpoint: CoreMessage[]; + recentMessages: CoreMessage[]; + checkpointCount: number; +} + +export function planCheckpoint( + prompt: CoreMessage[], + systemMsgIndex: number, + toolCount: number, + recentToolsToKeep: number, + checkpointInterval: number, +): CheckpointPlan | null { + if (toolCount < checkpointInterval) return null; + + const checkpointCount = Math.floor(toolCount / checkpointInterval); + const toolsToKeep = toolCount - checkpointCount * checkpointInterval; + const recentToolsStart = Math.max( + 0, + toolCount - Math.max(recentToolsToKeep, toolsToKeep), + ); + + const messagesToCheckpoint: CoreMessage[] = []; + const recentMessages: CoreMessage[] = []; + let currentToolCount = 0; + + prompt.forEach((msg, idx) => { + if (idx <= systemMsgIndex) return; + const msgToolCount = countToolsInMessage(msg); + if (currentToolCount < recentToolsStart) messagesToCheckpoint.push(msg); + else recentMessages.push(msg); + currentToolCount += msgToolCount; + }); + + if (messagesToCheckpoint.length === 0) return null; + return { messagesToCheckpoint, recentMessages, checkpointCount }; +} + +export async function generateCheckpointSummary( + messages: CoreMessage[], + checkpointCount: number, + llmClient: LLMClient, +): Promise { + const conversationText = messagesToText(messages); + const model = llmClient.getLanguageModel?.(); + if (!model) { + return `[Checkpoint Summary - ${checkpointCount} checkpoints]\n[Summary generation failed: LLM not available]`; + } + + const { text } = await llmClient.generateText({ + model, + messages: [ + { + role: "user", + content: `Create a concise checkpoint summary of this browser automation conversation segment. + +Focus on: +1. What browser actions were performed +2. What was accomplished +3. Current state/context +4. Any errors or issues + +Conversation segment: +${conversationText} + +Provide a brief summary (max 200 words) that preserves essential context for continuing the automation task:`, + }, + ], + maxTokens: 300, + temperature: 0.3, + }); + + return `[Checkpoint Summary - ${checkpointCount} checkpoints]\n${text}`; +} + +export async function summarizeConversation( + prompt: CoreMessage[], + systemMsgIndex: number, + llmClient: LLMClient, +): Promise<{ + summaryMessage: CoreAssistantMessage; + recentMessages: CoreMessage[]; +}> { + const recentMessages = prompt.slice(-RECENT_MESSAGES_TO_KEEP_IN_SUMMARY); + const summary = await generateConversationSummary( + prompt.slice(systemMsgIndex + 1), + llmClient, + ); + const summaryMessage: CoreAssistantMessage = { + role: "assistant", + content: `[Previous Conversation Summary]\n\n${summary}\n\n[End of Summary - Continuing conversation from this point]`, + }; + return { summaryMessage, recentMessages }; +} + +export async function generateConversationSummary( + messages: CoreMessage[], + llmClient: LLMClient, +): Promise { + const conversationText = messagesToText(messages); + const model = llmClient.getLanguageModel?.(); + if (!model) return "[Summary generation failed: LLM not available]"; + + const { text } = await llmClient.generateText({ + model, + messages: [ + { + role: "user", + content: `Analyze this browser automation conversation and create a comprehensive summary that preserves all important context. + +Conversation: +${conversationText} + +Create a summary that: +1. Captures all key browser actions and their outcomes +2. Preserves important technical details +3. Maintains context about what was accomplished +4. Notes the current page/state +5. Includes any pending tasks or issues +6. Summarizes data extracted or forms filled + +Provide a thorough summary that will allow continuation of the automation task:`, + }, + ], + maxTokens: 500, + temperature: 0.3, + }); + + return text; +} + +export function countToolsInMessage(msg: CoreMessage): number { + if (msg.role === "tool") return 1; + if (msg.role === "assistant") { + const assistantMsg = msg; + if (typeof assistantMsg.content !== "string") { + return assistantMsg.content.filter((part) => isToolCallPart(part)).length; + } + } + return 0; +} diff --git a/lib/agent/contextManager/compression.ts b/lib/agent/contextManager/compression.ts new file mode 100644 index 000000000..ac1db9b1c --- /dev/null +++ b/lib/agent/contextManager/compression.ts @@ -0,0 +1,134 @@ +import { CoreMessage, ToolContent } from "ai"; +import { + compressToolResultContent, + isImageContentPart, + isToolResultContentPart, +} from "."; +import { + DEFAULT_TRUNCATE_TEXT_OVER, + SCREENSHOT_TEXT_PLACEHOLDER, + TOOL_RESULT_AGE_MESSAGES_TO_CONSIDER_OLD, + MAX_PREVIOUS_SAME_TOOL_RESULTS_TO_KEEP, +} from "./constants"; +import { LogLevel } from "@/types/log"; + +export function compressToolResults( + prompt: CoreMessage[], + logger?: (message: string, level: LogLevel) => void, +): CoreMessage[] { + const processed = [...prompt]; + const toolPositions = new Map(); + let replacedOldToolResults = 0; + let replacedOldScreenshots = 0; + let replacedOldAriaTrees = 0; + let imagesConvertedToText = 0; + let truncatedLongToolResults = 0; + + prompt.forEach((msg, idx) => { + if (msg.role === "tool") { + const toolMessage = msg; + toolMessage.content.forEach((item) => { + if (isToolResultContentPart(item)) { + const positions = toolPositions.get(item.toolName) || []; + positions.push(idx); + toolPositions.set(item.toolName, positions); + } + }); + } + }); + + const mapped = processed.map((msg, idx) => { + if (msg.role === "tool") { + const toolMessage = msg; + const processedContent: ToolContent = toolMessage.content.map((item) => { + if (isToolResultContentPart(item)) { + const positions = toolPositions.get(item.toolName) || []; + const currentPos = positions.indexOf(idx); + const isOldByAge = + prompt.length - idx > TOOL_RESULT_AGE_MESSAGES_TO_CONSIDER_OLD; + const isOldByCount = + currentPos >= 0 && + positions.length - currentPos > + MAX_PREVIOUS_SAME_TOOL_RESULTS_TO_KEEP; + const isOld = isOldByAge || isOldByCount; + if (isOld) { + if (item.toolName === "screenshot") { + replacedOldToolResults++; + replacedOldScreenshots++; + logger?.( + `[compression] Replaced old screenshot tool-result at message index ${idx} (reason: ${[ + isOldByAge ? "age" : "", + isOldByCount ? "prior-results" : "", + ] + .filter(Boolean) + .join("+")})`, + 2, + ); + return { + type: "tool-result", + toolCallId: item.toolCallId, + toolName: item.toolName, + result: "Screenshot taken", + }; + } else if (item.toolName === "ariaTree") { + replacedOldToolResults++; + replacedOldAriaTrees++; + logger?.( + `[compression] Compressed old ariaTree tool-result at message index ${idx} (reason: ${[ + isOldByAge ? "age" : "", + isOldByCount ? "prior-results" : "", + ] + .filter(Boolean) + .join("+")})`, + 2, + ); + return { + type: "tool-result", + toolCallId: item.toolCallId, + toolName: item.toolName, + result: { + success: true, + content: "Aria tree retrieved (compressed)", + }, + }; + } + } + } + // Convert screenshot image content to text + if (isImageContentPart(item)) { + imagesConvertedToText++; + return { + type: "text", + text: SCREENSHOT_TEXT_PLACEHOLDER, + } as unknown as ToolContent[number]; + } + + if (isToolResultContentPart(item)) { + const compressed = compressToolResultContent(item, { + truncateTextOver: DEFAULT_TRUNCATE_TEXT_OVER, + }); + if (compressed !== item) truncatedLongToolResults++; + return compressed; + } + + return item; + }); + + return { ...toolMessage, content: processedContent }; + } + return msg; + }); + + if ( + replacedOldToolResults > 0 || + imagesConvertedToText > 0 || + truncatedLongToolResults > 0 + ) { + logger?.( + `[compression] Summary: replaced old tool-results=${replacedOldToolResults} (screenshots=${replacedOldScreenshots}, ariaTree=${replacedOldAriaTrees}); images→text=${imagesConvertedToText}; truncated long tool results=${truncatedLongToolResults}`, + 2, + ); + } + + return mapped; +} diff --git a/lib/agent/contextManager/constants.ts b/lib/agent/contextManager/constants.ts new file mode 100644 index 000000000..31d0f37b7 --- /dev/null +++ b/lib/agent/contextManager/constants.ts @@ -0,0 +1,23 @@ +// Compression thresholds +export const TOOL_RESULT_AGE_MESSAGES_TO_CONSIDER_OLD = 7; +export const MAX_PREVIOUS_SAME_TOOL_RESULTS_TO_KEEP = 2; +export const DEFAULT_TRUNCATE_TEXT_OVER = 4000; + +// Token estimation defaults +export const DEFAULT_TOKENS_PER_IMAGE = 2000; +export const DEFAULT_TOKENS_PER_TOOL_CALL = 50; +export const DEFAULT_TOKENS_FOR_UNKNOWN_TOOL_CONTENT = 200; + +// Summaries and previews +export const ARIA_TREE_PREVIEW_CHARS = 100; +export const GENERIC_RESULT_PREVIEW_CHARS = 50; +export const RECENT_MESSAGES_TO_KEEP_IN_SUMMARY = 10; + +// Text placeholders +export const SCREENSHOT_TEXT_PLACEHOLDER = "[screenshot]"; +export const IMAGE_TEXT_PLACEHOLDER = "[image]"; + +// Context manager thresholds +export const CHECKPOINT_INTERVAL = 50; +export const RECENT_TOOLS_TO_KEEP = 10; +export const SUMMARIZATION_THRESHOLD = 120000; diff --git a/lib/agent/contextManager/contentParts.ts b/lib/agent/contextManager/contentParts.ts new file mode 100644 index 000000000..bb9f0a66d --- /dev/null +++ b/lib/agent/contextManager/contentParts.ts @@ -0,0 +1,151 @@ +import type { AgentToolResult } from "../tools"; +import { + DEFAULT_TOKENS_FOR_UNKNOWN_TOOL_CONTENT, + DEFAULT_TOKENS_PER_IMAGE, + DEFAULT_TOKENS_PER_TOOL_CALL, + GENERIC_RESULT_PREVIEW_CHARS, + ARIA_TREE_PREVIEW_CHARS, +} from "./constants"; + +// Helper to convert any result to string consistently +export function getResultAsString(result: unknown): string { + return typeof result === "string" ? result : JSON.stringify(result); +} + +function hasProperty( + obj: unknown, + prop: T, +): obj is Record { + return typeof obj === "object" && obj !== null && prop in obj; +} +export function isImageContentPart( + item: unknown, +): item is { type: "image"; data: string; mimeType: string } { + return ( + typeof item === "object" && + item !== null && + (item as { type?: unknown }).type === "image" + ); +} + +export function isTextContentPart( + item: unknown, +): item is { type: "text"; text: string } { + return ( + typeof item === "object" && + item !== null && + (item as { type?: unknown }).type === "text" && + typeof (item as { text?: unknown }).text === "string" + ); +} + +export function isToolCallPart(part: unknown): part is { + type: "tool-call"; +} { + return hasProperty(part, "type") && part.type === "tool-call"; +} + +export function textLengthTokens(text: string): number { + return Math.ceil(text.length / 4); +} + +export function sumTokensFromTextParts( + parts: Array<{ type: "text"; text: string }>, +): number { + return parts.reduce((acc, p) => acc + textLengthTokens(p.text), 0); +} + +export function isToolResultContentPart( + item: unknown, +): item is AgentToolResult { + return hasProperty(item, "type") && item.type === "tool-result"; +} + +export function estimateTokensForToolContent(item: unknown): number { + if (isImageContentPart(item)) return DEFAULT_TOKENS_PER_IMAGE; + if (!isToolResultContentPart(item)) + return DEFAULT_TOKENS_FOR_UNKNOWN_TOOL_CONTENT; + + const toolResult = item as AgentToolResult; + if (toolResult.toolName === "screenshot") { + // If compression replaced screenshot tool-result with a small string, count it as text + const maybeString = (toolResult as unknown as { result?: unknown }).result; + if (typeof maybeString === "string") { + return textLengthTokens(maybeString); + } + // Otherwise treat as a small tool result rather than an image + return DEFAULT_TOKENS_PER_TOOL_CALL; + } + + if (toolResult.toolName === "ariaTree") { + if (toolResult.result.success && toolResult.result.content) { + return textLengthTokens(toolResult.result.content); + } + return DEFAULT_TOKENS_FOR_UNKNOWN_TOOL_CONTENT; + } + + // For all other tools, estimate based on result content + const resultStr = getResultAsString(toolResult.result); + return textLengthTokens(resultStr); +} + +export function toolResultSummaryLabel(t: AgentToolResult): string { + if (t.toolName === "screenshot") { + return `[screenshot result: Screenshot taken]`; + } + + if (t.toolName === "ariaTree") { + if (t.result.success && t.result.content) { + return `[ariaTree result: ${previewText( + t.result.content, + ARIA_TREE_PREVIEW_CHARS, + )}]`; + } else if (!t.result.success && t.result.error) { + return `[ariaTree error: ${t.result.error}]`; + } + return `[ariaTree result: Aria tree retrieved]`; + } + + // Handle all other tools + const resultStr = getResultAsString(t.result); + if (resultStr.length > 0) { + return `[${t.toolName} result: ${previewText( + resultStr, + GENERIC_RESULT_PREVIEW_CHARS, + )}]`; + } + return `[${t.toolName} result]`; +} + +export function compressToolResultContent( + toolResult: AgentToolResult, + options?: { truncateTextOver?: number }, +): AgentToolResult { + const limit = options?.truncateTextOver ?? 4000; + + if (toolResult.toolName === "screenshot") { + return toolResult; + } + + const resultStr = getResultAsString(toolResult.result); + + if (resultStr.length > limit) { + if (toolResult.toolName === "ariaTree" && toolResult.result.content) { + return { + ...toolResult, + result: { + ...toolResult.result, + content: "Aria tree retrieved - truncated", + }, + } as AgentToolResult; + } + return { ...toolResult, result: "Truncated" } as unknown as AgentToolResult; + } + + return toolResult; +} + +function previewText(text: string, maxChars: number): string { + const preview = text.substring(0, maxChars); + return `${preview}${text.length > maxChars ? "..." : ""}`; +} diff --git a/lib/agent/contextManager/contextManager.ts b/lib/agent/contextManager/contextManager.ts new file mode 100644 index 000000000..9bf10edd3 --- /dev/null +++ b/lib/agent/contextManager/contextManager.ts @@ -0,0 +1,421 @@ +import { + LanguageModelV1CallOptions, + CoreMessage, + CoreAssistantMessage, +} from "ai"; +import { + compressToolResults, + countTools, + estimateTokens, + generateCheckpointSummary, + planCheckpoint, + summarizeConversation, +} from "."; + +import { + CHECKPOINT_INTERVAL, + RECENT_TOOLS_TO_KEEP, + SUMMARIZATION_THRESHOLD, +} from "./constants"; + +import { LLMClient } from "../../llm/LLMClient"; +import { LogLine } from "../../../types/log"; + +type PromptInput = LanguageModelV1CallOptions["prompt"]; + +interface ProcessedState { + processedPrompt: PromptInput; + lastProcessedIndex: number; + checkpointCount: number; + totalToolCount: number; + compressionLevel: number; +} + +interface CacheEntry { + state: ProcessedState; + timestamp: number; +} + +export async function compressMessages( + messages: PromptInput, + sessionId?: string, + logger?: (message: LogLine) => void, +): Promise { + const manager = new ContextManager(logger); + return manager.processMessages(messages, sessionId || "default"); +} + +export class ContextManager { + private cache = new Map(); + private ttl = 3600000; // 1 hour + private logger?: (message: LogLine) => void; + + // Thresholds moved to centralized constants + + constructor(logger?: (message: LogLine) => void) { + this.logger = logger; + } + + async processMessages( + prompt: PromptInput, + sessionId: string, + llmClient?: LLMClient, + ): Promise { + this.cleanup(); + + const cachedEntry = this.cache.get(sessionId); + const previousState = cachedEntry?.state; + + if (!previousState) { + return this.processInitialPrompt(prompt, sessionId); + } + + return this.processIncrementalPrompt( + prompt, + sessionId, + previousState, + llmClient, + ); + } + + private async processInitialPrompt( + prompt: PromptInput, + sessionId: string, + ): Promise { + const promptArray = this.toCoreMessages(prompt); + const toolCount = countTools(promptArray); + const estimatedTokens = estimateTokens(promptArray); + + this.logger?.({ + category: "context", + message: `Initial prompt analysis: ${promptArray.length} messages, ${toolCount} tools, ~${estimatedTokens} tokens`, + level: 2, + }); + + let processedPrompt = [...promptArray]; + let compressionLevel = 0; + + if (toolCount > 7) { + const beforeSize = JSON.stringify(processedPrompt).length; + processedPrompt = compressToolResults(processedPrompt, (message, level) => + this.logger?.({ + category: "context", + message, + level, + }), + ); + const afterSize = JSON.stringify(processedPrompt).length; + compressionLevel = 1; + + this.logger?.({ + category: "context", + message: `Basic compression applied: ${beforeSize} → ${afterSize} chars (${Math.round((1 - afterSize / beforeSize) * 100)}% reduction)`, + level: 2, + }); + } + + const state: ProcessedState = { + processedPrompt: processedPrompt as PromptInput, + lastProcessedIndex: promptArray.length, + checkpointCount: 0, + totalToolCount: toolCount, + compressionLevel, + }; + + this.setCache(sessionId, state); + + return processedPrompt as PromptInput; + } + + private async processIncrementalPrompt( + prompt: PromptInput, + sessionId: string, + previousState: ProcessedState, + llmClient?: LLMClient, + ): Promise { + const promptArray = this.toCoreMessages(prompt); + const previousPromptArray = Array.isArray(previousState.processedPrompt) + ? (previousState.processedPrompt as CoreMessage[]) + : []; + + const newMessages = promptArray.slice(previousState.lastProcessedIndex); + + let processedPrompt = [...previousPromptArray]; + + processedPrompt = processedPrompt.concat(newMessages as CoreMessage[]); + + const totalToolCount = + previousState.totalToolCount + countTools(newMessages as CoreMessage[]); + let estimatedTokensNow = estimateTokens(processedPrompt); + + let compressionLevel = previousState.compressionLevel; + + // Level 1: Basic compression (idempotent; re-apply to cover new tool results) + if (totalToolCount > 7) { + const beforeSize = JSON.stringify(processedPrompt).length; + const beforeTokens = estimatedTokensNow; + processedPrompt = compressToolResults(processedPrompt, (message, level) => + this.logger?.({ + category: "context", + message, + level, + }), + ); + const afterSize = JSON.stringify(processedPrompt).length; + estimatedTokensNow = estimateTokens(processedPrompt); + if (afterSize !== beforeSize || estimatedTokensNow !== beforeTokens) { + if (compressionLevel < 1) { + compressionLevel = 1; + } + const tokenReductionPct = Math.round( + (1 - estimatedTokensNow / Math.max(1, beforeTokens)) * 100, + ); + this.logger?.({ + category: "context", + message: `Basic compression: ${beforeSize} → ${afterSize} chars (${Math.round((1 - afterSize / beforeSize) * 100)}%); tokens ~${beforeTokens} → ~${estimatedTokensNow} (${tokenReductionPct}%)`, + level: 2, + }); + } + } + + if (llmClient && this.shouldCreateCheckpoint(totalToolCount)) { + const beforeCount = processedPrompt.length; + const beforeTokens = estimatedTokensNow; + processedPrompt = await this.createCheckpoint( + processedPrompt, + sessionId, + llmClient, + ); + const afterCount = processedPrompt.length; + estimatedTokensNow = estimateTokens(processedPrompt); + + const tokenReductionPct = Math.round( + (1 - estimatedTokensNow / Math.max(1, beforeTokens)) * 100, + ); + this.logger?.({ + category: "context", + message: `Checkpoint created: ${beforeCount} → ${afterCount} messages (${totalToolCount} tools processed)`, + level: 2, + }); + this.logger?.({ + category: "context", + message: `Checkpoint optimization: tokens ~${beforeTokens} → ~${estimatedTokensNow} (${tokenReductionPct}%)`, + level: 2, + }); + } + + const shouldSummarizeByTokens = + estimatedTokensNow > SUMMARIZATION_THRESHOLD; + + if (llmClient && shouldSummarizeByTokens && compressionLevel < 2) { + const beforeCount = processedPrompt.length; + processedPrompt = await this.summarizeAndTruncateConversation( + processedPrompt, + sessionId, + llmClient, + ); + const afterCount = processedPrompt.length; + compressionLevel = 2; + + this.logger?.({ + category: "context", + message: `FULL SUMMARIZATION: ${beforeCount} → ${afterCount} messages (exceeded ${SUMMARIZATION_THRESHOLD} tokens)`, + level: 2, + }); + } else if (llmClient && compressionLevel < 2) { + this.logger?.({ + category: "context", + message: `Skip summarization: tokens ~${estimatedTokensNow} ≤ threshold ${SUMMARIZATION_THRESHOLD}`, + level: 2, + }); + } + + const newState: ProcessedState = { + processedPrompt: processedPrompt as PromptInput, + lastProcessedIndex: promptArray.length, + checkpointCount: Math.floor(totalToolCount / CHECKPOINT_INTERVAL), + totalToolCount, + compressionLevel, + }; + + this.setCache(sessionId, newState); + + return processedPrompt as PromptInput; + } + + private async createCheckpoint( + prompt: CoreMessage[], + sessionId: string, + llmClient: LLMClient, + ): Promise { + try { + const toolCount = countTools(prompt); + const { index: systemMsgIndex, systemMessage } = + this.getSystemMessageInfo(prompt); + const plan = planCheckpoint( + prompt, + systemMsgIndex, + toolCount, + RECENT_TOOLS_TO_KEEP, + CHECKPOINT_INTERVAL, + ); + if (!plan) return prompt; + + const { messagesToCheckpoint, recentMessages, checkpointCount } = plan; + const checkpointText = await generateCheckpointSummary( + messagesToCheckpoint, + checkpointCount, + llmClient, + ); + + const checkpointMessage: CoreAssistantMessage = { + role: "assistant", + content: checkpointText, + }; + + // Reconstruct messages + const result: CoreMessage[] = []; + if (systemMessage) { + result.push(systemMessage); + } + result.push(checkpointMessage); + result.push(...recentMessages); + + this.logger?.({ + category: "context", + message: `Checkpoint created: ${messagesToCheckpoint.length} messages → 1 checkpoint + ${recentMessages.length} recent messages`, + level: 2, + }); + + return result; + } catch (error) { + this.logger?.({ + category: "context", + message: `Checkpoint creation failed: ${error instanceof Error ? error.message : String(error)}`, + level: 2, + }); + return prompt; + } + } + + private async summarizeAndTruncateConversation( + prompt: CoreMessage[], + sessionId: string, + llmClient: LLMClient, + ): Promise { + try { + // Find system message + const { index: systemMsgIndex, systemMessage } = + this.getSystemMessageInfo(prompt); + + // Pre-summarization metrics + const beforeTokenEstimate = estimateTokens(prompt); + const beforeCharSize = JSON.stringify(prompt).length; + const beforeMessageCount = prompt.length; + + const { summaryMessage, recentMessages } = await summarizeConversation( + prompt, + systemMsgIndex, + llmClient, + ); + + const result: CoreMessage[] = []; + if (systemMessage) { + result.push(systemMessage); + } + result.push(summaryMessage); + + recentMessages.forEach((msg) => { + if (msg.role !== "system") { + result.push(msg); + } + }); + + // Post-summarization metrics + const afterTokenEstimate = estimateTokens(result); + const afterCharSize = JSON.stringify(result).length; + const afterMessageCount = result.length; + + const tokenReductionPct = Math.round( + (1 - afterTokenEstimate / Math.max(1, beforeTokenEstimate)) * 100, + ); + const charReductionPct = Math.round( + (1 - afterCharSize / Math.max(1, beforeCharSize)) * 100, + ); + + this.logger?.({ + category: "context", + message: `Summarization optimization: messages ${beforeMessageCount} → ${afterMessageCount}; tokens ~${beforeTokenEstimate} → ~${afterTokenEstimate} (${tokenReductionPct}%); chars ${beforeCharSize} → ${afterCharSize} (${charReductionPct}%); recent kept ${recentMessages.length}${systemMessage ? " + system" : ""}`, + level: 2, + }); + + this.cache.set(`${sessionId}:summary`, { + state: { + processedPrompt: result as PromptInput, + lastProcessedIndex: prompt.length, + checkpointCount: 0, + totalToolCount: 0, + compressionLevel: 2, + }, + timestamp: Date.now(), + }); + + this.logger?.({ + category: "context", + message: `Full conversation summary created: ${prompt.length} → ${result.length} messages`, + level: 2, + }); + + return result; + } catch (error) { + this.logger?.({ + category: "context", + message: `Conversation summarization failed: ${error instanceof Error ? error.message : String(error)}`, + level: 2, + }); + return prompt; + } + } + + private getSystemMessageInfo(messages: CoreMessage[]): { + index: number; + systemMessage: CoreMessage | null; + } { + const index = messages.findIndex((msg) => msg.role === "system"); + return { + index, + systemMessage: index >= 0 ? messages[index] : null, + }; + } + + private toCoreMessages(prompt: PromptInput): CoreMessage[] { + return Array.isArray(prompt) ? prompt : []; + } + + private shouldCreateCheckpoint(totalToolCount: number): boolean { + return ( + totalToolCount >= CHECKPOINT_INTERVAL && + totalToolCount % CHECKPOINT_INTERVAL === 0 + ); + } + + private setCache(sessionId: string, state: ProcessedState): void { + this.cache.set(sessionId, { + state, + timestamp: Date.now(), + }); + } + + private cleanup() { + const now = Date.now(); + for (const [key, entry] of this.cache.entries()) { + if (now - entry.timestamp > this.ttl) { + this.cache.delete(key); + } + } + } + + clearSession(sessionId: string) { + this.cache.delete(sessionId); + } +} diff --git a/lib/agent/contextManager/index.ts b/lib/agent/contextManager/index.ts new file mode 100644 index 000000000..a79f045bf --- /dev/null +++ b/lib/agent/contextManager/index.ts @@ -0,0 +1,7 @@ +export * from "./contextManager"; +export * from "./compression"; +export * from "./messageText"; +export * from "./metrics"; +export * from "./checkpoints"; +export * from "./contentParts"; +export * from "./constants"; diff --git a/lib/agent/contextManager/messageText.ts b/lib/agent/contextManager/messageText.ts new file mode 100644 index 000000000..33106a039 --- /dev/null +++ b/lib/agent/contextManager/messageText.ts @@ -0,0 +1,49 @@ +import { CoreMessage } from "ai"; +import { + isImageContentPart, + isTextContentPart, + isToolCallPart, + toolResultSummaryLabel, +} from "."; +import { IMAGE_TEXT_PLACEHOLDER } from "./constants"; + +export function messagesToText(messages: CoreMessage[]): string { + return messages + .map((msg) => { + if (msg.role === "user") { + const userMsg = msg; + const content = + typeof userMsg.content === "string" + ? userMsg.content + : userMsg.content + .map((p) => + isTextContentPart(p) ? p.text : IMAGE_TEXT_PLACEHOLDER, + ) + .join(" "); + return `User: ${content}`; + } else if (msg.role === "assistant") { + const assistantMsg = msg; + const content = + typeof assistantMsg.content === "string" + ? assistantMsg.content + : assistantMsg.content + .map((p) => { + if (isTextContentPart(p)) return p.text; + if (isToolCallPart(p)) return `[Called tool: ${p.toolName}]`; + if (isImageContentPart(p)) return "[image]"; + return ""; + }) + .join(" "); + return `Assistant: ${content}`; + } else if (msg.role === "tool") { + const toolMsg = msg; + const toolSummary = toolMsg.content + .map(toolResultSummaryLabel) + .join(" "); + return `Tool: ${toolSummary}`; + } + return ""; + }) + .filter(Boolean) + .join("\n\n"); +} diff --git a/lib/agent/contextManager/metrics.ts b/lib/agent/contextManager/metrics.ts new file mode 100644 index 000000000..4c2c735b3 --- /dev/null +++ b/lib/agent/contextManager/metrics.ts @@ -0,0 +1,66 @@ +import { CoreMessage } from "ai"; +import { + isImageContentPart, + isTextContentPart, + textLengthTokens, + estimateTokensForToolContent, + isToolCallPart, +} from "."; +import { + DEFAULT_TOKENS_PER_IMAGE, + DEFAULT_TOKENS_PER_TOOL_CALL, +} from "./constants"; + +export function countTools(prompt: CoreMessage[]): number { + let count = 0; + prompt.forEach((msg) => { + if (msg.role === "tool") { + const toolMessage = msg; + count += toolMessage.content.length; + } else if (msg.role === "assistant") { + const assistantMessage = msg; + if (typeof assistantMessage.content !== "string") { + assistantMessage.content.forEach((part) => { + if (isToolCallPart(part)) count++; + }); + } + } + }); + return count; +} + +export function estimateTokens(prompt: CoreMessage[]): number { + let tokens = 0; + prompt.forEach((msg) => { + if (msg.role === "user") { + const user = msg; + if (typeof user.content === "string") { + tokens += textLengthTokens(user.content); + } else { + user.content.forEach((part) => { + if (isTextContentPart(part)) tokens += textLengthTokens(part.text); + else if (isImageContentPart(part)) tokens += DEFAULT_TOKENS_PER_IMAGE; + }); + } + } else if (msg.role === "assistant") { + const assistantMessage = msg; + if (typeof assistantMessage.content === "string") { + tokens += textLengthTokens(assistantMessage.content); + } else { + assistantMessage.content.forEach((part) => { + if (isTextContentPart(part)) { + tokens += textLengthTokens(part.text); + } else if (isToolCallPart(part)) { + tokens += DEFAULT_TOKENS_PER_TOOL_CALL; + } + }); + } + } else if (msg.role === "tool") { + const toolMessage = msg; + toolMessage.content.forEach((item) => { + tokens += estimateTokensForToolContent(item); + }); + } + }); + return tokens; +} diff --git a/lib/agent/tools/ariaTree.ts b/lib/agent/tools/ariaTree.ts index e0a28d170..ac1feb2de 100644 --- a/lib/agent/tools/ariaTree.ts +++ b/lib/agent/tools/ariaTree.ts @@ -2,34 +2,85 @@ import { tool } from "ai"; import { z } from "zod/v3"; import { Stagehand } from "../../index"; -export const createAriaTreeTool = (stagehand: Stagehand) => - tool({ +// Schema for models that support optional parameters well +const defaultParametersSchema = z.object({ + chunkNumber: z + .number() + .optional() + .describe( + "The chunk number to retrieve (1-based). If not provided, returns the first chunk.", + ), +}); + +// Schema for GPT-5: make chunkNumber explicitly required (no optional/default) +// GPT-5 requires all properties to be in the 'required' array +const gpt5ParametersSchema = z.object({ + chunkNumber: z + .number() + .describe( + "The chunk number to retrieve (1-based). Use 1 for the first chunk.", + ), +}); + +export const createAriaTreeTool = (stagehand: Stagehand, isGpt5 = false) => { + const parametersSchema = isGpt5 + ? gpt5ParametersSchema + : defaultParametersSchema; + + return tool({ description: - "gets the accessibility (ARIA) tree from the current page. this is useful for understanding the page structure and accessibility features. it should provide full context of what is on the page", - parameters: z.object({}), - execute: async () => { - const { page_text } = await stagehand.page.extract(); - const pageUrl = stagehand.page.url(); - - let content = page_text; - const MAX_CHARACTERS = 70000; - - const estimatedTokens = Math.ceil(content.length / 4); - - if (estimatedTokens > MAX_CHARACTERS) { - const maxCharacters = MAX_CHARACTERS * 4; - content = - content.substring(0, maxCharacters) + - "\n\n[CONTENT TRUNCATED: Exceeded 70,000 token limit]"; - } + "gets the accessibility (ARIA) tree from the current page in chunks. this is useful for understanding the page structure and accessibility features. it provides full context of what is on the page, broken into manageable chunks. if no chunk number is specified, returns the first chunk with metadata about total chunks available.", + parameters: parametersSchema, + execute: async ({ chunkNumber = 1 }) => { + try { + const { page_text } = await stagehand.page.extract(); - return { - content, - pageUrl, - }; - }, - experimental_toToolResultContent: (result) => { - const content = typeof result === "string" ? result : result.content; - return [{ type: "text", text: `Accessibility Tree:\n${content}` }]; + const TOKENS_PER_CHUNK = 60000; + const CHARACTERS_PER_TOKEN = 4; // Rough estimate + const CHARACTERS_PER_CHUNK = TOKENS_PER_CHUNK * CHARACTERS_PER_TOKEN; + + const totalCharacters = page_text.length; + const totalChunks = Math.ceil(totalCharacters / CHARACTERS_PER_CHUNK); + + if (chunkNumber < 1 || chunkNumber > totalChunks) { + return { + success: false, + error: `Invalid chunk number ${chunkNumber}. Available chunks: 1-${totalChunks}`, + }; + } + + // Calculate chunk boundaries + const startIndex = (chunkNumber - 1) * CHARACTERS_PER_CHUNK; + const endIndex = Math.min( + startIndex + CHARACTERS_PER_CHUNK, + totalCharacters, + ); + const chunkContent = page_text.substring(startIndex, endIndex); + + const hasMoreChunks = chunkNumber < totalChunks; + const nextChunkNumber = hasMoreChunks ? chunkNumber + 1 : null; + + let content = `Accessibility Tree - Chunk ${chunkNumber} of ${totalChunks} (characters ${startIndex + 1}-${endIndex} of ${totalCharacters})\n\n${chunkContent}`; + + if (hasMoreChunks) { + content += `\n\n[CHUNK INCOMPLETE: This is chunk ${chunkNumber} of ${totalChunks}. To get the next chunk, call this tool again with chunkNumber: ${nextChunkNumber}]`; + } else { + content += `\n\n[CHUNK COMPLETE: This is the final chunk (${chunkNumber} of ${totalChunks})]`; + } + + return { + success: true, + content, + chunkNumber, + totalChunks, + hasMoreChunks, + }; + } catch { + return { + success: false, + error: `Error getting aria tree, try again`, + }; + } }, }); +}; diff --git a/lib/agent/tools/click.ts b/lib/agent/tools/click.ts new file mode 100644 index 000000000..d56fcb7a8 --- /dev/null +++ b/lib/agent/tools/click.ts @@ -0,0 +1,29 @@ +import { tool } from "ai"; +import { z } from "zod/v3"; +import { Stagehand } from "../../index"; + +export const createClickTool = (stagehand: Stagehand) => + tool({ + description: + "Click on an element using its coordinates ( this is the most reliable way to click on an element, always use this over act, unless the element is not visible in the screenshot, but shown in ariaTree)", + parameters: z.object({ + describe: z + .string() + .describe( + "Describe the element to click on in a short, specific phrase that mentions the element type and a good visual description", + ), + coordinates: z + .array(z.number()) + .describe("The (x, y) coordinates to click on"), + }), + execute: async ({ describe, coordinates }) => { + try { + await stagehand.page.mouse.move(coordinates[0], coordinates[1]); + await stagehand.page.waitForTimeout(50); + await stagehand.page.mouse.click(coordinates[0], coordinates[1]); + return { success: true, describe, coordinates }; + } catch { + return { success: false, error: `Error clicking, try again` }; + } + }, + }); diff --git a/lib/agent/tools/clickAndHold.ts b/lib/agent/tools/clickAndHold.ts new file mode 100644 index 000000000..939deef17 --- /dev/null +++ b/lib/agent/tools/clickAndHold.ts @@ -0,0 +1,29 @@ +import { tool } from "ai"; +import { z } from "zod/v3"; +import { Stagehand } from "../../index"; + +export const createClickAndHoldTool = (stagehand: Stagehand) => + tool({ + description: "Click and hold on an element using its coordinates", + parameters: z.object({ + describe: z + .string() + .describe( + "Describe the element to click on in a short, specific phrase that mentions the element type and a good visual description", + ), + duration: z + .number() + .describe("The duration to hold the element in milliseconds"), + coordinates: z + .array(z.number()) + .describe("The (x, y) coordinates to click on"), + }), + + execute: async ({ describe, coordinates, duration }) => { + await stagehand.page.mouse.move(coordinates[0], coordinates[1]); + await stagehand.page.mouse.down(); + await stagehand.page.waitForTimeout(duration); + await stagehand.page.mouse.up(); + return { success: true, describe }; + }, + }); diff --git a/lib/agent/tools/dragAndDrop.ts b/lib/agent/tools/dragAndDrop.ts new file mode 100644 index 000000000..abaf93831 --- /dev/null +++ b/lib/agent/tools/dragAndDrop.ts @@ -0,0 +1,25 @@ +import { tool } from "ai"; +import { z } from "zod/v3"; +import { Stagehand } from "../../index"; + +export const createDragAndDropTool = (stagehand: Stagehand) => + tool({ + description: + "Drag and drop an element using its coordinates ( this is the most reliable way to drag and drop an element, always use this over act, unless the element is not visible in the screenshot, but shown in ariaTree)", + parameters: z.object({ + describe: z.string().describe("Describe the element to drag and drop"), + startCoordinates: z + .array(z.number()) + .describe("The (x, y) coordinates to start the drag and drop from"), + endCoordinates: z + .array(z.number()) + .describe("The (x, y) coordinates to end the drag and drop at"), + }), + execute: async ({ describe, startCoordinates, endCoordinates }) => { + await stagehand.page.mouse.move(startCoordinates[0], startCoordinates[1]); + await stagehand.page.mouse.down(); + await stagehand.page.mouse.move(endCoordinates[0], endCoordinates[1]); + await stagehand.page.mouse.up(); + return { success: true, describe }; + }, + }); diff --git a/lib/agent/tools/extract.ts b/lib/agent/tools/extract.ts index 4e758e158..ab7b072da 100644 --- a/lib/agent/tools/extract.ts +++ b/lib/agent/tools/extract.ts @@ -65,7 +65,7 @@ export const createExtractTool = ( schema: z .string() .describe( - 'Zod schema as a string (e.g., "z.object({ price: z.number() })")', + 'Zod schema as a string (e.g., "z.object({ price: z.number().optional() })")', ), }), execute: async ({ instruction, schema }) => { @@ -89,7 +89,6 @@ export const createExtractTool = ( return { success: true, data: result, - timestamp: Date.now(), }; } catch (error) { const errorMessage = diff --git a/lib/agent/tools/fillform.ts b/lib/agent/tools/fillform.ts index 487bb5c84..da61b9581 100644 --- a/lib/agent/tools/fillform.ts +++ b/lib/agent/tools/fillform.ts @@ -53,19 +53,52 @@ export const createFillFormTool = ( .map((field) => field.action) .join(", ")}`; - const observeResults = executionModel + let observeResults = executionModel ? await stagehand.page.observe({ instruction, modelName: executionModel, }) : await stagehand.page.observe(instruction); - const completedActions = []; - for (const result of observeResults) { - const action = await stagehand.page.act(result); - completedActions.push(action); + let usedIframe = false; + const hasIframeAction = observeResults?.some( + (r) => r.description === "an iframe", + ); + + if (hasIframeAction) { + const iframeObserveOptions = executionModel + ? { + instruction, + modelName: executionModel, + iframes: true, + } + : { + instruction, + iframes: true, + }; + + const iframeObserveResults = + await stagehand.page.observe(iframeObserveOptions); + + if (!iframeObserveResults || iframeObserveResults.length === 0) { + return { + success: false, + error: "No observable actions found within iframe context", + isIframe: true, + }; + } + + observeResults = iframeObserveResults; + usedIframe = true; } - return { success: true, actions: completedActions }; + for (const observeResult of observeResults) { + await stagehand.page.act(observeResult); + } + return { + success: true, + playwrightArguments: observeResults, + isIframe: usedIframe, + }; }, }); diff --git a/lib/agent/tools/fillformVision.ts b/lib/agent/tools/fillformVision.ts new file mode 100644 index 000000000..1f0acef28 --- /dev/null +++ b/lib/agent/tools/fillformVision.ts @@ -0,0 +1,66 @@ +import { tool } from "ai"; +import { z } from "zod/v3"; +import { Stagehand } from "../../index"; + +export const createFillFormVisionTool = (stagehand: Stagehand) => + tool({ + description: `📝 FORM FILL - SPECIALIZED MULTI-FIELD INPUT TOOL + + CRITICAL: Use this for ANY form with 2+ input fields (text inputs, textareas, etc.) + IMPORTANT: ensure the fields are visible within the current viewport + + WHY THIS TOOL EXISTS: + • Forms are the #1 use case for multi-field input + • Optimized specifically for input/textarea elements + • 4-6x faster than individual typing actions + + Use fillForm: Pure form filling (inputs, textareas only) + MANDATORY USE CASES (always use fillForm for these): + Registration forms: name, email, password fields + Contact forms: name, email, message fields + Checkout forms: address, payment info fields + Profile updates: multiple user data fields + Search filters: multiple criteria inputs + + + `, + parameters: z.object({ + fields: z + .array( + z.object({ + action: z + .string() + .describe( + "Description of the typing action, e.g. 'type foo into the bar field'", + ), + value: z.string().describe("Text to type into the target field"), + coordinates: z + .object({ + x: z.number(), + y: z.number(), + }) + .describe("Coordinates of the target field"), + }), + ) + .min(2, "Provide at least two fields to fill"), + }), + + execute: async ({ fields }) => { + for (const field of fields) { + await stagehand.page.mouse.move( + field.coordinates.x, + field.coordinates.y, + ); + await stagehand.page.mouse.click( + field.coordinates.x, + field.coordinates.y, + ); + await stagehand.page.keyboard.type(field.value); + await stagehand.page.waitForTimeout(100); + } + return { + success: true, + playwrightArguments: fields, + }; + }, + }); diff --git a/lib/agent/tools/goto.ts b/lib/agent/tools/goto.ts index b9fbb1a1e..9a1e21dfc 100644 --- a/lib/agent/tools/goto.ts +++ b/lib/agent/tools/goto.ts @@ -10,7 +10,7 @@ export const createGotoTool = (stagehand: Stagehand) => }), execute: async ({ url }) => { try { - await stagehand.page.goto(url, { waitUntil: "load" }); + await stagehand.page.goto(url, { waitUntil: "commit" }); return { success: true, url }; } catch (error) { return { success: false, error: error.message }; diff --git a/lib/agent/tools/index.ts b/lib/agent/tools/index.ts index d73574b40..d549d9fca 100644 --- a/lib/agent/tools/index.ts +++ b/lib/agent/tools/index.ts @@ -7,13 +7,38 @@ import { createCloseTool } from "./close"; import { createAriaTreeTool } from "./ariaTree"; import { createFillFormTool } from "./fillform"; import { createScrollTool } from "./scroll"; -import { Stagehand } from "../../index"; import { LogLine } from "@/types/log"; +import { thinkTool } from "./think"; +import { createClickTool } from "./click"; +import { createTypeTool } from "./type"; +import { createDragAndDropTool } from "./dragAndDrop"; +import { createSearchTool } from "./search"; +import { createKeysTool } from "./keys"; +import { createClickAndHoldTool } from "./clickAndHold"; +import { Stagehand } from "../../index"; +import { createFillFormVisionTool } from "./fillformVision"; +import type { ToolSet, ToolCallUnion, ToolResultUnion } from "ai"; import { createExtractTool } from "./extract"; - export interface AgentToolOptions { executionModel?: string; logger?: (message: LogLine) => void; + mainModel?: string; + storeActions?: boolean; +} + +function filterToolsByModelName(tools: ToolSet, isClaude: boolean): ToolSet { + const filtered: ToolSet = { ...tools }; + + if (isClaude) { + delete filtered.fillForm; + return filtered; + } + delete filtered.dragAndDrop; + delete filtered.clickAndHold; + delete filtered.click; + delete filtered.type; + delete filtered.fillFormVision; + return filtered; } export function createAgentTools( @@ -21,19 +46,59 @@ export function createAgentTools( options?: AgentToolOptions, ) { const executionModel = options?.executionModel; + const hasExaApiKey = process.env.EXA_API_KEY?.length > 0; + + // Detect model characteristics for tool configuration (defined once here) + const modelName = (options?.mainModel || "").toLowerCase().trim(); + const storeActions = options?.storeActions; + const isGpt5 = modelName.startsWith("gpt-5"); + const isClaude = modelName.startsWith("claude") && storeActions === false; - return { + const all = { act: createActTool(stagehand, executionModel), - ariaTree: createAriaTreeTool(stagehand), + ariaTree: createAriaTreeTool(stagehand, isGpt5), + click: createClickTool(stagehand), + clickAndHold: createClickAndHoldTool(stagehand), + dragAndDrop: createDragAndDropTool(stagehand), + type: createTypeTool(stagehand), close: createCloseTool(), - extract: createExtractTool(stagehand, executionModel, options?.logger), + think: thinkTool, fillForm: createFillFormTool(stagehand, executionModel), + fillFormVision: createFillFormVisionTool(stagehand), goto: createGotoTool(stagehand), navback: createNavBackTool(stagehand), - screenshot: createScreenshotTool(stagehand), - scroll: createScrollTool(stagehand), + screenshot: createScreenshotTool(stagehand, options?.mainModel), + scroll: createScrollTool(stagehand, isClaude), wait: createWaitTool(), - }; + ...(hasExaApiKey ? { search: createSearchTool() } : {}), + keys: createKeysTool(stagehand, isGpt5), + extract: createExtractTool(stagehand), + } satisfies ToolSet; + return filterToolsByModelName(all, isClaude); } export type AgentTools = ReturnType; + +export type AgentToolTypesMap = { + act: ReturnType; + ariaTree: ReturnType; + click: ReturnType; + clickAndHold: ReturnType; + dragAndDrop: ReturnType; + type: ReturnType; + close: ReturnType; + think: typeof thinkTool; + fillForm: ReturnType; + fillFormVision: ReturnType; + goto: ReturnType; + navback: ReturnType; + screenshot: ReturnType; + scroll: ReturnType; + wait: ReturnType; + search: ReturnType; + keys: ReturnType; + extract: ReturnType; +}; + +export type AgentToolCall = ToolCallUnion; +export type AgentToolResult = ToolResultUnion; diff --git a/lib/agent/tools/keys.ts b/lib/agent/tools/keys.ts new file mode 100644 index 000000000..5f3effd29 --- /dev/null +++ b/lib/agent/tools/keys.ts @@ -0,0 +1,137 @@ +import { tool } from "ai"; +import { z } from "zod/v3"; +import { Stagehand } from "../../index"; +import { resolvePlatform, normalizeKeys } from "../utils/cuaKeyMapping"; + +// Schema for models that support optional parameters well +const defaultParametersSchema = z.object({ + method: z + .enum(["press", "down", "up", "type", "insertText"]) + .describe("Keyboard method to use"), + keys: z + .union([z.string(), z.array(z.string())]) + .optional() + .describe( + "Key or combo for press/down/up. Use '+' to combine, e.g. 'mod+a' or ['Control','A'].", + ), + text: z.string().optional().describe("Text for type/insertText methods"), + repeat: z + .number() + .optional() + .describe("Repeat count for press/type. Default 1."), +}); + +// Schema for GPT-5: make all parameters required +// Use empty string "" for unused params (keys for type/insertText, text for press/down/up) +const gpt5ParametersSchema = z.object({ + method: z + .enum(["press", "down", "up", "type", "insertText"]) + .describe("Keyboard method to use"), + keys: z + .union([z.string(), z.array(z.string())]) + .describe( + "Key or combo for press/down/up. Use '+' to combine, e.g. 'mod+a' or ['Control','A']. Use empty string '' for type/insertText methods.", + ), + text: z + .string() + .describe( + "Text for type/insertText methods. Use empty string '' for press/down/up methods.", + ), + repeat: z + .number() + .describe("Repeat count for press/type. Use 1 for single execution."), +}); + +export const createKeysTool = (stagehand: Stagehand, isGpt5 = false) => { + const parametersSchema = isGpt5 + ? gpt5ParametersSchema + : defaultParametersSchema; + + return tool({ + description: + "Send keyboard events: press, down, up, type, or insertText. Supports combinations like mod+a, cmd+c, ctrl+v, etc. 'mod' maps to Command on macOS and Control on Windows/Linux. One really good use case of this tool, is clearing text from an input that is currently focused", + parameters: parametersSchema as z.ZodType<{ + method: "press" | "down" | "up" | "type" | "insertText"; + keys?: string | string[]; + text?: string; + repeat?: number; + }>, + execute: async ({ method, keys, text, repeat }) => { + try { + const userAgent = await stagehand.page.evaluate( + () => navigator.userAgent, + ); + const resolvedPlatform = resolvePlatform("auto", userAgent); + + const times = Math.max(1, repeat ?? 1); + + if (method === "type") { + if (!text || text === "") + return { + success: false, + error: "'text' is required for method 'type'", + }; + for (let i = 0; i < times; i++) { + await stagehand.page.keyboard.type(text, { delay: 100 }); + } + return { success: true, method, text, times }; + } + + if (method === "insertText") { + if (!text || text === "") + throw new Error("'text' is required for method 'insertText'"); + for (let i = 0; i < times; i++) { + await stagehand.page.keyboard.insertText(text); + await stagehand.page.waitForTimeout(100); + } + return { success: true, method, text, times }; + } + + if (!keys || keys === "" || (Array.isArray(keys) && keys.length === 0)) + throw new Error("'keys' is required for methods press/down/up"); + const { combo, tokens } = normalizeKeys(keys, resolvedPlatform); + + if (method === "press") { + for (let i = 0; i < times; i++) { + await stagehand.page.keyboard.press(combo, { delay: 100 }); + } + return { + success: true, + method, + keys: combo, + times, + }; + } + + if (method === "down") { + for (const token of tokens) { + await stagehand.page.keyboard.down(token); + await stagehand.page.waitForTimeout(100); + } + return { + success: true, + method, + keys: tokens, + }; + } + + if (method === "up") { + // Release in reverse order for combos + for (let i = tokens.length - 1; i >= 0; i--) { + await stagehand.page.keyboard.up(tokens[i]); + await stagehand.page.waitForTimeout(100); + } + return { + success: true, + method, + keys: tokens, + }; + } + + throw new Error(`Unsupported method: ${method}`); + } catch (error) { + return { success: false, error: (error as Error).message }; + } + }, + }); +}; diff --git a/lib/agent/tools/navback.ts b/lib/agent/tools/navback.ts index 829b7c0c6..2b6a559db 100644 --- a/lib/agent/tools/navback.ts +++ b/lib/agent/tools/navback.ts @@ -9,7 +9,11 @@ export const createNavBackTool = (stagehand: Stagehand) => reasoning: z.string().describe("Why you're going back"), }), execute: async () => { - await stagehand.page.goBack(); - return { success: true }; + try { + await stagehand.page.goBack(); + return { success: true }; + } catch (error) { + return { success: false, error: error.message }; + } }, }); diff --git a/lib/agent/tools/screenshot.ts b/lib/agent/tools/screenshot.ts index a563290fc..1fa7dc120 100644 --- a/lib/agent/tools/screenshot.ts +++ b/lib/agent/tools/screenshot.ts @@ -2,31 +2,51 @@ import { tool } from "ai"; import { z } from "zod/v3"; import { Stagehand } from "../../index"; -export const createScreenshotTool = (stagehand: Stagehand) => - tool({ +export const createScreenshotTool = ( + stagehand: Stagehand, + modelName?: string, +) => { + // Determine if we should use PNG (for Anthropic models) or JPEG (for others) + const normalized = (modelName || "").toLowerCase().trim(); + const isAnthropic = normalized.startsWith("claude"); + const imageType = isAnthropic ? "png" : "jpeg"; + const mimeType = isAnthropic ? "image/png" : "image/jpeg"; + + return tool({ description: "Takes a screenshot of the current page. Use this tool to learn where you are on the page, or to get context of elements on the page", parameters: z.object({}), execute: async () => { - const screenshotBuffer = await stagehand.page.screenshot({ - fullPage: false, - type: "jpeg", - }); - const pageUrl = stagehand.page.url(); + try { + const screenshotBuffer = await stagehand.page.screenshot({ + fullPage: false, + type: imageType, + }); + const pageUrl = stagehand.page.url(); - return { - base64: screenshotBuffer.toString("base64"), - timestamp: Date.now(), - pageUrl, - }; + return { + base64: screenshotBuffer.toString("base64"), + timestamp: Date.now(), + pageUrl, + }; + } catch { + return { + error: `Error taking screenshot, try again`, + }; + } }, experimental_toToolResultContent: (result) => { + if (result.error) { + return [{ type: "text", text: `Error, try again: ${result.error}` }]; + } + return [ { type: "image", data: result.base64, - mimeType: "image/jpeg", + mimeType, }, ]; }, }); +}; diff --git a/lib/agent/tools/scroll.ts b/lib/agent/tools/scroll.ts index e467208a6..cb49d741a 100644 --- a/lib/agent/tools/scroll.ts +++ b/lib/agent/tools/scroll.ts @@ -2,18 +2,68 @@ import { tool } from "ai"; import { z } from "zod/v3"; import { Stagehand } from "../../index"; -export const createScrollTool = (stagehand: Stagehand) => - tool({ - description: "Scroll the page", - parameters: z.object({ - pixels: z.number().describe("Number of pixels to scroll up or down"), - direction: z.enum(["up", "down"]).describe("Direction to scroll"), - }), - execute: async ({ pixels, direction }) => { +// Schema for Claude CUA models - includes coordinates parameter for precise scrolling +const claudeParametersSchema = z.object({ + percentage: z + .number() + .min(1) + .max(200) + .default(80) + .optional() + .describe("Percentage of viewport height to scroll (1-200%, default: 80%)"), + direction: z.enum(["up", "down"]).describe("Direction to scroll"), + coordinates: z + .array(z.number()) + .describe( + "the (x, y) coordinates to scroll inside of, if not provided, will scroll the page", + ) + .optional(), +}); + +// Schema for non-Claude models - no coordinates parameter +const defaultParametersSchema = z.object({ + percentage: z + .number() + .min(1) + .max(200) + .describe("Percentage of viewport height to scroll (1-200%, default: 80%)"), + direction: z.enum(["up", "down"]).describe("Direction to scroll"), +}); + +export const createScrollTool = (stagehand: Stagehand, isClaude = false) => { + const parametersSchema = isClaude + ? claudeParametersSchema + : defaultParametersSchema; + + return tool({ + description: + "Scroll the page by a percentage of the current viewport height. More dynamic and robust than fixed pixel amounts.", + parameters: parametersSchema as z.ZodType<{ + percentage?: number; + direction: "up" | "down"; + coordinates?: number[]; + }>, + execute: async (params) => { + const percentage = params.percentage ?? 80; + const direction = params.direction; + const coordinates = + "coordinates" in params ? params.coordinates : undefined; + const viewportHeight = await stagehand.page.evaluate( + () => window.innerHeight, + ); + const scrollDistance = Math.round((viewportHeight * percentage) / 100); + + if (coordinates && coordinates.length > 0) { + await stagehand.page.mouse.move(coordinates[0], coordinates[1]); + } await stagehand.page.mouse.wheel( 0, - direction === "up" ? -pixels : pixels, + direction === "up" ? -scrollDistance : scrollDistance, ); - return { success: true, pixels }; + return { + success: true, + message: `scrolled ${percentage}% of viewport ${direction} (${scrollDistance}px of ${viewportHeight}px viewport height)`, + }; }, }); +}; diff --git a/lib/agent/tools/search.ts b/lib/agent/tools/search.ts new file mode 100644 index 000000000..1c4ff08bc --- /dev/null +++ b/lib/agent/tools/search.ts @@ -0,0 +1,77 @@ +import { tool } from "ai"; +import { z } from "zod"; +import Exa from "exa-js"; + +export interface ExaSearchResult { + id: string; + title: string; + url: string; + publishedDate?: string; + author?: string; + favicon?: string; + score?: number; + image?: string; +} +interface SearchResponse { + data?: { + results: ExaSearchResult[]; + }; + error?: string; +} + +async function performExaSearch(query: string): Promise { + try { + const exa = new Exa(process.env.EXA_API_KEY); + + const response = await exa.search(query, { + type: "auto", + numResults: 5, + }); + + const responseObj = response; + const results: ExaSearchResult[] = []; + + if (responseObj?.results && Array.isArray(responseObj.results)) { + responseObj.results.forEach((item) => { + if (item.id && item.title && item.url) { + results.push({ + id: item.id, + title: item.title, + url: item.url, + }); + } + }); + } + + return { + data: { + results: results, + }, + }; + } catch (error) { + console.error("Search error", error); + return { + error: `Error performing search`, + data: { + results: [], + }, + }; + } +} + +export const createSearchTool = () => { + return tool({ + description: + "Perform a web search and returns results. Use this tool when you need information from the web or when you are unsure of the exact URL you want to navigate to. This can be used to find the ideal entry point, resulting in a task that is easier to complete due to starting further in the process.", + parameters: z.object({ + query: z.string().describe("The search query to look for on the page"), + }), + execute: async ({ query }: { query: string }) => { + const result = await performExaSearch(query); + return { + ...result, + timestamp: Date.now(), + }; + }, + }); +}; diff --git a/lib/agent/tools/think.ts b/lib/agent/tools/think.ts new file mode 100644 index 000000000..003333428 --- /dev/null +++ b/lib/agent/tools/think.ts @@ -0,0 +1,26 @@ +import { tool } from "ai"; +import { z } from "zod"; + +export const thinkTool = tool({ + description: `Use this tool to think through complex problems or plan a sequence of steps. This is for internal reasoning only and doesn't perform any actions. Use this to: + + 1. Plan a multi-step approach before taking action + 2. Break down complex tasks + 3. Reason through edge cases + 4. Evaluate options when you're unsure what to do next + + The output is only visible to you; use it to track your own reasoning process.`, + parameters: z.object({ + reasoning: z + .string() + .describe( + "Your step-by-step reasoning or planning process. Be as detailed as needed.", + ), + }), + execute: async ({ reasoning }: { reasoning: string }) => { + return { + acknowledged: true, + message: reasoning, + }; + }, +}); diff --git a/lib/agent/tools/type.ts b/lib/agent/tools/type.ts new file mode 100644 index 000000000..8e672cb73 --- /dev/null +++ b/lib/agent/tools/type.ts @@ -0,0 +1,31 @@ +import { tool } from "ai"; +import { z } from "zod/v3"; +import { Stagehand } from "../../index"; + +export const createTypeTool = (stagehand: Stagehand) => + tool({ + description: + "Type text into an element using its coordinates. this will click the element and then type the text into it ( this is the most reliable way to type into an element, always use this over act, unless the element is not visible in the screenshot, but shown in ariaTree)", + parameters: z.object({ + describe: z + .string() + .describe( + "Describe the element to type into in a short, specific phrase that mentions the element type and a good visual description", + ), + text: z.string().describe("The text to type into the element"), + coordinates: z + .array(z.number()) + .describe("The (x, y) coordinates to type into the element"), + }), + execute: async ({ describe, coordinates, text }) => { + try { + await stagehand.page.mouse.move(coordinates[0], coordinates[1]); + await stagehand.page.waitForTimeout(50); + await stagehand.page.mouse.click(coordinates[0], coordinates[1]); + await stagehand.page.keyboard.type(text); + } catch { + return { success: false, error: `Error typing, try again` }; + } + return { success: true, describe }; + }, + }); diff --git a/lib/agent/tools/wait.ts b/lib/agent/tools/wait.ts index 2311eed53..c478f41dc 100644 --- a/lib/agent/tools/wait.ts +++ b/lib/agent/tools/wait.ts @@ -9,6 +9,6 @@ export const createWaitTool = () => }), execute: async ({ timeMs }) => { await new Promise((resolve) => setTimeout(resolve, timeMs)); - return { success: true, waited: timeMs }; + return { success: true }; }, }); diff --git a/lib/agent/utils/actionHandler.ts b/lib/agent/utils/actionHandler.ts new file mode 100644 index 000000000..78bbf00d4 --- /dev/null +++ b/lib/agent/utils/actionHandler.ts @@ -0,0 +1,74 @@ +import { AgentAction } from "@/types/agent"; +import { AgentToolResult, AgentToolCall } from "@/lib/agent/tools"; + +export interface ActionHandlerOptions { + toolCallName: string; + toolResult: AgentToolResult; + args: AgentToolCall["args"]; + reasoning?: string; +} + +export function mapToolResultToActions({ + toolCallName, + toolResult, + args, + reasoning, +}: ActionHandlerOptions): AgentAction[] { + if (toolResult) { + if (toolResult.toolName === "act") { + const result = toolResult.result; + + const playwrightArguments = result.playwrightArguments + ? { playwrightArguments: result.playwrightArguments } + : {}; + + return [ + { + type: "act", + reasoning, + taskCompleted: false, + ...playwrightArguments, + }, + ]; + } + + if (toolResult.toolName === "fillForm") { + const result = toolResult.result; + const observeResults = Array.isArray(result.playwrightArguments) + ? result.playwrightArguments + : []; + + const actions: AgentAction[] = []; + + actions.push({ + type: "fillForm", + reasoning, + taskCompleted: false, + ...args, + }); + + for (const observeResult of observeResults) { + actions.push({ + type: "act", + reasoning: "acting from fillform tool", + taskCompleted: false, + playwrightArguments: observeResult, + }); + } + + return actions; + } + } + + return [ + { + type: toolCallName, + reasoning, + taskCompleted: + toolCallName === "close" && args && "taskComplete" in args + ? args.taskComplete + : false, + ...args, + }, + ]; +} diff --git a/lib/agent/utils/cuaKeyMapping.ts b/lib/agent/utils/cuaKeyMapping.ts index 442df5eab..f1e8f3f9c 100644 --- a/lib/agent/utils/cuaKeyMapping.ts +++ b/lib/agent/utils/cuaKeyMapping.ts @@ -60,3 +60,80 @@ export function mapKeyToPlaywright(key: string): string { const upperKey = key.toUpperCase(); return KEY_MAP[upperKey] || key; } + +/** + * Cross-client platform and key normalization helpers used by CUA and tools + */ +//to do: utilize this within other cua handlers, with OS detection for more robust key handling + multi key handling +export type Platform = "mac" | "windows" | "linux"; + +export function detectPlatformFromUserAgent( + userAgent: string, +): Platform | undefined { + const ua = userAgent.toLowerCase(); + if (ua.includes("mac os x") || ua.includes("macintosh")) return "mac"; + if (ua.includes("windows nt")) return "windows"; + if (ua.includes("linux") || ua.includes("cros")) return "linux"; + return undefined; +} + +export function resolvePlatform( + param: Platform | "auto" | undefined, + userAgent: string | undefined, +): Platform { + if (param && param !== "auto") return param; + const fromUa = userAgent ? detectPlatformFromUserAgent(userAgent) : undefined; + if (fromUa) return fromUa; + if (typeof process !== "undefined") { + if (process.platform === "darwin") return "mac"; + if (process.platform === "win32") return "windows"; + } + return "linux"; +} + +export function normalizeKeyToken(raw: string, platform: Platform): string { + const t = raw.trim(); + const lower = t.toLowerCase(); + if (lower === "mod") return platform === "mac" ? "Meta" : "Control"; + if ( + lower === "cmd" || + lower === "command" || + lower === "meta" || + lower === "⌘" + ) + return "Meta"; + if (lower === "ctrl" || lower === "control") return "Control"; + if (lower === "alt" || lower === "option" || lower === "opt") return "Alt"; + if (lower === "shift") return "Shift"; + if (lower === "enter" || lower === "return") return "Enter"; + if (lower === "esc" || lower === "escape") return "Escape"; + if (lower === "space" || lower === "spacebar") return "Space"; + if (lower === "left" || lower === "arrowleft") return "ArrowLeft"; + if (lower === "right" || lower === "arrowright") return "ArrowRight"; + if (lower === "up" || lower === "arrowup") return "ArrowUp"; + if (lower === "down" || lower === "arrowdown") return "ArrowDown"; + if (lower === "pgup" || lower === "pageup") return "PageUp"; + if (lower === "pgdn" || lower === "pagedown") return "PageDown"; + if (lower === "del" || lower === "delete") return "Delete"; + if (lower === "backspace") return "Backspace"; + if (lower === "tab") return "Tab"; + if (lower === "home") return "Home"; + if (lower === "end") return "End"; + // Upper-case single letters (a -> A), keep others as-is + if (/^[a-z]$/.test(lower)) return lower.toUpperCase(); + return t; +} + +export function normalizeKeys( + keys: string | string[], + platform: Platform, +): { combo: string; tokens: string[] } { + const tokens = Array.isArray(keys) + ? keys + : keys + .split("+") + .map((k) => k.trim()) + .filter(Boolean); + const normalizedTokens = tokens.map((k) => normalizeKeyToken(k, platform)); + return { combo: normalizedTokens.join("+"), tokens: normalizedTokens }; +} diff --git a/lib/agent/utils/messageProcessing.ts b/lib/agent/utils/messageProcessing.ts deleted file mode 100644 index faee3f33b..000000000 --- a/lib/agent/utils/messageProcessing.ts +++ /dev/null @@ -1,201 +0,0 @@ -import { type LanguageModelV1CallOptions } from "ai"; - -export interface CompressionStats { - originalSize: number; - compressedSize: number; - savedChars: number; - compressionRatio: number; - screenshotCount: number; - ariaTreeCount: number; -} - -function isToolMessage( - message: unknown, -): message is { role: "tool"; content: unknown[] } { - return ( - !!message && - typeof message === "object" && - (message as { role?: unknown }).role === "tool" && - Array.isArray((message as { content?: unknown }).content) - ); -} - -function isScreenshotPart(part: unknown): boolean { - return ( - !!part && - typeof part === "object" && - (part as { toolName?: unknown }).toolName === "screenshot" - ); -} - -function isAriaTreePart(part: unknown): boolean { - return ( - !!part && - typeof part === "object" && - (part as { toolName?: unknown }).toolName === "ariaTree" - ); -} - -export function processMessages(params: LanguageModelV1CallOptions): { - processedPrompt: LanguageModelV1CallOptions["prompt"]; - stats: CompressionStats; -} { - // Calculate original content size - const originalContentSize = JSON.stringify(params.prompt).length; - const screenshotIndices = findToolIndices(params.prompt, "screenshot"); - const ariaTreeIndices = findToolIndices(params.prompt, "ariaTree"); - - // Process messages and compress old screenshots - const processedPrompt = params.prompt.map((message, index) => { - if (isToolMessage(message)) { - if ( - (message.content as unknown[]).some((part) => isScreenshotPart(part)) - ) { - const shouldCompress = shouldCompressScreenshot( - index, - screenshotIndices, - ); - if (shouldCompress) { - return compressScreenshotMessage(message); - } - } - if ((message.content as unknown[]).some((part) => isAriaTreePart(part))) { - const shouldCompress = shouldCompressAriaTree(index, ariaTreeIndices); - if (shouldCompress) { - return compressAriaTreeMessage(message); - } - } - } - - return { ...message }; - }); - - const compressedContentSize = JSON.stringify(processedPrompt).length; - const stats = calculateCompressionStats( - originalContentSize, - compressedContentSize, - screenshotIndices.length, - ariaTreeIndices.length, - ); - - return { - processedPrompt: - processedPrompt as unknown as LanguageModelV1CallOptions["prompt"], - stats, - }; -} - -function findToolIndices( - prompt: unknown[], - toolName: "screenshot" | "ariaTree", -): number[] { - const screenshotIndices: number[] = []; - - prompt.forEach((message, index) => { - if (isToolMessage(message)) { - const hasMatch = (message.content as unknown[]).some((part) => - toolName === "screenshot" - ? isScreenshotPart(part) - : isAriaTreePart(part), - ); - if (hasMatch) { - screenshotIndices.push(index); - } - } - }); - - return screenshotIndices; -} - -function shouldCompressScreenshot( - index: number, - screenshotIndices: number[], -): boolean { - const isNewestScreenshot = index === Math.max(...screenshotIndices); - const isSecondNewestScreenshot = - screenshotIndices.length > 1 && - index === screenshotIndices.sort((a, b) => b - a)[1]; - - return !isNewestScreenshot && !isSecondNewestScreenshot; -} - -function shouldCompressAriaTree( - index: number, - ariaTreeIndices: number[], -): boolean { - const isNewestAriaTree = index === Math.max(...ariaTreeIndices); - // Only keep the most recent ARIA tree - return !isNewestAriaTree; -} - -function compressScreenshotMessage(message: { - role: "tool"; - content: unknown[]; -}): { role: "tool"; content: unknown[] } { - const updatedContent = (message.content as unknown[]).map((part) => { - if (isScreenshotPart(part)) { - return { - ...(part as object), - result: [ - { - type: "text", - text: "screenshot taken", - }, - ], - } as unknown; - } - return part; - }); - - return { - ...message, - content: updatedContent, - } as { role: "tool"; content: unknown[] }; -} - -function compressAriaTreeMessage(message: { - role: "tool"; - content: unknown[]; -}): { role: "tool"; content: unknown[] } { - const updatedContent = (message.content as unknown[]).map((part) => { - if (isAriaTreePart(part)) { - return { - ...(part as object), - result: [ - { - type: "text", - text: "ARIA tree extracted for context of page elements", - }, - ], - } as unknown; - } - return part; - }); - - return { - ...message, - content: updatedContent, - } as { role: "tool"; content: unknown[] }; -} - -function calculateCompressionStats( - originalSize: number, - compressedSize: number, - screenshotCount: number, - ariaTreeCount: number, -): CompressionStats { - const savedChars = originalSize - compressedSize; - const compressionRatio = - originalSize > 0 - ? ((originalSize - compressedSize) / originalSize) * 100 - : 0; - - return { - originalSize, - compressedSize, - savedChars, - compressionRatio, - screenshotCount, - ariaTreeCount, - }; -} diff --git a/lib/agent/utils/modelWrapper.ts b/lib/agent/utils/modelWrapper.ts new file mode 100644 index 000000000..21cae8931 --- /dev/null +++ b/lib/agent/utils/modelWrapper.ts @@ -0,0 +1,25 @@ +import { wrapLanguageModel } from "ai"; +import type { LanguageModel } from "ai"; +import type { LLMClient } from "../../llm/LLMClient"; +import { ContextManager } from "../contextManager"; + +export function modelWrapper( + llmClient: LLMClient, + contextManager: ContextManager, + sessionId: string, +): LanguageModel { + const baseModel: LanguageModel = llmClient.getLanguageModel(); + return wrapLanguageModel({ + model: baseModel, + middleware: { + transformParams: async ({ params }) => { + const processedPrompt = await contextManager.processMessages( + params.prompt, + sessionId, + llmClient, + ); + return { ...params, prompt: processedPrompt }; + }, + }, + }); +} diff --git a/lib/agent/utils/processStepFinish.ts b/lib/agent/utils/processStepFinish.ts new file mode 100644 index 000000000..cb41874c2 --- /dev/null +++ b/lib/agent/utils/processStepFinish.ts @@ -0,0 +1,97 @@ +import { AgentAction } from "@/types/agent"; +import { LogLine } from "@/types/log"; +import type { AgentToolCall, AgentToolResult } from "../tools"; +import { mapToolResultToActions } from "./actionHandler"; + +export interface StepFinishEventLike { + finishReason?: string; + text: string; + toolCalls?: Array<{ + toolName: string; + args: unknown; + }>; + toolResults?: Array; +} + +export interface ProcessedStepResult { + actionsAppended: AgentAction[]; + collectedReasoning?: string; + completed: boolean; + finalMessage?: string; +} + +export function processStepFinishEvent( + event: StepFinishEventLike, + logger: (message: LogLine) => void, + priorReasoning: string[], +): ProcessedStepResult { + const actions: AgentAction[] = []; + let completed = false; + let finalMessage: string | undefined; + + logger({ + category: "agent", + message: `Step finished: ${event.finishReason}`, + level: 2, + }); + + if (event.toolCalls && event.toolCalls.length > 0) { + for (let i = 0; i < event.toolCalls.length; i++) { + const toolCall = event.toolCalls[i]; + const typedToolCall = toolCall as AgentToolCall; + + logger({ + category: "agent", + message: `tool call: ${typedToolCall.toolName} with args: ${JSON.stringify(typedToolCall.args)}`, + level: 1, + }); + if (event.text.length > 0) { + priorReasoning.push(event.text); + logger({ + category: "agent", + message: `reasoning: ${event.text}`, + level: 1, + }); + } + if (typedToolCall.toolName === "close") { + completed = true; + if (typedToolCall.args?.taskComplete) { + const closeReasoning = typedToolCall.args.reasoning; + const allReasoning = priorReasoning.join(" "); + finalMessage = closeReasoning + ? `${allReasoning} ${closeReasoning}`.trim() + : allReasoning || "Task completed successfully"; + } + } + + const toolResult = event.toolResults?.[i] || null; + + const mapped = mapToolResultToActions({ + toolCallName: typedToolCall.toolName, + toolResult, + args: typedToolCall.args || {}, + reasoning: event.text || undefined, + }); + + actions.push(...mapped); + } + } + + return { + actionsAppended: actions, + collectedReasoning: priorReasoning.join(" "), + completed, + finalMessage, + }; +} + +export function finalizeAgentMessage( + currentFinalMessage: string | undefined, + priorReasoning: string[], + resultText?: string, +): string { + const existing = (currentFinalMessage || "").trim(); + if (existing.length > 0) return existing; + const allReasoning = priorReasoning.join(" ").trim(); + return allReasoning || resultText || "Task completed successfully"; +} diff --git a/lib/handlers/stagehandAgentHandler.ts b/lib/handlers/stagehandAgentHandler.ts index cbb309683..1645e897d 100644 --- a/lib/handlers/stagehandAgentHandler.ts +++ b/lib/handlers/stagehandAgentHandler.ts @@ -1,17 +1,19 @@ -import { - AgentAction, - AgentExecuteOptions, - AgentResult, - ActToolResult, -} from "@/types/agent"; +import { AgentAction, AgentExecuteOptions, AgentResult } from "@/types/agent"; import { LogLine } from "@/types/log"; import { LLMClient } from "../llm/LLMClient"; -import { CoreMessage, wrapLanguageModel } from "ai"; -import { LanguageModel } from "ai"; -import { processMessages } from "../agent/utils/messageProcessing"; -import { createAgentTools } from "../agent/tools"; +import { CoreMessage } from "ai"; +import { createAgentTools, type AgentTools } from "../agent/tools"; +import { buildStagehandAgentSystemPrompt } from "../prompt"; +import { + finalizeAgentMessage, + processStepFinishEvent, +} from "../agent/utils/processStepFinish"; import { ToolSet } from "ai"; +import { ContextManager } from "../agent/contextManager"; +import { modelWrapper } from "../agent/utils/modelWrapper"; +import { randomUUID } from "crypto"; import { Stagehand } from "../index"; +import { ScreenshotCollector } from "../../evals/utils/ScreenshotCollector"; export class StagehandAgentHandler { private stagehand: Stagehand; @@ -20,8 +22,8 @@ export class StagehandAgentHandler { private executionModel?: string; private systemInstructions?: string; private tools?: ToolSet; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private screenshotCollector?: any; + private contextManager: ContextManager; + private screenshotCollector?: ScreenshotCollector; constructor( stagehand: Stagehand, @@ -37,30 +39,36 @@ export class StagehandAgentHandler { this.executionModel = executionModel; this.systemInstructions = systemInstructions; this.tools = tools; + this.contextManager = new ContextManager(logger); } public async execute( instructionOrOptions: string | AgentExecuteOptions, ): Promise { const startTime = Date.now(); + const sessionId = randomUUID(); const options = typeof instructionOrOptions === "string" ? { instruction: instructionOrOptions } : instructionOrOptions; const maxSteps = options.maxSteps || 10; + const storeActions = options.storeActions ?? true; const actions: AgentAction[] = []; let finalMessage = ""; let completed = false; const collectedReasoning: string[] = []; try { - const systemPrompt = this.buildSystemPrompt( + const systemPrompt = buildStagehandAgentSystemPrompt( + this.stagehand.page.url(), + this.llmClient?.modelName, options.instruction, this.systemInstructions, + storeActions, ); - const defaultTools = this.createTools(); - const allTools = { ...defaultTools, ...this.tools }; + const tools = this.createTools(storeActions); + const allTools: ToolSet = { ...tools, ...this.tools }; const messages: CoreMessage[] = [ { role: "user", @@ -79,16 +87,11 @@ export class StagehandAgentHandler { "StagehandAgentHandler requires an AISDK-backed LLM client. Ensure your model is configured like 'openai/gpt-4.1-mini' in the provider/model format.", ); } - const baseModel: LanguageModel = this.llmClient.getLanguageModel(); - const wrappedModel = wrapLanguageModel({ - model: baseModel, - middleware: { - transformParams: async ({ params }) => { - const { processedPrompt } = processMessages(params); - return { ...params, prompt: processedPrompt }; - }, - }, - }); + const wrappedModel = modelWrapper( + this.llmClient, + this.contextManager, + sessionId, + ); const result = await this.llmClient.generateText({ model: wrappedModel, @@ -99,77 +102,28 @@ export class StagehandAgentHandler { temperature: 1, toolChoice: "auto", onStepFinish: async (event) => { - this.logger({ - category: "agent", - message: `Step finished: ${event.finishReason}`, - level: 2, - }); - - if (event.toolCalls && event.toolCalls.length > 0) { - for (let i = 0; i < event.toolCalls.length; i++) { - const toolCall = event.toolCalls[i]; - const args = toolCall.args as Record; - - if (event.text.length > 0) { - collectedReasoning.push(event.text); - this.logger({ - category: "agent", - message: `reasoning: ${event.text}`, - level: 1, - }); - } - - if (toolCall.toolName === "close") { - completed = true; - if (args?.taskComplete) { - const closeReasoning = args.reasoning as string; - const allReasoning = collectedReasoning.join(" "); - finalMessage = closeReasoning - ? `${allReasoning} ${closeReasoning}`.trim() - : allReasoning || "Task completed successfully"; - } - } - - // Get the tool result if available - const toolResult = event.toolResults?.[i]; - - const getPlaywrightArguments = () => { - if (toolCall.toolName !== "act" || !toolResult) { - return {}; - } - const result = toolResult.result as ActToolResult; - if (result && result.playwrightArguments) { - return { playwrightArguments: result.playwrightArguments }; - } - - return {}; - }; - - const action: AgentAction = { - type: toolCall.toolName, - reasoning: event.text || undefined, - taskCompleted: - toolCall.toolName === "close" - ? (args?.taskComplete as boolean) - : false, - ...args, - ...getPlaywrightArguments(), - }; - - actions.push(action); - } - } + const processed = processStepFinishEvent( + event, + this.logger, + collectedReasoning, + ); + actions.push(...processed.actionsAppended); + if (processed.completed) completed = true; + if (processed.finalMessage) finalMessage = processed.finalMessage; }, }); - if (!finalMessage) { - const allReasoning = collectedReasoning.join(" ").trim(); - finalMessage = allReasoning || result.text; - } + finalMessage = finalizeAgentMessage( + finalMessage, + collectedReasoning, + result.text, + ); const endTime = Date.now(); const inferenceTimeMs = endTime - startTime; + this.contextManager.clearSession(sessionId); + return { success: completed, message: finalMessage || "Task execution completed", @@ -184,6 +138,7 @@ export class StagehandAgentHandler { : undefined, }; } catch (error) { + this.contextManager.clearSession(sessionId); const errorMessage = error instanceof Error ? error.message : String(error); this.logger({ @@ -201,64 +156,25 @@ export class StagehandAgentHandler { } } - // in the future if we continue to describe tools in system prompt, we need to make sure to update them in here when new tools are added or removed. still tbd on whether we want to keep them in here long term. - private buildSystemPrompt( - executionInstruction: string, - systemInstructions?: string, - ): string { - if (systemInstructions) { - return `${systemInstructions} -Your current goal: ${executionInstruction}`; - } - - return `You are a web automation assistant using browser automation tools to accomplish the user's goal. - -Your task: ${executionInstruction} - -You have access to various browser automation tools. Use them step by step to complete the task. - -IMPORTANT GUIDELINES: -1. Always start by understanding the current page state -2. Use the screenshot tool to verify page state when needed -3. Use appropriate tools for each action -4. When the task is complete, use the "close" tool with taskComplete: true -5. If the task cannot be completed, use "close" with taskComplete: false - -TOOLS OVERVIEW: -- screenshot: Take a compressed JPEG screenshot for quick visual context (use sparingly) -- ariaTree: Get an accessibility (ARIA) hybrid tree for full page context (preferred for understanding layout and elements) -- act: Perform a specific atomic action (click, type, etc.) -- extract: Extract structured data -- goto: Navigate to a URL -- wait/navback/refresh: Control timing and navigation -- scroll: Scroll the page x pixels up or down - -STRATEGY: -- Prefer ariaTree to understand the page before acting; use screenshot for quick confirmation. -- Keep actions atomic and verify outcomes before proceeding. - -For each action, provide clear reasoning about why you're taking that step.`; - } - - private createTools() { + private createTools(storeActions: boolean): AgentTools { return createAgentTools(this.stagehand, { executionModel: this.executionModel, + mainModel: this.llmClient?.modelName || undefined, logger: this.logger, + storeActions, }); } /** * Set the screenshot collector for this agent handler */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - setScreenshotCollector(collector: any): void { + setScreenshotCollector(collector: ScreenshotCollector): void { this.screenshotCollector = collector; } /** * Get the screenshot collector */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - getScreenshotCollector(): any { + getScreenshotCollector(): ScreenshotCollector | undefined { return this.screenshotCollector; } setTools(tools: ToolSet): void { diff --git a/lib/prompt.ts b/lib/prompt.ts index ba5fb7112..309f5445c 100644 --- a/lib/prompt.ts +++ b/lib/prompt.ts @@ -175,37 +175,206 @@ export function buildActObservePrompt( return instruction; } -export function buildOperatorSystemPrompt(goal: string): ChatMessage { - return { - role: "system", - content: `You are a general-purpose agent whose job is to accomplish the user's goal across multiple model calls by running actions on the page. - -You will be given a goal and a list of steps that have been taken so far. Your job is to determine if either the user's goal has been completed or if there are still steps that need to be taken. - -# Your current goal -${goal} - -# CRITICAL: You MUST use the provided tools to take actions. Do not just describe what you want to do - actually call the appropriate tools. - -# Available tools and when to use them: -- \`act\`: Use this to interact with the page (click, type, navigate, etc.) -- \`extract\`: Use this to get information from the page -- \`goto\`: Use this to navigate to a specific URL -- \`wait\`: Use this to wait for a period of time -- \`navback\`: Use this to go back to the previous page -- \`refresh\`: Use this to refresh the current page -- \`close\`: Use this ONLY when the task is complete or cannot be achieved -- External tools: Use any additional tools (like search tools) as needed for your goal - -# Important guidelines -1. ALWAYS use tools - never just provide text responses about what you plan to do -2. Break down complex actions into individual atomic steps -3. For \`act\` commands, use only one action at a time, such as: - - Single click on a specific element - - Type into a single input field - - Select a single option -4. Avoid combining multiple actions in one instruction -5. If multiple actions are needed, they should be separate steps -6. Only use \`close\` when the task is genuinely complete or impossible to achieve`, - }; +export function buildStagehandAgentSystemPrompt( + url: string, + modelName: string, + executionInstruction: string, + systemInstructions?: string, + storeActions: boolean = true, +): string { + const localeDate = new Date().toLocaleDateString(); + const isoDate = new Date().toISOString(); + const cdata = (text: string) => ``; + + const normalizedModel = (modelName || "").toLowerCase().trim(); + const isAnthropic = + normalizedModel.startsWith("claude") && storeActions === false; + + const useAnthropicCustomizations = isAnthropic === false; + + const hasSearch = process.env.EXA_API_KEY?.length > 0; + + const searchToolLine = hasSearch + ? `\n Perform a web search and return results. Prefer this over navigating to Google and searching within the page for reliability and efficiency.` + : ""; + + const toolsSection = useAnthropicCustomizations + ? ` + Take a compressed JPEG screenshot for quick visual context + Get an accessibility (ARIA) hybrid tree for full page context + Click on an element (PREFERRED - more reliable when element is visible in viewport) + Type text into an element (PREFERRED - more reliable when element is visible in viewport) + Perform a specific atomic action (click, type, etc.) - ONLY use when element is in ariaTree but NOT visible in screenshot. Less reliable but can interact with out-of-viewport elements. + Drag and drop an element + Press a keyboard key + Fill out a form + Think about the task + Extract structured data + Navigate to a URL + Control timing and navigation + Scroll the page x pixels up or down + ${searchToolLine} + ` + : ` + Take a compressed JPEG screenshot for quick visual context + Get an accessibility (ARIA) hybrid tree for full page context + Perform a specific atomic action (click, type) + Press a keyboard key + Fill out a form + Think about the task + Extract structured data + Navigate to a URL + Control timing and navigation + Scroll the page x pixels up or down + ${searchToolLine} + `; + + // Build strategy items based on whether model is Anthropic or not + const strategyItems = isAnthropic + ? [ + `Tool selection priority: Use specific tools (click, type) when elements are visible in viewport for maximum reliability.`, + `Always use screenshot to get proper grounding of the coordinates you want to type/click into.`, + `When interacting with an input, always use the type tool to type into the input, over clicking and then typing into it.`, + `Use ariaTree as a secondary tool when elements aren't visible in screenshot or to get full page context.`, + `Only use act when element is in ariaTree but NOT visible in screenshot.`, + ] + : [ + `Tool selection priority: Use act tool for all clicking and typing on a page.`, + `Always check ariaTree first to understand full page content without scrolling - it shows all elements including those below the fold.`, + `When interacting with an input, always use the act tool to type into the input, over clicking and then typing.`, + `If an element is present in the ariaTree, use act to interact with it directly - this eliminates the need to scroll.`, + `Use screenshot for visual confirmation when needed, but rely primarily on ariaTree for element detection.`, + ]; + + const strategySection = strategyItems.join("\n "); + + const commonStrategyItems = ` + CRITICAL: Use extract ONLY when the task explicitly requires structured data output (e.g., "get job listings", "extract product details"). For reading page content or understanding elements, always use ${isAnthropic ? "screenshot or ariaTree" : "ariaTree"} instead - it's faster and more reliable. + Keep actions atomic and verify outcomes before proceeding. + For each action, provide clear reasoning about why you're taking that step. + When you need to input text that could be entered character-by-character or through multiple separate inputs, prefer using the keys tool to type the entire sequence at once. This is more efficient for scenarios like verification codes split across multiple fields, or when virtual keyboards are present but direct typing would be faster. + `; + + const pageUnderstandingProtocol = isAnthropic + ? ` + + UNDERSTAND THE PAGE + + screenshot + Visual confirmation when needed. Ideally after navigating to a new page. + + + ariaTree + Get complete page context before taking actions + Eliminates the need to scroll and provides full accessible content + + + ` + : ` + + UNDERSTAND THE PAGE + + ariaTree + Get complete page context before taking actions + Eliminates the need to scroll and provides full accessible content + + + screenshot + Visual confirmation when needed. Ideally after navigating to a new page. + + + `; + + if (systemInstructions) { + return ` + You are a web automation assistant using browser automation tools to accomplish the user's goal. + ${cdata(systemInstructions)} + + ${cdata(executionInstruction)} + ${localeDate} + You may think the date is different due to knowledge cutoff, but this is the actual date. + + + you are starting your taskon this url:${url} + + + Be very intentional about your action. The initial instruction is very important, and slight variations of the actual goal can lead to failures. + If something fails to meet a single condition of the task, move on from it rather than seeing if it meets other criteria. We only care that it meets all of it + When the task is complete, do not seek more information; you have completed the task. + + + Always start by understanding the current page state + Use the screenshot tool to verify page state when needed + Use appropriate tools for each action + When the task is complete, use the "close" tool with taskComplete: true + If the task cannot be completed, use "close" with taskComplete: false + + ${pageUnderstandingProtocol} + + If you are confident in the URL, navigate directly to it. + ${hasSearch ? `If you are not confident in the URL, use the search tool to find it.` : ``} + + ${toolsSection} + + ${strategySection} + ${commonStrategyItems} + + + captchas, popups, etc. + If you see a captcha, use the wait tool. It will automatically be solved by our internal solver. + + + When you complete the task, explain any information that was found that was relevant to the original task. + + If you were asked for specific flights, list the flights you found. + If you were asked for information about a product, list the product information you were asked for. + + +`; + } + + return ` + You are a web automation assistant using browser automation tools to accomplish the user's goal. + + ${cdata(executionInstruction)} + ${localeDate} + You may think the date is different due to knowledge cutoff, but this is the actual date. + + + you are starting your taskon this url:${url} + + + Be very intentional about your action. The initial instruction is very important, and slight variations of the actual goal can lead to failures. + If something fails to meet a single condition of the task, move on from it rather than seeing if it meets other criteria. We only care that it meets all of it + When the task is complete, do not seek more information; you have completed the task. + + + Always start by understanding the current page state + Use the screenshot tool to verify page state when needed + Use appropriate tools for each action + When the task is complete, use the "close" tool with taskComplete: true + If the task cannot be completed, use "close" with taskComplete: false + + ${pageUnderstandingProtocol} + + If you are confident in the URL, navigate directly to it. + ${hasSearch ? `If you are not confident in the URL, use the search tool to find it.` : ``} + + ${toolsSection} + + ${strategySection} + ${commonStrategyItems} + + + captchas, popups, etc. + If you see a captcha, use the wait tool. It will automatically be solved by our internal solver. + + + When you complete the task, explain any information that was found that was relevant to the original task. + + If you were asked for specific flights, list the flights you found. + If you were asked for information about a product, list the product information you were asked for. + + +`; } diff --git a/package.json b/package.json index 226b708a3..f78094320 100644 --- a/package.json +++ b/package.json @@ -82,6 +82,7 @@ "devtools-protocol": "^0.0.1464554", "fetch-cookie": "^3.1.0", "openai": "^4.87.1", + "exa-js": "^1.9.3", "pino": "^9.6.0", "pino-pretty": "^13.0.0", "playwright": "^1.52.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index cbd994598..7d8647383 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -32,6 +32,9 @@ importers: dotenv: specifier: ^16.4.5 version: 16.6.1 + exa-js: + specifier: ^1.9.3 + version: 1.9.3(ws@8.18.3) fetch-cookie: specifier: ^3.1.0 version: 3.1.0 @@ -2692,6 +2695,9 @@ packages: typescript: optional: true + cross-fetch@4.1.0: + resolution: {integrity: sha512-uKm5PU+MHTootlWEY+mZ4vvXoCn4fLQxT9dSc1sXVMSFkINTJVN8cAQROpwcKm8bJ/c7rgZVIBWzH5T78sNZZw==} + cross-spawn@7.0.6: resolution: {integrity: sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==} engines: {node: '>= 8'} @@ -2882,6 +2888,10 @@ packages: domutils@3.2.2: resolution: {integrity: sha512-6kZKyUajlDuqlHKVX1w7gyslj9MPIXzIFiz/rGu35uC1wMi+kMhQwGhl4lt9unC9Vb9INnY9Z3/ZA3+FhASLaw==} + dotenv@16.4.7: + resolution: {integrity: sha512-47qPchRCykZC03FhkYAhrvwU4xDBFIj1QPqaarj6mdM/hgUzfPHcpkHJOn3mJAufFeeAxAzeGsr5X0M4k6fLZQ==} + engines: {node: '>=12'} + dotenv@16.6.1: resolution: {integrity: sha512-uBq4egWHTcTt33a72vpSG0z3HnPuIl6NqYcTrKEg2azoEyl2hpW0zqlxysq2pK9HlDIHyHyakeYaYnSAwd8bow==} engines: {node: '>=12'} @@ -3133,6 +3143,9 @@ packages: resolution: {integrity: sha512-CRT1WTyuQoD771GW56XEZFQ/ZoSfWid1alKGDYMmkt2yl8UXrVR4pspqWNEcqKvVIzg6PAltWjxcSSPrboA4iA==} engines: {node: '>=18.0.0'} + exa-js@1.9.3: + resolution: {integrity: sha512-4u8vO5KHstifBz6fcwcBVvU62zfwsWFpD8qomU2zQ+lLRYCwOh2Rz04xSSqEeoHrkCypGjy2VHez7elBt6ibQQ==} + express-rate-limit@7.5.1: resolution: {integrity: sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==} engines: {node: '>= 16'} @@ -4532,6 +4545,18 @@ packages: resolution: {integrity: sha512-ey2CXh1OTcTUa0AWZWuTpgA9t5GuAG3DVU1MofCRUI7fQJij8XJ3Sr0VtgxoAE69C9wbHBMCux8Z/IQZfSwHiA==} hasBin: true + openai@5.23.0: + resolution: {integrity: sha512-Cfq155NHzI7VWR67LUNJMIgPZy2oSh7Fld/OKhxq648BiUjELAvcge7g30xJ6vAfwwXf6TVK0KKuN+3nmIJG/A==} + hasBin: true + peerDependencies: + ws: ^8.18.0 + zod: ^3.23.8 + peerDependenciesMeta: + ws: + optional: true + zod: + optional: true + openapi-types@12.1.3: resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} @@ -7697,7 +7722,7 @@ snapshots: '@stoplight/json-ref-readers@1.2.2': dependencies: - node-fetch: 2.6.7 + node-fetch: 2.7.0 tslib: 1.14.1 transitivePeerDependencies: - encoding @@ -8763,6 +8788,12 @@ snapshots: optionalDependencies: typescript: 5.9.2 + cross-fetch@4.1.0: + dependencies: + node-fetch: 2.7.0 + transitivePeerDependencies: + - encoding + cross-spawn@7.0.6: dependencies: path-key: 3.1.1 @@ -8929,6 +8960,8 @@ snapshots: domelementtype: 2.3.0 domhandler: 5.0.3 + dotenv@16.4.7: {} + dotenv@16.6.1: {} dotenv@8.6.0: {} @@ -9330,6 +9363,17 @@ snapshots: dependencies: eventsource-parser: 3.0.6 + exa-js@1.9.3(ws@8.18.3): + dependencies: + cross-fetch: 4.1.0 + dotenv: 16.4.7 + openai: 5.23.0(ws@8.18.3)(zod@3.25.76) + zod: 3.25.76 + zod-to-json-schema: 3.24.6(zod@3.25.76) + transitivePeerDependencies: + - encoding + - ws + express-rate-limit@7.5.1(express@5.1.0): dependencies: express: 5.1.0 @@ -11239,6 +11283,11 @@ snapshots: transitivePeerDependencies: - encoding + openai@5.23.0(ws@8.18.3)(zod@3.25.76): + optionalDependencies: + ws: 8.18.3 + zod: 3.25.76 + openapi-types@12.1.3: {} openapi3-ts@4.5.0: diff --git a/types/agent.ts b/types/agent.ts index 9ab4cba0e..bb31e25dc 100644 --- a/types/agent.ts +++ b/types/agent.ts @@ -1,14 +1,6 @@ import { LogLine } from "./log"; import { ObserveResult } from "./stagehand"; - -export interface ActToolResult { - success: boolean; - action?: string; - error?: string; - isIframe?: boolean; - playwrightArguments?: ObserveResult | null; -} - +import { ScreenshotCollector } from "../evals/utils/ScreenshotCollector"; export interface AgentAction { type: string; reasoning?: string; @@ -19,7 +11,7 @@ export interface AgentAction { pageText?: string; // ariaTree tool pageUrl?: string; // ariaTree tool instruction?: string; // various tools - playwrightArguments?: ObserveResult | null; // act tool + playwrightArguments?: ObserveResult; [key: string]: unknown; } @@ -41,6 +33,12 @@ export interface AgentOptions { autoScreenshot?: boolean; waitBetweenActions?: number; context?: string; + /** + * When true (default), the agent will store actions and avoid using + * Claude-specific custom tools/prompts. Set to false to enable + * Claude-optimized toolset and prompts. + */ + storeActions?: boolean; } export interface AgentExecuteOptions extends AgentOptions { @@ -168,5 +166,5 @@ export interface AgentInstance { execute: ( instructionOrOptions: string | AgentExecuteOptions, ) => Promise; - setScreenshotCollector?: (collector: unknown) => void; + setScreenshotCollector?: (collector: ScreenshotCollector) => void; }