diff --git a/.changeset/snobbish-fortunate-tamarin.md b/.changeset/snobbish-fortunate-tamarin.md new file mode 100644 index 0000000..7dcfd44 --- /dev/null +++ b/.changeset/snobbish-fortunate-tamarin.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +updated a few multi-tab support issues of concern with frame id map diff --git a/stagehand/context.py b/stagehand/context.py index aa99f07..a3fdd16 100644 --- a/stagehand/context.py +++ b/stagehand/context.py @@ -15,7 +15,10 @@ def __init__(self, context: BrowserContext, stagehand): self.page_map = weakref.WeakKeyDictionary() self.active_stagehand_page = None # Map frame IDs to StagehandPage instances - self.frame_id_map = {} + # Using WeakValueDictionary to prevent memory leaks + self.frame_id_map = weakref.WeakValueDictionary() + # Lock for frame ID operations to prevent race conditions + self._frame_lock = asyncio.Lock() async def new_page(self) -> StagehandPage: pw_page: Page = await self._context.new_page() @@ -176,29 +179,66 @@ async def _attach_frame_navigated_listener( Attach CDP listener for frame navigation events to track frame IDs. This mirrors the TypeScript implementation's frame tracking. """ + cdp_session = None try: # Create CDP session for the page cdp_session = await self._context.new_cdp_session(pw_page) await cdp_session.send("Page.enable") + # Store CDP session reference for cleanup + stagehand_page._cdp_session = cdp_session + # Get the current root frame ID frame_tree = await cdp_session.send("Page.getFrameTree") root_frame_id = frame_tree.get("frameTree", {}).get("frame", {}).get("id") if root_frame_id: # Initialize the page with its frame ID - stagehand_page.update_root_frame_id(root_frame_id) - self.register_frame_id(root_frame_id, stagehand_page) + async with self._frame_lock: + stagehand_page.update_root_frame_id(root_frame_id) + self.register_frame_id(root_frame_id, stagehand_page) # Set up event listener for frame navigation def on_frame_navigated(params): """Handle Page.frameNavigated events""" - frame = params.get("frame", {}) - frame_id = frame.get("id") - parent_id = frame.get("parentId") + # Schedule async work to handle the event + asyncio.create_task( + self._handle_frame_navigated(params, stagehand_page) + ) + + # Register the event listener + cdp_session.on("Page.frameNavigated", on_frame_navigated) + + # Clean up frame ID and CDP session when page closes + def on_page_close(): + asyncio.create_task(self._cleanup_page_resources(stagehand_page)) + + pw_page.once("close", on_page_close) + + except Exception as e: + self.stagehand.logger.error( + f"Failed to attach frame navigation listener: {str(e)}", + category="context", + ) + # Clean up CDP session if it was created + if cdp_session: + try: + await cdp_session.detach() + except Exception: + pass + + async def _handle_frame_navigated( + self, params: dict, stagehand_page: StagehandPage + ): + """Async handler for frame navigation events.""" + try: + frame = params.get("frame", {}) + frame_id = frame.get("id") + parent_id = frame.get("parentId") - # Only track root frames (no parent) - if not parent_id and frame_id: + # Only track root frames (no parent) + if not parent_id and frame_id: + async with self._frame_lock: # Skip if it's the same frame ID if frame_id == stagehand_page.frame_id: return @@ -216,19 +256,33 @@ def on_frame_navigated(params): f"Frame navigated from {old_id} to {frame_id}", category="context", ) + except Exception as e: + self.stagehand.logger.error( + f"Error handling frame navigation: {str(e)}", + category="context", + ) - # Register the event listener - cdp_session.on("Page.frameNavigated", on_frame_navigated) - - # Clean up frame ID when page closes - def on_page_close(): + async def _cleanup_page_resources(self, stagehand_page: StagehandPage): + """Clean up resources when a page closes.""" + try: + async with self._frame_lock: if stagehand_page.frame_id: self.unregister_frame_id(stagehand_page.frame_id) - pw_page.once("close", on_page_close) - + # Clean up CDP session if it exists + if ( + hasattr(stagehand_page, "_cdp_session") + and stagehand_page._cdp_session + ): + try: + await stagehand_page._cdp_session.detach() + except Exception as e: + self.stagehand.logger.debug( + f"Error detaching CDP session: {str(e)}", + category="context", + ) except Exception as e: self.stagehand.logger.error( - f"Failed to attach frame navigation listener: {str(e)}", + f"Error cleaning up page resources: {str(e)}", category="context", ) diff --git a/stagehand/page.py b/stagehand/page.py index c3e8831..ef2c33b 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -113,8 +113,8 @@ async def goto( if options: payload["options"] = options - # Add frame ID if available - if self._frame_id: + # Add frame ID if available and context exists + if self._context and self._frame_id: payload["frameId"] = self._frame_id lock = self._stagehand._get_lock_for_session() @@ -187,8 +187,8 @@ async def act( result = await self._act_handler.act(payload) return result - # Add frame ID if available - if self._frame_id: + # Add frame ID if available and context exists + if self._context and self._frame_id: payload["frameId"] = self._frame_id lock = self._stagehand._get_lock_for_session() @@ -260,8 +260,8 @@ async def observe( return result - # Add frame ID if available - if self._frame_id: + # Add frame ID if available and context exists + if self._context and self._frame_id: payload["frameId"] = self._frame_id lock = self._stagehand._get_lock_for_session() @@ -388,8 +388,8 @@ async def extract( return result.data # Use API - # Add frame ID if available - if self._frame_id: + # Add frame ID if available and context exists + if self._context and self._frame_id: payload["frameId"] = self._frame_id lock = self._stagehand._get_lock_for_session() diff --git a/tests/unit/core/test_frame_id_tracking.py b/tests/unit/core/test_frame_id_tracking.py index f18214b..5643cf7 100644 --- a/tests/unit/core/test_frame_id_tracking.py +++ b/tests/unit/core/test_frame_id_tracking.py @@ -3,7 +3,9 @@ Tests the implementation of frame ID map in StagehandContext and StagehandPage. """ +import asyncio import pytest +import weakref from unittest.mock import AsyncMock, MagicMock, patch from stagehand.context import StagehandContext from stagehand.page import StagehandPage @@ -50,8 +52,10 @@ def test_stagehand_context_initialization(self, mock_browser_context, mock_stage context = StagehandContext(mock_browser_context, mock_stagehand) assert hasattr(context, 'frame_id_map') - assert isinstance(context.frame_id_map, dict) + assert isinstance(context.frame_id_map, weakref.WeakValueDictionary) assert len(context.frame_id_map) == 0 + assert hasattr(context, '_frame_lock') + assert isinstance(context._frame_lock, asyncio.Lock) def test_register_frame_id(self, mock_browser_context, mock_stagehand, mock_page): """Test registering a frame ID.""" @@ -146,9 +150,10 @@ async def test_attach_frame_navigated_listener(self, mock_browser_context, mock_ assert mock_cdp_session.on.call_args[0][0] == "Page.frameNavigated" @pytest.mark.asyncio - async def test_frame_id_in_api_calls(self, mock_page, mock_stagehand): + async def test_frame_id_in_api_calls(self, mock_page, mock_stagehand, mock_browser_context): """Test that frame ID is included in API payloads.""" - stagehand_page = StagehandPage(mock_page, mock_stagehand) + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) stagehand_page.update_root_frame_id("test-frame-123") # Mock the stagehand client for API mode @@ -198,6 +203,7 @@ async def test_frame_navigation_event_handling(self, mock_browser_context, mock_ event_handler = mock_cdp_session.on.call_args[0][1] # Simulate frame navigation event + # The event handler schedules async work, so we need to wait for it new_frame_id = "frame-new" event_handler({ "frame": { @@ -206,8 +212,138 @@ async def test_frame_navigation_event_handling(self, mock_browser_context, mock_ } }) + # Wait for async tasks to complete + await asyncio.sleep(0.1) + # Verify old frame ID was unregistered and new one registered assert initial_frame_id not in context.frame_id_map assert new_frame_id in context.frame_id_map assert stagehand_page.frame_id == new_frame_id - assert context.frame_id_map[new_frame_id] == stagehand_page \ No newline at end of file + assert context.frame_id_map[new_frame_id] == stagehand_page + + @pytest.mark.asyncio + async def test_cdp_session_failure(self, mock_browser_context, mock_stagehand, mock_page): + """Test handling of CDP session creation failure.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Mock CDP session creation to fail + mock_browser_context.new_cdp_session = AsyncMock(side_effect=Exception("CDP connection failed")) + + # Attach listener should handle the error gracefully + await context._attach_frame_navigated_listener(mock_page, stagehand_page) + + # Verify error was logged + mock_stagehand.logger.error.assert_called_with( + "Failed to attach frame navigation listener: CDP connection failed", + category="context", + ) + + # Verify page still works without frame ID + assert stagehand_page.frame_id is None + assert len(context.frame_id_map) == 0 + + @pytest.mark.asyncio + async def test_cdp_session_cleanup_on_failure(self, mock_browser_context, mock_stagehand, mock_page): + """Test that CDP session is cleaned up if attachment fails after creation.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Mock CDP session + mock_cdp_session = MagicMock() + mock_cdp_session.send = AsyncMock(side_effect=Exception("Page.enable failed")) + mock_cdp_session.detach = AsyncMock() + mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session) + + # Attach listener should handle the error + await context._attach_frame_navigated_listener(mock_page, stagehand_page) + + # Verify CDP session was detached after failure + mock_cdp_session.detach.assert_called_once() + + # Verify error was logged + mock_stagehand.logger.error.assert_called() + + @pytest.mark.asyncio + async def test_page_cleanup_with_cdp_session(self, mock_browser_context, mock_stagehand, mock_page): + """Test that CDP session is properly cleaned up when page closes.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Mock CDP session + mock_cdp_session = MagicMock() + mock_cdp_session.send = AsyncMock() + mock_cdp_session.on = MagicMock() + mock_cdp_session.detach = AsyncMock() + mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session) + + # Mock frame tree response + mock_cdp_session.send.return_value = { + "frameTree": { + "frame": { + "id": "test-frame-id" + } + } + } + + # Attach listener + await context._attach_frame_navigated_listener(mock_page, stagehand_page) + + # Verify CDP session was stored + assert hasattr(stagehand_page, '_cdp_session') + assert stagehand_page._cdp_session == mock_cdp_session + + # Simulate page close by calling cleanup + await context._cleanup_page_resources(stagehand_page) + + # Verify CDP session was detached + mock_cdp_session.detach.assert_called_once() + + @pytest.mark.asyncio + async def test_frame_id_race_condition_prevention(self, mock_browser_context, mock_stagehand, mock_page): + """Test that frame lock prevents race conditions.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + + # Track actual calls to the real methods through patching + with patch.object(context, 'register_frame_id', wraps=context.register_frame_id) as mock_reg, \ + patch.object(context, 'unregister_frame_id', wraps=context.unregister_frame_id) as mock_unreg: + + # Simulate multiple concurrent frame navigations + tasks = [ + context._handle_frame_navigated( + {"frame": {"id": f"frame-{i}", "parentId": None}}, + stagehand_page + ) + for i in range(3) + ] + + await asyncio.gather(*tasks) + + # Verify methods were called the expected number of times + # Each navigation causes an unregister (if there's an old frame) and a register + assert mock_reg.call_count == 3 # 3 registers + assert mock_unreg.call_count >= 2 # At least 2 unregisters (first might not have old frame) + + @pytest.mark.asyncio + async def test_weak_reference_cleanup(self, mock_browser_context, mock_stagehand, mock_page): + """Test that weak references allow garbage collection.""" + context = StagehandContext(mock_browser_context, mock_stagehand) + + # Create a page and register it + stagehand_page = StagehandPage(mock_page, mock_stagehand, context) + frame_id = "frame-weak-ref" + context.register_frame_id(frame_id, stagehand_page) + + assert frame_id in context.frame_id_map + assert context.frame_id_map[frame_id] == stagehand_page + + # Delete the only strong reference to the page + del stagehand_page + + # Force garbage collection + import gc + gc.collect() + + # The weak reference should be gone + assert frame_id not in context.frame_id_map \ No newline at end of file