Skip to content
This repository was archived by the owner on Nov 10, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 44 additions & 52 deletions crewai_tools/tools/stagehand_tool/stagehand_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
from typing import Dict, List, Optional, Type, Union, Any

from pydantic import BaseModel, Field

# Define a flag to track whether stagehand is available
Expand All @@ -26,16 +25,16 @@
ActOptions = Any
ExtractOptions = Any
ObserveOptions = Any

# Mock configure_logging function
def configure_logging(level=None, remove_logger_name=None, quiet_dependencies=None):
pass

# Define only what's needed for class defaults
class AvailableModel:
CLAUDE_3_7_SONNET_LATEST = "anthropic.claude-3-7-sonnet-20240607"

from crewai.tools import BaseTool
from crewai.tools import BaseTool, EnvVar


class StagehandCommandType(str):
Expand Down Expand Up @@ -78,10 +77,10 @@ class StagehandToolSchema(BaseModel):
)
command_type: Optional[str] = Field(
"act",
description="""The type of command to execute (choose one):
description="""The type of command to execute (choose one):
- 'act': Perform an action like clicking buttons, filling forms, etc. (default)
- 'navigate': Specifically navigate to a URL
- 'extract': Extract structured data from the page
- 'extract': Extract structured data from the page
- 'observe': Identify and analyze elements on the page
""",
)
Expand Down Expand Up @@ -136,15 +135,15 @@ class StagehandTool(BaseTool):

name: str = "Web Automation Tool"
description: str = """Use this tool to control a web browser and interact with websites using natural language.

Capabilities:
- Navigate to websites and follow links
- Click buttons, links, and other elements
- Fill in forms and input fields
- Search within websites
- Extract information from web pages
- Identify and analyze elements on a page

To use this tool, provide a natural language instruction describing what you want to do.
For different types of tasks, specify the command_type:
- 'act': For performing actions (default)
Expand Down Expand Up @@ -173,6 +172,13 @@ class StagehandTool(BaseTool):
_logger: Optional[logging.Logger] = None
_testing: bool = False

env_vars: List[EnvVar] = [
EnvVar(name="BROWSERBASE_API_KEY", description="API key for Browserbase", required=False),
EnvVar(name="BROWSERBASE_PROJECT_ID", description="Project ID for Browserbase", required=False),
EnvVar(name="OPENAI_API_KEY", description="Model API key for OpenAI", required=False),
EnvVar(name="ANTHROPIC_API_KEY", description="Model API key for Anthropic", required=False),
]

def __init__(
self,
api_key: Optional[str] = None,
Expand All @@ -197,8 +203,9 @@ def __init__(
self._logger = logging.getLogger(__name__)

# For backward compatibility
browserbase_api_key = kwargs.get("browserbase_api_key")
browserbase_project_id = kwargs.get("browserbase_project_id")
browserbase_api_key = kwargs.get("browserbase_api_key") or os.getenv("BROWSERBASE_API_KEY")
browserbase_project_id = kwargs.get("browserbase_project_id") or os.getenv("BROWSERBASE_PROJECT_ID")
model_api_key = model_api_key or os.getenv("OPENAI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")

if api_key:
self.api_key = api_key
Expand Down Expand Up @@ -251,7 +258,7 @@ def _check_required_credentials(self):
raise ImportError(
"`stagehand-py` package not found, please run `uv add stagehand-py`"
)

if not self.api_key:
raise ValueError("api_key is required (or set BROWSERBASE_API_KEY in env).")
if not self.project_id:
Expand All @@ -265,7 +272,7 @@ def _check_required_credentials(self):

async def _setup_stagehand(self, session_id: Optional[str] = None):
"""Initialize Stagehand if not already set up."""

# If we're in testing mode, return mock objects
if self._testing:
if not self._stagehand:
Expand All @@ -275,43 +282,42 @@ def act(self, options):
mock_result = type('MockResult', (), {})()
mock_result.model_dump = lambda: {"message": "Action completed successfully"}
return mock_result

def goto(self, url):
return None

def extract(self, options):
mock_result = type('MockResult', (), {})()
mock_result.model_dump = lambda: {"data": "Extracted content"}
return mock_result

def observe(self, options):
mock_result1 = type('MockResult', (), {"description": "Test element", "method": "click"})()
return [mock_result1]

class MockStagehand:
def __init__(self):
self.page = MockPage()
self.session_id = "test-session-id"

def init(self):
return None

def close(self):
return None

self._stagehand = MockStagehand()
# No need to await the init call in test mode
self._stagehand.init()
self._page = self._stagehand.page
self._session_id = self._stagehand.session_id

return self._stagehand, self._page

# Normal initialization for non-testing mode
if not self._stagehand:
self._logger.debug("Initializing Stagehand")
# Create model client options with the API key
model_client_options = {"apiKey": self.model_api_key}

# Build the StagehandConfig object
config = StagehandConfig(
Expand All @@ -323,13 +329,12 @@ def close(self):
model_name=self.model_name,
self_heal=self.self_heal,
wait_for_captcha_solves=self.wait_for_captcha_solves,
model_client_options=model_client_options,
verbose=self.verbose,
session_id=session_id or self._session_id,
browserbase_session_id=session_id or self._session_id,
)

# Initialize Stagehand with config and server_url
self._stagehand = Stagehand(config=config, server_url=self.server_url)
self._stagehand = Stagehand(config=config, server_url=self.server_url, model_api_key=self.model_api_key)

# Initialize the Stagehand instance
await self._stagehand.init()
Expand All @@ -355,7 +360,7 @@ async def _async_run(
# Return predefined mock results based on command type
if command_type.lower() == "act":
return StagehandResult(
success=True,
success=True,
data={"message": "Action completed successfully"}
)
elif command_type.lower() == "navigate":
Expand All @@ -368,7 +373,7 @@ async def _async_run(
)
elif command_type.lower() == "extract":
return StagehandResult(
success=True,
success=True,
data={"data": "Extracted content", "metadata": {"source": "test"}}
)
elif command_type.lower() == "observe":
Expand All @@ -380,11 +385,11 @@ async def _async_run(
)
else:
return StagehandResult(
success=False,
data={},
success=False,
data={},
error=f"Unknown command type: {command_type}"
)

# Normal execution for non-test mode
stagehand, page = await self._setup_stagehand(self._session_id)

Expand All @@ -394,15 +399,8 @@ async def _async_run(

# Process according to command type
if command_type.lower() == "act":
# Create act options
act_options = ActOptions(
action=instruction,
model_name=self.model_name,
dom_settle_timeout_ms=self.dom_settle_timeout_ms,
)

# Execute the act command
result = await page.act(act_options)
result = await page.act(instruction)
self._logger.info(f"Act operation completed: {result}")
return StagehandResult(success=True, data=result.model_dump())

Expand All @@ -427,31 +425,25 @@ async def _async_run(
)

elif command_type.lower() == "extract":
# Create extract options
extract_options = ExtractOptions(
# Execute the extract command
result = await page.extract(
instruction=instruction,
model_name=self.model_name,
dom_settle_timeout_ms=self.dom_settle_timeout_ms,
use_text_extract=True,
use_text_extract=True
)

# Execute the extract command
result = await page.extract(extract_options)
self._logger.info(f"Extract operation completed successfully {result}")
return StagehandResult(success=True, data=result.model_dump())

elif command_type.lower() == "observe":
# Create observe options
observe_options = ObserveOptions(
# Execute the observe command
results = await page.observe(
instruction=instruction,
model_name=self.model_name,
only_visible=True,
dom_settle_timeout_ms=self.dom_settle_timeout_ms,
)

# Execute the observe command
results = await page.observe(observe_options)

# Format the observation results
formatted_results = []
for i, result in enumerate(results):
Expand Down Expand Up @@ -551,7 +543,7 @@ async def _async_close(self):
self._stagehand = None
self._page = None
return

if self._stagehand:
await self._stagehand.close()
self._stagehand = None
Expand All @@ -565,7 +557,7 @@ def close(self):
self._stagehand = None
self._page = None
return

if self._stagehand:
try:
# Handle both synchronous and asynchronous cases
Expand All @@ -586,9 +578,9 @@ def close(self):
# Log but don't raise - we're cleaning up
if self._logger:
self._logger.error(f"Error closing Stagehand: {str(e)}")

self._stagehand = None

if self._page:
self._page = None

Expand Down
Loading