Skip to content

Frame ID map for multi-tab support (API) #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/vegan-intrepid-mustang.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

Added frame_id_map to support multi-tab handling on API
86 changes: 84 additions & 2 deletions stagehand/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __init__(self, context: BrowserContext, stagehand):
# Use a weak key dictionary to map Playwright Pages to our StagehandPage wrappers
self.page_map = weakref.WeakKeyDictionary()
self.active_stagehand_page = None
# Map frame IDs to StagehandPage instances
self.frame_id_map = {}

async def new_page(self) -> StagehandPage:
pw_page: Page = await self._context.new_page()
Expand All @@ -23,9 +25,13 @@ async def new_page(self) -> StagehandPage:

async def create_stagehand_page(self, pw_page: Page) -> StagehandPage:
# Create a StagehandPage wrapper for the given Playwright page
stagehand_page = StagehandPage(pw_page, self.stagehand)
stagehand_page = StagehandPage(pw_page, self.stagehand, self)
await self.inject_custom_scripts(pw_page)
self.page_map[pw_page] = stagehand_page

# Initialize frame tracking for this page
await self._attach_frame_navigated_listener(pw_page, stagehand_page)

return stagehand_page

async def inject_custom_scripts(self, pw_page: Page):
Expand Down Expand Up @@ -69,9 +75,21 @@ def set_active_page(self, stagehand_page: StagehandPage):
def get_active_page(self) -> StagehandPage:
return self.active_stagehand_page

def register_frame_id(self, frame_id: str, page: StagehandPage):
"""Register a frame ID to StagehandPage mapping."""
self.frame_id_map[frame_id] = page

def unregister_frame_id(self, frame_id: str):
"""Unregister a frame ID from the mapping."""
if frame_id in self.frame_id_map:
del self.frame_id_map[frame_id]

def get_stagehand_page_by_frame_id(self, frame_id: str) -> StagehandPage:
"""Get StagehandPage by frame ID."""
return self.frame_id_map.get(frame_id)

@classmethod
async def init(cls, context: BrowserContext, stagehand):
stagehand.logger.debug("StagehandContext.init() called", category="context")
instance = cls(context, stagehand)
# Pre-initialize StagehandPages for any existing pages
stagehand.logger.debug(
Expand Down Expand Up @@ -150,3 +168,67 @@ async def wrapped_pages():

return wrapped_pages
return attr

async def _attach_frame_navigated_listener(
self, pw_page: Page, stagehand_page: StagehandPage
):
"""
Attach CDP listener for frame navigation events to track frame IDs.
This mirrors the TypeScript implementation's frame tracking.
"""
try:
# Create CDP session for the page
cdp_session = await self._context.new_cdp_session(pw_page)
await cdp_session.send("Page.enable")

# Get the current root frame ID
frame_tree = await cdp_session.send("Page.getFrameTree")
root_frame_id = frame_tree.get("frameTree", {}).get("frame", {}).get("id")

if root_frame_id:
# Initialize the page with its frame ID
stagehand_page.update_root_frame_id(root_frame_id)
self.register_frame_id(root_frame_id, stagehand_page)

# Set up event listener for frame navigation
def on_frame_navigated(params):
"""Handle Page.frameNavigated events"""
frame = params.get("frame", {})
frame_id = frame.get("id")
parent_id = frame.get("parentId")

# Only track root frames (no parent)
if not parent_id and frame_id:
# Skip if it's the same frame ID
if frame_id == stagehand_page.frame_id:
return

# Unregister old frame ID if exists
old_id = stagehand_page.frame_id
if old_id:
self.unregister_frame_id(old_id)

# Register new frame ID
self.register_frame_id(frame_id, stagehand_page)
stagehand_page.update_root_frame_id(frame_id)

self.stagehand.logger.debug(
f"Frame navigated from {old_id} to {frame_id}",
category="context",
)

# Register the event listener
cdp_session.on("Page.frameNavigated", on_frame_navigated)

# Clean up frame ID when page closes
def on_page_close():
if stagehand_page.frame_id:
self.unregister_frame_id(stagehand_page.frame_id)

pw_page.once("close", on_page_close)

except Exception as e:
self.stagehand.logger.error(
f"Failed to attach frame navigation listener: {str(e)}",
category="context",
)
35 changes: 32 additions & 3 deletions stagehand/page.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import time
from typing import Optional, Union

from playwright.async_api import CDPSession, Page
Expand Down Expand Up @@ -26,16 +28,29 @@ class StagehandPage:

_cdp_client: Optional[CDPSession] = None

def __init__(self, page: Page, stagehand_client):
def __init__(self, page: Page, stagehand_client, context=None):
"""
Initialize a StagehandPage instance.

Args:
page (Page): The underlying Playwright page.
stagehand_client: The client used to interface with the Stagehand server.
context: The StagehandContext instance (optional).
"""
self._page = page
self._stagehand = stagehand_client
self._context = context
self._frame_id = None

@property
def frame_id(self) -> Optional[str]:
"""Get the current root frame ID."""
return self._frame_id

def update_root_frame_id(self, new_id: str):
"""Update the root frame ID."""
self._frame_id = new_id
self._stagehand.logger.debug(f"Updated frame ID to {new_id}", category="page")

# TODO try catch here
async def ensure_injection(self):
Expand Down Expand Up @@ -98,6 +113,10 @@ async def goto(
if options:
payload["options"] = options

# Add frame ID if available
if self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("navigate", payload)
Expand Down Expand Up @@ -168,6 +187,10 @@ async def act(
result = await self._act_handler.act(payload)
return result

# Add frame ID if available
if self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("act", payload)
Expand Down Expand Up @@ -237,6 +260,10 @@ async def observe(

return result

# Add frame ID if available
if self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("observe", payload)
Expand Down Expand Up @@ -361,6 +388,10 @@ async def extract(
return result.data

# Use API
# Add frame ID if available
if self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
async with lock:
result_dict = await self._stagehand._execute("extract", payload)
Expand Down Expand Up @@ -487,8 +518,6 @@ async def _wait_for_settled_dom(self, timeout_ms: int = None):
timeout_ms (int, optional): Maximum time to wait in milliseconds.
If None, uses the stagehand client's dom_settle_timeout_ms.
"""
import asyncio
import time

timeout = timeout_ms or getattr(self._stagehand, "dom_settle_timeout_ms", 30000)
client = await self.get_cdp_client()
Expand Down
Loading