Skip to content

☒ Fix race condition - make on_frame_navigated callback async #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: miguel/stg-607-feature-parity-page-map
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/snobbish-fortunate-tamarin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

updated a few multi-tab support issues of concern with frame id map
86 changes: 70 additions & 16 deletions stagehand/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def __init__(self, context: BrowserContext, stagehand):
self.page_map = weakref.WeakKeyDictionary()
self.active_stagehand_page = None
# Map frame IDs to StagehandPage instances
self.frame_id_map = {}
# Using WeakValueDictionary to prevent memory leaks
self.frame_id_map = weakref.WeakValueDictionary()
# Lock for frame ID operations to prevent race conditions
self._frame_lock = asyncio.Lock()

async def new_page(self) -> StagehandPage:
pw_page: Page = await self._context.new_page()
Expand Down Expand Up @@ -176,29 +179,66 @@ async def _attach_frame_navigated_listener(
Attach CDP listener for frame navigation events to track frame IDs.
This mirrors the TypeScript implementation's frame tracking.
"""
cdp_session = None
try:
# Create CDP session for the page
cdp_session = await self._context.new_cdp_session(pw_page)
await cdp_session.send("Page.enable")

# Store CDP session reference for cleanup
stagehand_page._cdp_session = cdp_session

# Get the current root frame ID
frame_tree = await cdp_session.send("Page.getFrameTree")
root_frame_id = frame_tree.get("frameTree", {}).get("frame", {}).get("id")

if root_frame_id:
# Initialize the page with its frame ID
stagehand_page.update_root_frame_id(root_frame_id)
self.register_frame_id(root_frame_id, stagehand_page)
async with self._frame_lock:
stagehand_page.update_root_frame_id(root_frame_id)
self.register_frame_id(root_frame_id, stagehand_page)

# Set up event listener for frame navigation
def on_frame_navigated(params):
"""Handle Page.frameNavigated events"""
frame = params.get("frame", {})
frame_id = frame.get("id")
parent_id = frame.get("parentId")
# Schedule async work to handle the event
asyncio.create_task(
self._handle_frame_navigated(params, stagehand_page)
)

# Register the event listener
cdp_session.on("Page.frameNavigated", on_frame_navigated)

# Clean up frame ID and CDP session when page closes
def on_page_close():
asyncio.create_task(self._cleanup_page_resources(stagehand_page))

pw_page.once("close", on_page_close)

except Exception as e:
self.stagehand.logger.error(
f"Failed to attach frame navigation listener: {str(e)}",
category="context",
)
# Clean up CDP session if it was created
if cdp_session:
try:
await cdp_session.detach()
except Exception:
pass

async def _handle_frame_navigated(
self, params: dict, stagehand_page: StagehandPage
):
"""Async handler for frame navigation events."""
try:
frame = params.get("frame", {})
frame_id = frame.get("id")
parent_id = frame.get("parentId")

# Only track root frames (no parent)
if not parent_id and frame_id:
# Only track root frames (no parent)
if not parent_id and frame_id:
async with self._frame_lock:
# Skip if it's the same frame ID
if frame_id == stagehand_page.frame_id:
return
Expand All @@ -216,19 +256,33 @@ def on_frame_navigated(params):
f"Frame navigated from {old_id} to {frame_id}",
category="context",
)
except Exception as e:
self.stagehand.logger.error(
f"Error handling frame navigation: {str(e)}",
category="context",
)

# Register the event listener
cdp_session.on("Page.frameNavigated", on_frame_navigated)

# Clean up frame ID when page closes
def on_page_close():
async def _cleanup_page_resources(self, stagehand_page: StagehandPage):
"""Clean up resources when a page closes."""
try:
async with self._frame_lock:
if stagehand_page.frame_id:
self.unregister_frame_id(stagehand_page.frame_id)

pw_page.once("close", on_page_close)

# Clean up CDP session if it exists
if (
hasattr(stagehand_page, "_cdp_session")
and stagehand_page._cdp_session
):
try:
await stagehand_page._cdp_session.detach()
except Exception as e:
self.stagehand.logger.debug(
f"Error detaching CDP session: {str(e)}",
category="context",
)
except Exception as e:
self.stagehand.logger.error(
f"Failed to attach frame navigation listener: {str(e)}",
f"Error cleaning up page resources: {str(e)}",
category="context",
)
16 changes: 8 additions & 8 deletions stagehand/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ async def goto(
if options:
payload["options"] = options

# Add frame ID if available
if self._frame_id:
# Add frame ID if available and context exists
if self._context and self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
Expand Down Expand Up @@ -187,8 +187,8 @@ async def act(
result = await self._act_handler.act(payload)
return result

# Add frame ID if available
if self._frame_id:
# Add frame ID if available and context exists
if self._context and self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
Expand Down Expand Up @@ -260,8 +260,8 @@ async def observe(

return result

# Add frame ID if available
if self._frame_id:
# Add frame ID if available and context exists
if self._context and self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
Expand Down Expand Up @@ -388,8 +388,8 @@ async def extract(
return result.data

# Use API
# Add frame ID if available
if self._frame_id:
# Add frame ID if available and context exists
if self._context and self._frame_id:
payload["frameId"] = self._frame_id

lock = self._stagehand._get_lock_for_session()
Expand Down
144 changes: 140 additions & 4 deletions tests/unit/core/test_frame_id_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
Tests the implementation of frame ID map in StagehandContext and StagehandPage.
"""

import asyncio
import pytest
import weakref
from unittest.mock import AsyncMock, MagicMock, patch
from stagehand.context import StagehandContext
from stagehand.page import StagehandPage
Expand Down Expand Up @@ -50,8 +52,10 @@ def test_stagehand_context_initialization(self, mock_browser_context, mock_stage
context = StagehandContext(mock_browser_context, mock_stagehand)

assert hasattr(context, 'frame_id_map')
assert isinstance(context.frame_id_map, dict)
assert isinstance(context.frame_id_map, weakref.WeakValueDictionary)
assert len(context.frame_id_map) == 0
assert hasattr(context, '_frame_lock')
assert isinstance(context._frame_lock, asyncio.Lock)

def test_register_frame_id(self, mock_browser_context, mock_stagehand, mock_page):
"""Test registering a frame ID."""
Expand Down Expand Up @@ -146,9 +150,10 @@ async def test_attach_frame_navigated_listener(self, mock_browser_context, mock_
assert mock_cdp_session.on.call_args[0][0] == "Page.frameNavigated"

@pytest.mark.asyncio
async def test_frame_id_in_api_calls(self, mock_page, mock_stagehand):
async def test_frame_id_in_api_calls(self, mock_page, mock_stagehand, mock_browser_context):
"""Test that frame ID is included in API payloads."""
stagehand_page = StagehandPage(mock_page, mock_stagehand)
context = StagehandContext(mock_browser_context, mock_stagehand)
stagehand_page = StagehandPage(mock_page, mock_stagehand, context)
stagehand_page.update_root_frame_id("test-frame-123")

# Mock the stagehand client for API mode
Expand Down Expand Up @@ -198,6 +203,7 @@ async def test_frame_navigation_event_handling(self, mock_browser_context, mock_
event_handler = mock_cdp_session.on.call_args[0][1]

# Simulate frame navigation event
# The event handler schedules async work, so we need to wait for it
new_frame_id = "frame-new"
event_handler({
"frame": {
Expand All @@ -206,8 +212,138 @@ async def test_frame_navigation_event_handling(self, mock_browser_context, mock_
}
})

# Wait for async tasks to complete
await asyncio.sleep(0.1)

# Verify old frame ID was unregistered and new one registered
assert initial_frame_id not in context.frame_id_map
assert new_frame_id in context.frame_id_map
assert stagehand_page.frame_id == new_frame_id
assert context.frame_id_map[new_frame_id] == stagehand_page
assert context.frame_id_map[new_frame_id] == stagehand_page

@pytest.mark.asyncio
async def test_cdp_session_failure(self, mock_browser_context, mock_stagehand, mock_page):
"""Test handling of CDP session creation failure."""
context = StagehandContext(mock_browser_context, mock_stagehand)
stagehand_page = StagehandPage(mock_page, mock_stagehand, context)

# Mock CDP session creation to fail
mock_browser_context.new_cdp_session = AsyncMock(side_effect=Exception("CDP connection failed"))

# Attach listener should handle the error gracefully
await context._attach_frame_navigated_listener(mock_page, stagehand_page)

# Verify error was logged
mock_stagehand.logger.error.assert_called_with(
"Failed to attach frame navigation listener: CDP connection failed",
category="context",
)

# Verify page still works without frame ID
assert stagehand_page.frame_id is None
assert len(context.frame_id_map) == 0

@pytest.mark.asyncio
async def test_cdp_session_cleanup_on_failure(self, mock_browser_context, mock_stagehand, mock_page):
"""Test that CDP session is cleaned up if attachment fails after creation."""
context = StagehandContext(mock_browser_context, mock_stagehand)
stagehand_page = StagehandPage(mock_page, mock_stagehand, context)

# Mock CDP session
mock_cdp_session = MagicMock()
mock_cdp_session.send = AsyncMock(side_effect=Exception("Page.enable failed"))
mock_cdp_session.detach = AsyncMock()
mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session)

# Attach listener should handle the error
await context._attach_frame_navigated_listener(mock_page, stagehand_page)

# Verify CDP session was detached after failure
mock_cdp_session.detach.assert_called_once()

# Verify error was logged
mock_stagehand.logger.error.assert_called()

@pytest.mark.asyncio
async def test_page_cleanup_with_cdp_session(self, mock_browser_context, mock_stagehand, mock_page):
"""Test that CDP session is properly cleaned up when page closes."""
context = StagehandContext(mock_browser_context, mock_stagehand)
stagehand_page = StagehandPage(mock_page, mock_stagehand, context)

# Mock CDP session
mock_cdp_session = MagicMock()
mock_cdp_session.send = AsyncMock()
mock_cdp_session.on = MagicMock()
mock_cdp_session.detach = AsyncMock()
mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session)

# Mock frame tree response
mock_cdp_session.send.return_value = {
"frameTree": {
"frame": {
"id": "test-frame-id"
}
}
}

# Attach listener
await context._attach_frame_navigated_listener(mock_page, stagehand_page)

# Verify CDP session was stored
assert hasattr(stagehand_page, '_cdp_session')
assert stagehand_page._cdp_session == mock_cdp_session

# Simulate page close by calling cleanup
await context._cleanup_page_resources(stagehand_page)

# Verify CDP session was detached
mock_cdp_session.detach.assert_called_once()

@pytest.mark.asyncio
async def test_frame_id_race_condition_prevention(self, mock_browser_context, mock_stagehand, mock_page):
"""Test that frame lock prevents race conditions."""
context = StagehandContext(mock_browser_context, mock_stagehand)
stagehand_page = StagehandPage(mock_page, mock_stagehand, context)

# Track actual calls to the real methods through patching
with patch.object(context, 'register_frame_id', wraps=context.register_frame_id) as mock_reg, \
patch.object(context, 'unregister_frame_id', wraps=context.unregister_frame_id) as mock_unreg:

# Simulate multiple concurrent frame navigations
tasks = [
context._handle_frame_navigated(
{"frame": {"id": f"frame-{i}", "parentId": None}},
stagehand_page
)
for i in range(3)
]

await asyncio.gather(*tasks)

# Verify methods were called the expected number of times
# Each navigation causes an unregister (if there's an old frame) and a register
assert mock_reg.call_count == 3 # 3 registers
assert mock_unreg.call_count >= 2 # At least 2 unregisters (first might not have old frame)

@pytest.mark.asyncio
async def test_weak_reference_cleanup(self, mock_browser_context, mock_stagehand, mock_page):
"""Test that weak references allow garbage collection."""
context = StagehandContext(mock_browser_context, mock_stagehand)

# Create a page and register it
stagehand_page = StagehandPage(mock_page, mock_stagehand, context)
frame_id = "frame-weak-ref"
context.register_frame_id(frame_id, stagehand_page)

assert frame_id in context.frame_id_map
assert context.frame_id_map[frame_id] == stagehand_page

# Delete the only strong reference to the page
del stagehand_page

# Force garbage collection
import gc
gc.collect()

# The weak reference should be gone
assert frame_id not in context.frame_id_map