Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 9 additions & 62 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'

- name: Install dependencies
run: |
Expand All @@ -52,60 +52,12 @@ jobs:
# run: |
# pytest

- name: Calculate new version
id: version
- name: Get project version
id: get_version
run: |
# Get current version from pyproject.toml
CURRENT_VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])")
echo "current_version=$CURRENT_VERSION" >> $GITHUB_OUTPUT

# Parse version components
IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT_VERSION"

# Calculate new version based on release type
case "${{ github.event.inputs.release_type }}" in
"major")
NEW_MAJOR=$((MAJOR + 1))
NEW_MINOR=0
NEW_PATCH=0
;;
"minor")
NEW_MAJOR=$MAJOR
NEW_MINOR=$((MINOR + 1))
NEW_PATCH=0
;;
"patch")
NEW_MAJOR=$MAJOR
NEW_MINOR=$MINOR
NEW_PATCH=$((PATCH + 1))
;;
esac

NEW_VERSION="${NEW_MAJOR}.${NEW_MINOR}.${NEW_PATCH}"
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
echo "Bumping version from $CURRENT_VERSION to $NEW_VERSION"

- name: Update version files
run: |
CURRENT_VERSION="${{ steps.version.outputs.current_version }}"
NEW_VERSION="${{ steps.version.outputs.new_version }}"

# Update pyproject.toml
sed -i "s/version = \"$CURRENT_VERSION\"/version = \"$NEW_VERSION\"/" pyproject.toml

# Update __init__.py
sed -i "s/__version__ = \"$CURRENT_VERSION\"/__version__ = \"$NEW_VERSION\"/" stagehand/__init__.py

echo "Updated version to $NEW_VERSION in pyproject.toml and __init__.py"

- name: Commit version bump
run: |
git config --local user.email "[email protected]"
git config --local user.name "GitHub Action"
git add pyproject.toml stagehand/__init__.py
git commit -m "Bump version to ${{ steps.version.outputs.new_version }}"
git tag "v${{ steps.version.outputs.new_version }}"

VERSION=$(python -c "import tomli; print(tomli.load(open('pyproject.toml', 'rb'))['project']['version'])")
echo "version=$VERSION" >> $GITHUB_OUTPUT

