diff --git a/.changeset/vegan-intrepid-mustang.md b/.changeset/vegan-intrepid-mustang.md new file mode 100644 index 0000000..488cd76 --- /dev/null +++ b/.changeset/vegan-intrepid-mustang.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Added frame_id_map to support multi-tab handling on API diff --git a/stagehand/context.py b/stagehand/context.py index bfb6f2f..aa99f07 100644 --- a/stagehand/context.py +++ b/stagehand/context.py @@ -14,6 +14,8 @@ def __init__(self, context: BrowserContext, stagehand): # Use a weak key dictionary to map Playwright Pages to our StagehandPage wrappers self.page_map = weakref.WeakKeyDictionary() self.active_stagehand_page = None + # Map frame IDs to StagehandPage instances + self.frame_id_map = {} async def new_page(self) -> StagehandPage: pw_page: Page = await self._context.new_page() @@ -23,9 +25,13 @@ async def new_page(self) -> StagehandPage: async def create_stagehand_page(self, pw_page: Page) -> StagehandPage: # Create a StagehandPage wrapper for the given Playwright page - stagehand_page = StagehandPage(pw_page, self.stagehand) + stagehand_page = StagehandPage(pw_page, self.stagehand, self) await self.inject_custom_scripts(pw_page) self.page_map[pw_page] = stagehand_page + + # Initialize frame tracking for this page + await self._attach_frame_navigated_listener(pw_page, stagehand_page) + return stagehand_page async def inject_custom_scripts(self, pw_page: Page): @@ -69,9 +75,21 @@ def set_active_page(self, stagehand_page: StagehandPage): def get_active_page(self) -> StagehandPage: return self.active_stagehand_page + def register_frame_id(self, frame_id: str, page: StagehandPage): + """Register a frame ID to StagehandPage mapping.""" + self.frame_id_map[frame_id] = page + + def unregister_frame_id(self, frame_id: str): + """Unregister a frame ID from the mapping.""" + if frame_id in self.frame_id_map: + del self.frame_id_map[frame_id] + + def get_stagehand_page_by_frame_id(self, frame_id: str) -> StagehandPage: + """Get StagehandPage by frame ID.""" + return self.frame_id_map.get(frame_id) + @classmethod async def init(cls, context: BrowserContext, stagehand): - stagehand.logger.debug("StagehandContext.init() called", category="context") instance = cls(context, stagehand) # Pre-initialize StagehandPages for any existing pages stagehand.logger.debug( @@ -150,3 +168,67 @@ async def wrapped_pages(): return wrapped_pages return attr + + async def _attach_frame_navigated_listener( + self, pw_page: Page, stagehand_page: StagehandPage + ): + """ + Attach CDP listener for frame navigation events to track frame IDs. + This mirrors the TypeScript implementation's frame tracking. + """ + try: + # Create CDP session for the page + cdp_session = await self._context.new_cdp_session(pw_page) + await cdp_session.send("Page.enable") + + # Get the current root frame ID + frame_tree = await cdp_session.send("Page.getFrameTree") + root_frame_id = frame_tree.get("frameTree", {}).get("frame", {}).get("id") + + if root_frame_id: + # Initialize the page with its frame ID + stagehand_page.update_root_frame_id(root_frame_id) + self.register_frame_id(root_frame_id, stagehand_page) + + # Set up event listener for frame navigation + def on_frame_navigated(params): + """Handle Page.frameNavigated events""" + frame = params.get("frame", {}) + frame_id = frame.get("id") + parent_id = frame.get("parentId") + + # Only track root frames (no parent) + if not parent_id and frame_id: + # Skip if it's the same frame ID + if frame_id == stagehand_page.frame_id: + return + + # Unregister old frame ID if exists + old_id = stagehand_page.frame_id + if old_id: + self.unregister_frame_id(old_id) + + # Register new frame ID + self.register_frame_id(frame_id, stagehand_page) + stagehand_page.update_root_frame_id(frame_id) + + self.stagehand.logger.debug( + f"Frame navigated from {old_id} to {frame_id}", + category="context", + ) + + # Register the event listener + cdp_session.on("Page.frameNavigated", on_frame_navigated) + + # Clean up frame ID when page closes + def on_page_close(): + if stagehand_page.frame_id: + self.unregister_frame_id(stagehand_page.frame_id) + + pw_page.once("close", on_page_close) + + except Exception as e: + self.stagehand.logger.error( + f"Failed to attach frame navigation listener: {str(e)}", + category="context", + ) diff --git a/stagehand/page.py b/stagehand/page.py index 1bc1d5f..c3e8831 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -1,3 +1,5 @@ +import asyncio +import time from typing import Optional, Union from playwright.async_api import CDPSession, Page @@ -26,16 +28,29 @@ class StagehandPage: _cdp_client: Optional[CDPSession] = None - def __init__(self, page: Page, stagehand_client): + def __init__(self, page: Page, stagehand_client, context=None): """ Initialize a StagehandPage instance. Args: page (Page): The underlying Playwright page. stagehand_client: The client used to interface with the Stagehand server. + context: The StagehandContext instance (optional). """ self._page = page self._stagehand = stagehand_client + self._context = context + self._frame_id = None + + @property + def frame_id(self) -> Optional[str]: + """Get the current root frame ID.""" + return self._frame_id + + def update_root_frame_id(self, new_id: str): + """Update the root frame ID.""" + self._frame_id = new_id + self._stagehand.logger.debug(f"Updated frame ID to {new_id}", category="page") # TODO try catch here async def ensure_injection(self): @@ -98,6 +113,10 @@ async def goto( if options: payload["options"] = options + # Add frame ID if available + if self._frame_id: + payload["frameId"] = self._frame_id + lock = self._stagehand._get_lock_for_session() async with lock: result = await self._stagehand._execute("navigate", payload) @@ -168,6 +187,10 @@ async def act( result = await self._act_handler.act(payload) return result + # Add frame ID if available + if self._frame_id: + payload["frameId"] = self._frame_id + lock = self._stagehand._get_lock_for_session() async with lock: result = await self._stagehand._execute("act", payload) @@ -237,6 +260,10 @@ async def observe( return result + # Add frame ID if available + if self._frame_id: + payload["frameId"] = self._frame_id + lock = self._stagehand._get_lock_for_session() async with lock: result = await self._stagehand._execute("observe", payload) @@ -361,6 +388,10 @@ async def extract( return result.data # Use API + # Add frame ID if available + if self._frame_id: + payload["frameId"] = self._frame_id + lock = self._stagehand._get_lock_for_session() async with lock: result_dict = await self._stagehand._execute("extract", payload) @@ -487,8 +518,6 @@ async def _wait_for_settled_dom(self, timeout_ms: int = None): timeout_ms (int, optional): Maximum time to wait in milliseconds. If None, uses the stagehand client's dom_settle_timeout_ms. """ - import asyncio - import time timeout = timeout_ms or getattr(self._stagehand, "dom_settle_timeout_ms", 30000) client = await self.get_cdp_client() diff --git a/tests/integration/api/test_frame_id_integration.py b/tests/integration/api/test_frame_id_integration.py new file mode 100644 index 0000000..f19b4cc --- /dev/null +++ b/tests/integration/api/test_frame_id_integration.py @@ -0,0 +1,204 @@ +""" +Integration tests for frame ID functionality with the API. +Tests that frame IDs are properly tracked and sent to the server. +""" + +import pytest +import os +from unittest.mock import patch, AsyncMock, MagicMock +from stagehand import Stagehand + + +@pytest.mark.skipif( + not os.getenv("BROWSERBASE_API_KEY") or not os.getenv("BROWSERBASE_PROJECT_ID"), + reason="Browserbase credentials not configured" +) +@pytest.mark.asyncio +class TestFrameIdIntegration: + """Integration tests for frame ID tracking with the API.""" + + async def test_frame_id_initialization_and_api_calls(self): + """Test that frame IDs are initialized and included in API calls.""" + # Mock the HTTP client to capture API calls + with patch('stagehand.main.httpx.AsyncClient') as MockClient: + mock_client = AsyncMock() + MockClient.return_value = mock_client + + # Mock session creation response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "success": True, + "data": { + "sessionId": "test-session-123", + "available": True + } + } + mock_client.post = AsyncMock(return_value=mock_response) + + # Mock streaming response for execute calls + mock_stream_response = AsyncMock() + mock_stream_response.status_code = 200 + mock_stream_response.__aenter__ = AsyncMock(return_value=mock_stream_response) + mock_stream_response.__aexit__ = AsyncMock() + + # Mock the async iterator for streaming lines + async def mock_aiter_lines(): + yield 'data: {"type": "system", "data": {"status": "finished", "result": {"success": true}}}' + + mock_stream_response.aiter_lines = mock_aiter_lines + mock_client.stream = MagicMock(return_value=mock_stream_response) + + # Initialize Stagehand + stagehand = Stagehand( + env="BROWSERBASE", + use_api=True, + browserbase_api_key="test-api-key", + browserbase_project_id="test-project", + model_api_key="test-model-key" + ) + + try: + # Initialize browser (this will create session via API) + await stagehand.init() + + # Verify session was created + assert mock_client.post.called + + # Get the page and context + page = stagehand.page + context = stagehand.context + + # Verify frame tracking attributes exist + assert hasattr(page, 'frame_id') + assert hasattr(context, 'frame_id_map') + + # Simulate setting a frame ID (normally done by CDP listener) + test_frame_id = "test-frame-456" + page.update_root_frame_id(test_frame_id) + context.register_frame_id(test_frame_id, page) + + # Test that frame ID is included in navigate call + await page.goto("https://example.com") + + # Check the stream call was made with frameId + stream_call_args = mock_client.stream.call_args + if stream_call_args: + payload = stream_call_args[1].get('json', {}) + assert 'frameId' in payload + assert payload['frameId'] == test_frame_id + + finally: + await stagehand.close() + + async def test_multiple_pages_frame_id_tracking(self): + """Test frame ID tracking with multiple pages.""" + with patch('stagehand.main.httpx.AsyncClient') as MockClient: + mock_client = AsyncMock() + MockClient.return_value = mock_client + + # Setup mocks as in previous test + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "success": True, + "data": { + "sessionId": "test-session-789", + "available": True + } + } + mock_client.post = AsyncMock(return_value=mock_response) + + stagehand = Stagehand( + env="BROWSERBASE", + use_api=True, + browserbase_api_key="test-api-key", + browserbase_project_id="test-project", + model_api_key="test-model-key" + ) + + try: + await stagehand.init() + + # Get first page + page1 = stagehand.page + context = stagehand.context + + # Simulate frame IDs for testing + frame_id_1 = "frame-page1" + page1.update_root_frame_id(frame_id_1) + context.register_frame_id(frame_id_1, page1) + + # Create second page + page2 = await context.new_page() + frame_id_2 = "frame-page2" + page2.update_root_frame_id(frame_id_2) + context.register_frame_id(frame_id_2, page2) + + # Verify both pages are tracked + assert len(context.frame_id_map) == 2 + assert context.get_stagehand_page_by_frame_id(frame_id_1) == page1 + assert context.get_stagehand_page_by_frame_id(frame_id_2) == page2 + + # Verify each page has its own frame ID + assert page1.frame_id == frame_id_1 + assert page2.frame_id == frame_id_2 + + finally: + await stagehand.close() + + async def test_frame_id_persistence_across_navigation(self): + """Test that frame IDs are updated when navigating to new pages.""" + with patch('stagehand.main.httpx.AsyncClient') as MockClient: + mock_client = AsyncMock() + MockClient.return_value = mock_client + + # Setup basic mocks + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "success": True, + "data": { + "sessionId": "test-session-nav", + "available": True + } + } + mock_client.post = AsyncMock(return_value=mock_response) + + stagehand = Stagehand( + env="BROWSERBASE", + use_api=True, + browserbase_api_key="test-api-key", + browserbase_project_id="test-project", + model_api_key="test-model-key" + ) + + try: + await stagehand.init() + + page = stagehand.page + context = stagehand.context + + # Initial frame ID + initial_frame_id = "frame-initial" + page.update_root_frame_id(initial_frame_id) + context.register_frame_id(initial_frame_id, page) + + assert page.frame_id == initial_frame_id + assert initial_frame_id in context.frame_id_map + + # Simulate navigation causing frame ID change + # (In real scenario, CDP listener would handle this) + new_frame_id = "frame-after-nav" + context.unregister_frame_id(initial_frame_id) + page.update_root_frame_id(new_frame_id) + context.register_frame_id(new_frame_id, page) + + # Verify frame ID was updated + assert page.frame_id == new_frame_id + assert initial_frame_id not in context.frame_id_map + assert new_frame_id in context.frame_id_map + assert context.get_stagehand_page_by_frame_id(new_frame_id) == page + + finally: + await stagehand.close() \ No newline at end of file diff --git a/tests/unit/core/test_frame_id_tracking.py b/tests/unit/core/test_frame_id_tracking.py new file mode 100644 index 0000000..f18214b --- /dev/null +++ b/tests/unit/core/test_frame_id_tracking.py @@ -0,0 +1,213 @@ +""" +Unit tests for frame ID tracking functionality. +Tests the implementation of frame ID map in StagehandContext and StagehandPage. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from stagehand.context import StagehandContext +from stagehand.page import StagehandPage + + +@pytest.fixture +def mock_stagehand(): + """Create a mock Stagehand client.""" + mock = MagicMock() + mock.logger = MagicMock() + mock.logger.debug = MagicMock() + mock.logger.error = MagicMock() + mock._page_switch_lock = AsyncMock() + return mock + + +@pytest.fixture +def mock_browser_context(): + """Create a mock Playwright BrowserContext.""" + mock_context = MagicMock() + mock_context.pages = [] + mock_context.new_page = AsyncMock() + mock_context.new_cdp_session = AsyncMock() + return mock_context + + +@pytest.fixture +def mock_page(): + """Create a mock Playwright Page.""" + page = MagicMock() + page.url = "https://example.com" + page.evaluate = AsyncMock(return_value=False) + page.add_init_script = AsyncMock() + page.context = MagicMock() + page.once = MagicMock() + return page + + +class TestFrameIdTracking: + """Test suite for frame ID tracking functionality.""" + + def test_stagehand_context_initialization(self, mock_browser_context, mock_stagehand): + """Test that StagehandContext initializes with frame_id_map.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + + assert hasattr(context, 'frame_id_map') + assert isinstance(context.frame_id_map, dict) + assert len(context.frame_id_map) == 0 + + def test_register_frame_id(self, mock_browser_context, mock_stagehand, mock_page): + """Test registering a frame ID.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Register frame ID + frame_id = "frame-123" + context.register_frame_id(frame_id, stagehand_page) + + assert frame_id in context.frame_id_map + assert context.frame_id_map[frame_id] == stagehand_page + + def test_unregister_frame_id(self, mock_browser_context, mock_stagehand, mock_page): + """Test unregistering a frame ID.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Register and then unregister + frame_id = "frame-456" + context.register_frame_id(frame_id, stagehand_page) + context.unregister_frame_id(frame_id) + + assert frame_id not in context.frame_id_map + + def test_get_stagehand_page_by_frame_id(self, mock_browser_context, mock_stagehand, mock_page): + """Test retrieving a StagehandPage by frame ID.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + frame_id = "frame-789" + context.register_frame_id(frame_id, stagehand_page) + + retrieved_page = context.get_stagehand_page_by_frame_id(frame_id) + assert retrieved_page == stagehand_page + + # Test non-existent frame ID + non_existent = context.get_stagehand_page_by_frame_id("non-existent") + assert non_existent is None + + def test_stagehand_page_frame_id_property(self, mock_page, mock_stagehand): + """Test StagehandPage frame_id property and update method.""" + stagehand_page = StagehandPage(mock_page, mock_stagehand) + + # Initially None + assert stagehand_page.frame_id is None + + # Update frame ID + new_frame_id = "frame-abc" + stagehand_page.update_root_frame_id(new_frame_id) + + assert stagehand_page.frame_id == new_frame_id + mock_stagehand.logger.debug.assert_called_with( + f"Updated frame ID to {new_frame_id}", category="page" + ) + + @pytest.mark.asyncio + async def test_attach_frame_navigated_listener(self, mock_browser_context, mock_stagehand, mock_page): + """Test attaching CDP frame navigation listener.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Mock CDP session + mock_cdp_session = MagicMock() + mock_cdp_session.send = AsyncMock() + mock_cdp_session.on = MagicMock() + mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session) + + # Mock frame tree response + mock_cdp_session.send.return_value = { + "frameTree": { + "frame": { + "id": "initial-frame-id" + } + } + } + + # Attach listener + await context._attach_frame_navigated_listener(mock_page, stagehand_page) + + # Verify CDP session was created and Page domain was enabled + mock_browser_context.new_cdp_session.assert_called_once_with(mock_page) + mock_cdp_session.send.assert_any_call("Page.enable") + mock_cdp_session.send.assert_any_call("Page.getFrameTree") + + # Verify frame ID was set + assert stagehand_page.frame_id == "initial-frame-id" + assert "initial-frame-id" in context.frame_id_map + + # Verify event listener was registered + mock_cdp_session.on.assert_called_once() + assert mock_cdp_session.on.call_args[0][0] == "Page.frameNavigated" + + @pytest.mark.asyncio + async def test_frame_id_in_api_calls(self, mock_page, mock_stagehand): + """Test that frame ID is included in API payloads.""" + stagehand_page = StagehandPage(mock_page, mock_stagehand) + stagehand_page.update_root_frame_id("test-frame-123") + + # Mock the stagehand client for API mode + mock_stagehand.use_api = True + mock_stagehand._get_lock_for_session = MagicMock() + mock_stagehand._get_lock_for_session.return_value = AsyncMock() + mock_stagehand._execute = AsyncMock(return_value={"success": True}) + + # Test goto with frame ID + await stagehand_page.goto("https://example.com") + + # Verify frame ID was included in the payload + call_args = mock_stagehand._execute.call_args + assert call_args[0][0] == "navigate" + assert "frameId" in call_args[0][1] + assert call_args[0][1]["frameId"] == "test-frame-123" + + @pytest.mark.asyncio + async def test_frame_navigation_event_handling(self, mock_browser_context, mock_stagehand, mock_page): + """Test handling of frame navigation events.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Set initial frame ID + initial_frame_id = "frame-initial" + stagehand_page.update_root_frame_id(initial_frame_id) + context.register_frame_id(initial_frame_id, stagehand_page) + + # Mock CDP session + mock_cdp_session = MagicMock() + mock_cdp_session.send = AsyncMock() + mock_cdp_session.on = MagicMock() + mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session) + + # Mock initial frame tree + mock_cdp_session.send.return_value = { + "frameTree": { + "frame": { + "id": initial_frame_id + } + } + } + + await context._attach_frame_navigated_listener(mock_page, stagehand_page) + + # Get the registered event handler + event_handler = mock_cdp_session.on.call_args[0][1] + + # Simulate frame navigation event + new_frame_id = "frame-new" + event_handler({ + "frame": { + "id": new_frame_id, + "parentId": None # Root frame has no parent + } + }) + + # Verify old frame ID was unregistered and new one registered + assert initial_frame_id not in context.frame_id_map + assert new_frame_id in context.frame_id_map + assert stagehand_page.frame_id == new_frame_id + assert context.frame_id_map[new_frame_id] == stagehand_page \ No newline at end of file