|
| 1 | +from typing import Optional, Union |
| 2 | + |
| 3 | +from ..handlers.cua_handler import CUAHandler |
| 4 | +from ..types.agent import ( |
| 5 | + AgentConfig, |
| 6 | + AgentExecuteOptions, |
| 7 | + AgentResult, |
| 8 | + AgentUsage, |
| 9 | +) |
| 10 | +from .anthropic_cua import AnthropicCUAClient |
| 11 | +from .client import AgentClient |
| 12 | +from .openai_cua import OpenAICUAClient |
| 13 | + |
| 14 | +MODEL_TO_CLIENT_CLASS_MAP: dict[str, type[AgentClient]] = { |
| 15 | + "computer-use-preview": OpenAICUAClient, |
| 16 | + "claude-3-5-sonnet-20240620": AnthropicCUAClient, |
| 17 | + "claude-3-7-sonnet-20250219": AnthropicCUAClient, |
| 18 | +} |
| 19 | + |
| 20 | +AGENT_METRIC_FUNCTION_NAME = "AGENT_EXECUTE_TASK" |
| 21 | + |
| 22 | + |
| 23 | +class Agent: |
| 24 | + |
| 25 | + def __init__(self, stagehand_client, **kwargs): |
| 26 | + self.stagehand = stagehand_client |
| 27 | + self.config = AgentConfig(**kwargs) if kwargs else AgentConfig() |
| 28 | + self.logger = self.stagehand.logger |
| 29 | + |
| 30 | + if not hasattr(self.stagehand, "page") or not hasattr( |
| 31 | + self.stagehand.page, "_page" |
| 32 | + ): |
| 33 | + self.logger.error( |
| 34 | + "Stagehand page object not available for CUAHandler initialization." |
| 35 | + ) |
| 36 | + raise ValueError("Stagehand page not initialized. Cannot create Agent.") |
| 37 | + |
| 38 | + self.cua_handler = CUAHandler( |
| 39 | + stagehand=self.stagehand, page=self.stagehand.page._page, logger=self.logger |
| 40 | + ) |
| 41 | + |
| 42 | + self.client: AgentClient = self._get_client() |
| 43 | + |
| 44 | + def _get_client(self) -> AgentClient: |
| 45 | + ClientClass = MODEL_TO_CLIENT_CLASS_MAP.get(self.config.model) # noqa: N806 |
| 46 | + if not ClientClass: |
| 47 | + self.logger.error( |
| 48 | + f"Unsupported model or client not mapped: {self.config.model}" |
| 49 | + ) |
| 50 | + raise ValueError( |
| 51 | + f"Unsupported model or client not mapped: {self.config.model}" |
| 52 | + ) |
| 53 | + |
| 54 | + return ClientClass( |
| 55 | + model=self.config.model, |
| 56 | + instructions=( |
| 57 | + self.config.instructions |
| 58 | + if self.config.instructions |
| 59 | + else "Your browser is in full screen mode. There is no search bar, or navigation bar, or shortcut to control it. You can use the goto tool to navigate to different urls. Do not try to access a top navigation bar or other browser features." |
| 60 | + ), |
| 61 | + config=self.config, |
| 62 | + logger=self.logger, |
| 63 | + handler=self.cua_handler, |
| 64 | + ) |
| 65 | + |
| 66 | + async def execute( |
| 67 | + self, options_or_instruction: Union[AgentExecuteOptions, str] |
| 68 | + ) -> AgentResult: |
| 69 | + |
| 70 | + options: Optional[AgentExecuteOptions] = None |
| 71 | + instruction: str |
| 72 | + |
| 73 | + if isinstance(options_or_instruction, str): |
| 74 | + instruction = options_or_instruction |
| 75 | + options = AgentExecuteOptions(instruction=instruction) |
| 76 | + elif isinstance(options_or_instruction, dict): |
| 77 | + options = AgentExecuteOptions(**options_or_instruction) |
| 78 | + instruction = options.instruction |
| 79 | + else: |
| 80 | + options = options_or_instruction |
| 81 | + instruction = options.instruction |
| 82 | + |
| 83 | + if not instruction: |
| 84 | + self.logger.error("No instruction provided for agent execution.") |
| 85 | + return AgentResult( |
| 86 | + message="No instruction provided.", completed=True, actions=[], usage={} |
| 87 | + ) |
| 88 | + |
| 89 | + self.logger.info( |
| 90 | + f"Agent starting execution for instruction: '{instruction}'", |
| 91 | + category="agent", |
| 92 | + ) |
| 93 | + |
| 94 | + try: |
| 95 | + agent_result = await self.client.run_task( |
| 96 | + instruction=instruction, |
| 97 | + max_steps=self.config.max_steps, |
| 98 | + options=options, |
| 99 | + ) |
| 100 | + except Exception as e: |
| 101 | + self.logger.error( |
| 102 | + f"Exception during client.run_task: {e}", category="agent" |
| 103 | + ) |
| 104 | + empty_usage = AgentUsage( |
| 105 | + input_tokens=0, output_tokens=0, inference_time_ms=0 |
| 106 | + ) |
| 107 | + return AgentResult( |
| 108 | + message=f"Error: {str(e)}", |
| 109 | + completed=True, |
| 110 | + actions=[], |
| 111 | + usage=empty_usage, |
| 112 | + ) |
| 113 | + |
| 114 | + # Update metrics if usage data is available in the result |
| 115 | + if agent_result.usage: |
| 116 | + # self.stagehand.update_metrics( |
| 117 | + # AGENT_METRIC_FUNCTION_NAME, |
| 118 | + # agent_result.usage.get("input_tokens", 0), |
| 119 | + # agent_result.usage.get("output_tokens", 0), |
| 120 | + # agent_result.usage.get("inference_time_ms", 0), |
| 121 | + # ) |
| 122 | + pass # Placeholder if metrics are to be handled differently or not at all |
| 123 | + |
| 124 | + self.logger.info( |
| 125 | + f"Agent execution finished. Success: {agent_result.completed}. Message: {agent_result.message}", |
| 126 | + category="agent", |
| 127 | + ) |
| 128 | + # To clean up pydantic model output |
| 129 | + actions_repr = [action.root for action in agent_result.actions] |
| 130 | + self.logger.debug( |
| 131 | + f"Agent actions: {actions_repr}", |
| 132 | + category="agent", |
| 133 | + ) |
| 134 | + agent_result.actions = actions_repr |
| 135 | + return agent_result |
0 commit comments