- name: Build package
run: |
python -m build
Expand All @@ -117,15 +69,10 @@ jobs:
run: |
twine upload dist/*

- name: Push version bump
run: |
git push
git push --tags

- name: Create GitHub Release
if: ${{ github.event.inputs.create_release == 'true' }}
uses: softprops/action-gh-release@v1
with:
tag_name: v${{ steps.version.outputs.new_version }}
name: Release v${{ steps.version.outputs.new_version }}
tag_name: v${{ steps.get_version.outputs.version }}
name: Release v${{ steps.get_version.outputs.version }}
generate_release_notes: true
4 changes: 2 additions & 2 deletions stagehand/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, stagehand_client, **kwargs):
self.stagehand = stagehand_client
self.config = AgentConfig(**kwargs) if kwargs else AgentConfig()
self.logger = self.stagehand.logger
if self.stagehand.env == "BROWSERBASE":
if self.stagehand.use_api:
if self.config.model in MODEL_TO_PROVIDER_MAP:
self.provider = MODEL_TO_PROVIDER_MAP[self.config.model]
else:
Expand Down Expand Up @@ -120,7 +120,7 @@ async def execute(

instruction = options.instruction

if self.stagehand.env == "LOCAL":
if not self.stagehand.use_api:
self.logger.info(
f"Agent starting execution for instruction: '{instruction}'",
category="agent",
Expand Down
1 change: 0 additions & 1 deletion stagehand/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ async def _create_session(self):
},
}
),
"proxies": True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a param?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just the default config if none is passed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should go inside browserbase_session_create_params anyway

}

# Add the new parameters if they have values
Expand Down
28 changes: 24 additions & 4 deletions stagehand/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Optional

from browserbase import Browserbase
from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams
from playwright.async_api import (
Browser,
BrowserContext,
Expand Down Expand Up @@ -40,12 +41,31 @@ async def connect_browserbase_browser(
# Connect to remote browser via Browserbase SDK and CDP
bb = Browserbase(api_key=browserbase_api_key)
try:
session = bb.sessions.retrieve(session_id)
if session.status != "RUNNING":
raise RuntimeError(
f"Browserbase session {session_id} is not running (status: {session.status})"
if session_id:
session = bb.sessions.retrieve(session_id)
if session.status != "RUNNING":
raise RuntimeError(
f"Browserbase session {session_id} is not running (status: {session.status})"
)
else:
browserbase_session_create_params = (
BrowserbaseSessionCreateParams(
project_id=stagehand_instance.browserbase_project_id,
browser_settings={
"viewport": {
"width": 1024,
"height": 768,
},
},
)
if not stagehand_instance.browserbase_session_create_params
else stagehand_instance.browserbase_session_create_params
)
session = bb.sessions.create(**browserbase_session_create_params)
if not session.id:
raise Exception("Could not create Browserbase session")
connect_url = session.connectUrl
stagehand_instance.session_id = session.id
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move up

except Exception as e:
logger.error(f"Error retrieving or validating Browserbase session: {str(e)}")
raise
Expand Down
10 changes: 10 additions & 0 deletions stagehand/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ class StagehandConfig(BaseModel):
alias="localBrowserLaunchOptions",
description="Local browser launch options",
)
use_api: Optional[bool] = Field(
True,
alias="useAPI",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None

description="Whether to use the Stagehand API",
)
experimental: Optional[bool] = Field(
False,
alias="experimental",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

none

description="Whether to use experimental features",
)

model_config = ConfigDict(populate_by_name=True)

Expand Down
35 changes: 23 additions & 12 deletions stagehand/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,23 @@ def __init__(
self._playwright_page: Optional[PlaywrightPage] = None
self.page: Optional[StagehandPage] = None
self.context: Optional[StagehandContext] = None
self.use_api = self.config.use_api
self.experimental = self.config.experimental
if self.experimental:
self.use_api = False
if (
self.browserbase_session_create_params
and self.browserbase_session_create_params.get("region")
and self.browserbase_session_create_params.get("region") != "us-west-2"
):
self.use_api = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


self._initialized = False # Flag to track if init() has run
self._closed = False # Flag to track if resources have been closed

# Setup LLM client if LOCAL mode
self.llm = None
if self.env == "LOCAL":
if not self.use_api:
self.llm = LLMClient(
stagehand_logger=self.logger,
api_key=self.model_api_key,
Expand Down Expand Up @@ -385,15 +395,16 @@ async def init(self):

if self.env == "BROWSERBASE":
# Create session if we don't have one
if not self.session_id:
await self._create_session() # Uses self._client and api_url
self.logger.debug(
f"Created new Browserbase session via Stagehand server: {self.session_id}"
)
else:
self.logger.debug(
f"Using existing Browserbase session: {self.session_id}"
)
if self.use_api:
if not self.session_id:
await self._create_session() # Uses self._client and api_url
self.logger.debug(
f"Created new Browserbase session via Stagehand server: {self.session_id}"
)
else:
self.logger.debug(
f"Using existing Browserbase session: {self.session_id}"
)

# Connect to remote browser
try:
Expand Down Expand Up @@ -470,8 +481,8 @@ async def close(self):

self.logger.debug("Closing resources...")

if self.env == "BROWSERBASE":
# --- BROWSERBASE Cleanup ---
if self.use_api:
# --- BROWSERBASE Cleanup (API) ---
# End the session on the server if we have a session ID
if self.session_id and self._client: # Check if client was initialized
try:
Expand Down
25 changes: 11 additions & 14 deletions stagehand/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def goto(
Returns:
The result from the Stagehand server's navigation execution.
"""
if self._stagehand.env == "LOCAL":
if not self._stagehand.use_api:
await self._page.goto(
url, referer=referer, timeout=timeout, wait_until=wait_until
)
Expand Down Expand Up @@ -142,7 +142,7 @@ async def act(
)

# TODO: Temporary until we move api based logic to client
if self._stagehand.env == "LOCAL":
if not self._stagehand.use_api:
# TODO: revisit passing user_provided_instructions
if not hasattr(self, "_observe_handler"):
# TODO: revisit handlers initialization on page creation
Expand Down Expand Up @@ -207,7 +207,7 @@ async def observe(
payload = options_obj.model_dump(exclude_none=True, by_alias=True)

# If in LOCAL mode, use local implementation
if self._stagehand.env == "LOCAL":
if not self._stagehand.use_api:
self._stagehand.logger.debug(
"observe", category="observe", auxiliary=payload
)
Expand Down Expand Up @@ -324,8 +324,7 @@ async def extract(
else:
schema_to_validate_with = DefaultExtractSchema

# If in LOCAL mode, use local implementation
if self._stagehand.env == "LOCAL":
if not self._stagehand.use_api:
# If we don't have an extract handler yet, create one
if not hasattr(self, "_extract_handler"):
self._extract_handler = ExtractHandler(
Expand Down Expand Up @@ -391,18 +390,16 @@ async def screenshot(self, options: Optional[dict] = None) -> str:
Returns:
str: Base64-encoded screenshot data.
"""
if self._stagehand.env == "LOCAL":
self._stagehand.logger.info(
"Local execution of screenshot is not implemented"
)
return None
payload = options or {}

lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("screenshot", payload)
if self._stagehand.use_api:
lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("screenshot", payload)

return result
return result
else:
return await self._page.screenshot(options)

# Method to get or initialize the persistent CDP client
async def get_cdp_client(self) -> CDPSession:
Expand Down
Loading