From 14a699465f7d2421bf9c88cadff371aeec1347bc Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Tue, 15 Jul 2025 16:30:30 +0530 Subject: [PATCH 1/8] add event classes and tests --- .../webdriver/common/bidi/browsing_context.py | 239 +++++++++++++++--- .../common/bidi_browsing_context_tests.py | 214 ++++++++++++++++ 2 files changed, 415 insertions(+), 38 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 27656be2a1c6e..acbbf0590d744 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -399,6 +399,154 @@ def from_json(cls, json: dict) -> "BrowsingContextEvent": return cls(event_class=event_class, **json) +class ContextCreated: + """Event class for browsingContext.contextCreated event.""" + + event_class = "browsingContext.contextCreated" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, BrowsingContextInfo): + return json + return BrowsingContextInfo.from_json(json) + + +class ContextDestroyed: + """Event class for browsingContext.contextDestroyed event.""" + + event_class = "browsingContext.contextDestroyed" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, BrowsingContextInfo): + return json + return BrowsingContextInfo.from_json(json) + + +class NavigationStarted: + """Event class for browsingContext.navigationStarted event.""" + + event_class = "browsingContext.navigationStarted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationCommitted: + """Event class for browsingContext.navigationCommitted event.""" + + event_class = "browsingContext.navigationCommitted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationFailed: + """Event class for browsingContext.navigationFailed event.""" + + event_class = "browsingContext.navigationFailed" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationAborted: + """Event class for browsingContext.navigationAborted event.""" + + event_class = "browsingContext.navigationAborted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class DomContentLoaded: + """Event class for browsingContext.domContentLoaded event.""" + + event_class = "browsingContext.domContentLoaded" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class Load: + """Event class for browsingContext.load event.""" + + event_class = "browsingContext.load" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class FragmentNavigated: + """Event class for browsingContext.fragmentNavigated event.""" + + event_class = "browsingContext.fragmentNavigated" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class DownloadWillBegin: + """Event class for browsingContext.downloadWillBegin event.""" + + event_class = "browsingContext.downloadWillBegin" + + @classmethod + def from_json(cls, json: dict): + return DownloadWillBeginParams.from_json(json) + + +class UserPromptOpened: + """Event class for browsingContext.userPromptOpened event.""" + + event_class = "browsingContext.userPromptOpened" + + @classmethod + def from_json(cls, json: dict): + return UserPromptOpenedParams.from_json(json) + + +class UserPromptClosed: + """Event class for browsingContext.userPromptClosed event.""" + + event_class = "browsingContext.userPromptClosed" + + @classmethod + def from_json(cls, json: dict): + return UserPromptClosedParams.from_json(json) + + +class HistoryUpdated: + """Event class for browsingContext.historyUpdated event.""" + + event_class = "browsingContext.historyUpdated" + + @classmethod + def from_json(cls, json: dict): + return HistoryUpdatedParams.from_json(json) + + class BrowsingContext: """BiDi implementation of the browsingContext module.""" @@ -418,6 +566,22 @@ class BrowsingContext: "user_prompt_opened": "browsingContext.userPromptOpened", } + EVENT_CLASSES = { + "browsingContext.contextCreated": ContextCreated, + "browsingContext.contextDestroyed": ContextDestroyed, + "browsingContext.domContentLoaded": DomContentLoaded, + "browsingContext.downloadWillBegin": DownloadWillBegin, + "browsingContext.fragmentNavigated": FragmentNavigated, + "browsingContext.historyUpdated": HistoryUpdated, + "browsingContext.load": Load, + "browsingContext.navigationAborted": NavigationAborted, + "browsingContext.navigationCommitted": NavigationCommitted, + "browsingContext.navigationFailed": NavigationFailed, + "browsingContext.navigationStarted": NavigationStarted, + "browsingContext.userPromptClosed": UserPromptClosed, + "browsingContext.userPromptOpened": UserPromptOpened, + } + def __init__(self, conn): self.conn = conn self.subscriptions = {} @@ -751,30 +915,16 @@ def _on_event(self, event_name: str, callback: Callable) -> int: ------- int: callback id """ - event = BrowsingContextEvent(event_name) + event_class = self.EVENT_CLASSES.get(event_name) + if not event_class: + raise Exception(f"Event class for {event_name} not found") def _callback(event_data): - if event_name == self.EVENTS["context_created"] or event_name == self.EVENTS["context_destroyed"]: - info = BrowsingContextInfo.from_json(event_data.params) - callback(info) - elif event_name == self.EVENTS["download_will_begin"]: - params = DownloadWillBeginParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["user_prompt_opened"]: - params = UserPromptOpenedParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["user_prompt_closed"]: - params = UserPromptClosedParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["history_updated"]: - params = HistoryUpdatedParams.from_json(event_data.params) - callback(params) - else: - # For navigation events - info = NavigationInfo.from_json(event_data.params) - callback(info) + # Parse the event data using the appropriate event class + parsed_data = event_class.from_json(event_data) + callback(parsed_data) - callback_id = self.conn.add_callback(event, _callback) + callback_id = self.conn.add_callback(event_class, _callback) if event_name in self.callbacks: self.callbacks[event_name].append(callback_id) @@ -806,11 +956,11 @@ def add_event_handler(self, event: str, callback: Callable, contexts: Optional[l if event_name in self.subscriptions: self.subscriptions[event_name].append(callback_id) else: - params = {"events": [event_name]} - if contexts is not None: - params["browsingContexts"] = contexts session = Session(self.conn) - self.conn.execute(session.subscribe(**params)) + if contexts is not None: + self.conn.execute(session.subscribe(event_name, browsing_contexts=contexts)) + else: + self.conn.execute(session.subscribe(event_name)) self.subscriptions[event_name] = [callback_id] return callback_id @@ -828,26 +978,39 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: except KeyError: raise Exception(f"Event {event} not found") - event_obj = BrowsingContextEvent(event_name) + event_class = self.EVENT_CLASSES.get(event_name) + if not event_class: + raise Exception(f"Event class for {event_name} not found") - self.conn.remove_callback(event_obj, callback_id) - if event_name in self.subscriptions: - callbacks = self.subscriptions[event_name] + self.conn.remove_callback(event_class, callback_id) + + # Remove from callbacks tracking + if event_name in self.callbacks: + callbacks = self.callbacks[event_name] if callback_id in callbacks: callbacks.remove(callback_id) if not callbacks: - params = {"events": [event_name]} + del self.callbacks[event_name] + + # Remove from subscriptions and unsubscribe if no more callbacks + if event_name in self.subscriptions: + subscription_callbacks = self.subscriptions[event_name] + if callback_id in subscription_callbacks: + subscription_callbacks.remove(callback_id) + if not subscription_callbacks: session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) + self.conn.execute(session.unsubscribe(event_name)) del self.subscriptions[event_name] def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" - for event_name in self.subscriptions: - event = BrowsingContextEvent(event_name) - for callback_id in self.subscriptions[event_name]: - self.conn.remove_callback(event, callback_id) - params = {"events": [event_name]} - session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) + for event_name in list(self.subscriptions.keys()): + event_class = self.EVENT_CLASSES.get(event_name) + if event_class: + for callback_id in self.subscriptions[event_name]: + self.conn.remove_callback(event_class, callback_id) + session = Session(self.conn) + self.conn.execute(session.unsubscribe(event_name)) + self.subscriptions = {} + self.callbacks = {} diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 74a0f53aaf640..9a75cb3a535c4 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -16,6 +16,7 @@ # under the License. import base64 +import threading import pytest @@ -525,3 +526,216 @@ def test_locate_nodes_given_start_nodes(driver, pages): ) # The login form should have 3 input elements (email, age, and submit button) assert len(elements) == 3 + + +# Tests for event handlers + + +def test_add_event_handler_context_created(driver): + """Test adding event handler for context_created event.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + assert callback_id is not None + + # Create a new context to trigger the event + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify the event was received (might be > 1 since default context is also included) + assert len(events_received) >= 1 + assert events_received[0].context == context_id or events_received[1].context == context_id + + driver.browsing_context.close(context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id) + + +def test_add_event_handler_context_destroyed(driver): + """Test adding event handler for context_destroyed event.""" + events_received = [] + + def on_context_destroyed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_destroyed", on_context_destroyed) + assert callback_id is not None + + # Create and then close a context to trigger the event + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(context_id) + + assert len(events_received) == 1 + assert events_received[0].context == context_id + + driver.browsing_context.remove_event_handler("context_destroyed", callback_id) + + +def test_add_event_handler_navigation_committed(driver, pages): + """Test adding event handler for navigation_committed event.""" + events_received = [] + + def on_navigation_committed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_committed", on_navigation_committed) + assert callback_id is not None + + # Navigate to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) >= 1 + assert any(url in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("navigation_committed", callback_id) + + +def test_add_event_handler_with_specific_contexts(driver): + """Test adding event handler with specific browsing contexts.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Add event handler for specific context + callback_id = driver.browsing_context.add_event_handler( + "context_created", on_context_created, contexts=[context_id] + ) + assert callback_id is not None + + # Create another context (should trigger event) + new_context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + assert len(events_received) >= 1 + + driver.browsing_context.close(context_id) + driver.browsing_context.close(new_context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id) + + +def test_remove_event_handler(driver): + """Test removing event handler.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + + # Create a context to trigger the event + context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) + + initial_events = len(events_received) + + # Remove the event handler + driver.browsing_context.remove_event_handler("context_created", callback_id) + + # Create another context (should not trigger event after removal) + context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify no new events were received after removal + assert len(events_received) == initial_events + + driver.browsing_context.close(context_id_1) + driver.browsing_context.close(context_id_2) + + +def test_multiple_event_handlers_same_event(driver): + """Test adding multiple event handlers for the same event.""" + events_received_1 = [] + events_received_2 = [] + + def on_context_created_1(info): + events_received_1.append(info) + + def on_context_created_2(info): + events_received_2.append(info) + + # Add multiple event handlers for the same event + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + + # Create a context to trigger both handlers + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify both handlers received the event + assert len(events_received_1) >= 1 + assert len(events_received_2) >= 1 + # Check any of the events has the required context ID + assert any(event.context == context_id for event in events_received_1) + assert any(event.context == context_id for event in events_received_2) + + driver.browsing_context.close(context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id_1) + driver.browsing_context.remove_event_handler("context_created", callback_id_2) + + +def test_remove_specific_event_handler_multiple_handlers(driver): + """Test removing a specific event handler when multiple handlers exist.""" + events_received_1 = [] + events_received_2 = [] + + def on_context_created_1(info): + events_received_1.append(info) + + def on_context_created_2(info): + events_received_2.append(info) + + # Add multiple event handlers + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + + # Create a context to trigger both handlers + context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify both handlers received the event + assert len(events_received_1) >= 1 + assert len(events_received_2) >= 1 + + # store the initial event counts + initial_count_1 = len(events_received_1) + initial_count_2 = len(events_received_2) + + # Remove only the first handler + driver.browsing_context.remove_event_handler("context_created", callback_id_1) + + # Create another context + context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify only the second handler received the new event + assert len(events_received_1) == initial_count_1 # No new events + assert len(events_received_2) == initial_count_2 + 1 # 1 new event + + driver.browsing_context.close(context_id_1) + driver.browsing_context.close(context_id_2) + driver.browsing_context.remove_event_handler("context_created", callback_id_2) + + +def test_event_handler_thread_safety(driver): + """Test event handlers are thread-safe.""" + events_received = [] + event_lock = threading.Lock() + + def on_context_created(info): + with event_lock: + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + + # Create multiple contexts in rapid succession + context_ids = [] + for i in range(3): + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + context_ids.append(context_id) + + # Verify all events were received (might be 1 more than 3 due to default context) + assert len(events_received) >= 3 + + for context_id in context_ids: + driver.browsing_context.close(context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id) From 597437808d38b84b09276fc6c6fc3cff6f3afea8 Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Wed, 16 Jul 2025 17:49:27 +0530 Subject: [PATCH 2/8] add `get_event_names` classmethod and refactor event dict --- .../webdriver/common/bidi/browsing_context.py | 130 +++++++++--------- 1 file changed, 68 insertions(+), 62 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index acbbf0590d744..25a4cf73ff00b 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from dataclasses import dataclass from typing import Any, Callable, Optional, Union from selenium.webdriver.common.bidi.common import command_builder @@ -547,39 +548,34 @@ def from_json(cls, json: dict): return HistoryUpdatedParams.from_json(json) +@dataclass +class EventConfig: + event_key: str + bidi_event: str + event_class: type + + class BrowsingContext: """BiDi implementation of the browsingContext module.""" - EVENTS = { - "context_created": "browsingContext.contextCreated", - "context_destroyed": "browsingContext.contextDestroyed", - "dom_content_loaded": "browsingContext.domContentLoaded", - "download_will_begin": "browsingContext.downloadWillBegin", - "fragment_navigated": "browsingContext.fragmentNavigated", - "history_updated": "browsingContext.historyUpdated", - "load": "browsingContext.load", - "navigation_aborted": "browsingContext.navigationAborted", - "navigation_committed": "browsingContext.navigationCommitted", - "navigation_failed": "browsingContext.navigationFailed", - "navigation_started": "browsingContext.navigationStarted", - "user_prompt_closed": "browsingContext.userPromptClosed", - "user_prompt_opened": "browsingContext.userPromptOpened", - } - - EVENT_CLASSES = { - "browsingContext.contextCreated": ContextCreated, - "browsingContext.contextDestroyed": ContextDestroyed, - "browsingContext.domContentLoaded": DomContentLoaded, - "browsingContext.downloadWillBegin": DownloadWillBegin, - "browsingContext.fragmentNavigated": FragmentNavigated, - "browsingContext.historyUpdated": HistoryUpdated, - "browsingContext.load": Load, - "browsingContext.navigationAborted": NavigationAborted, - "browsingContext.navigationCommitted": NavigationCommitted, - "browsingContext.navigationFailed": NavigationFailed, - "browsingContext.navigationStarted": NavigationStarted, - "browsingContext.userPromptClosed": UserPromptClosed, - "browsingContext.userPromptOpened": UserPromptOpened, + EVENT_CONFIGS = { + "context_created": EventConfig("context_created", "browsingContext.contextCreated", ContextCreated), + "context_destroyed": EventConfig("context_destroyed", "browsingContext.contextDestroyed", ContextDestroyed), + "dom_content_loaded": EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", DomContentLoaded), + "download_will_begin": EventConfig( + "download_will_begin", "browsingContext.downloadWillBegin", DownloadWillBegin + ), + "fragment_navigated": EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", FragmentNavigated), + "history_updated": EventConfig("history_updated", "browsingContext.historyUpdated", HistoryUpdated), + "load": EventConfig("load", "browsingContext.load", Load), + "navigation_aborted": EventConfig("navigation_aborted", "browsingContext.navigationAborted", NavigationAborted), + "navigation_committed": EventConfig( + "navigation_committed", "browsingContext.navigationCommitted", NavigationCommitted + ), + "navigation_failed": EventConfig("navigation_failed", "browsingContext.navigationFailed", NavigationFailed), + "navigation_started": EventConfig("navigation_started", "browsingContext.navigationStarted", NavigationStarted), + "user_prompt_closed": EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", UserPromptClosed), + "user_prompt_opened": EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", UserPromptOpened), } def __init__(self, conn): @@ -587,6 +583,16 @@ def __init__(self, conn): self.subscriptions = {} self.callbacks = {} + @classmethod + def get_event_names(cls) -> list[str]: + """Get a list of all available event names. + + Returns: + ------- + List[str]: A list of event names that can be used with event handlers. + """ + return list(cls.EVENT_CONFIGS.keys()) + def activate(self, context: str) -> None: """Activates and focuses the given top-level traversable. @@ -903,21 +909,19 @@ def traverse_history(self, context: str, delta: int) -> dict: result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) return result - def _on_event(self, event_name: str, callback: Callable) -> int: + def _on_event(self, event_name: str, event_class: type, callback: Callable) -> int: """Set a callback function to subscribe to a browsing context event. Parameters: ---------- - event_name: The event to subscribe to. + event_name: The BiDi event name to subscribe to. + event_class: The event class for parsing. callback: The callback function to execute on event. Returns: ------- int: callback id """ - event_class = self.EVENT_CLASSES.get(event_name) - if not event_class: - raise Exception(f"Event class for {event_name} not found") def _callback(event_data): # Parse the event data using the appropriate event class @@ -946,22 +950,22 @@ def add_event_handler(self, event: str, callback: Callable, contexts: Optional[l ------- int: callback id """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") + event_config = self.EVENT_CONFIGS.get(event) + if not event_config: + available = ", ".join(sorted(self.get_event_names())) + raise ValueError(f"Event '{event}' not found. Available events: {available}") - callback_id = self._on_event(event_name, callback) + callback_id = self._on_event(event_config.bidi_event, event_config.event_class, callback) - if event_name in self.subscriptions: - self.subscriptions[event_name].append(callback_id) + if event_config.bidi_event in self.subscriptions: + self.subscriptions[event_config.bidi_event].append(callback_id) else: session = Session(self.conn) if contexts is not None: - self.conn.execute(session.subscribe(event_name, browsing_contexts=contexts)) + self.conn.execute(session.subscribe(event_config.bidi_event, browsing_contexts=contexts)) else: - self.conn.execute(session.subscribe(event_name)) - self.subscriptions[event_name] = [callback_id] + self.conn.execute(session.subscribe(event_config.bidi_event)) + self.subscriptions[event_config.bidi_event] = [callback_id] return callback_id @@ -973,39 +977,41 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: event: The event to unsubscribe from. callback_id: The callback id to remove. """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") + event_config = self.EVENT_CONFIGS.get(event) + if not event_config: + available = ", ".join(sorted(self.get_event_names())) + raise ValueError(f"Event '{event}' not found. Available events: {available}") - event_class = self.EVENT_CLASSES.get(event_name) - if not event_class: - raise Exception(f"Event class for {event_name} not found") - - self.conn.remove_callback(event_class, callback_id) + self.conn.remove_callback(event_config.event_class, callback_id) # Remove from callbacks tracking - if event_name in self.callbacks: - callbacks = self.callbacks[event_name] + if event_config.bidi_event in self.callbacks: + callbacks = self.callbacks[event_config.bidi_event] if callback_id in callbacks: callbacks.remove(callback_id) if not callbacks: - del self.callbacks[event_name] + del self.callbacks[event_config.bidi_event] # Remove from subscriptions and unsubscribe if no more callbacks - if event_name in self.subscriptions: - subscription_callbacks = self.subscriptions[event_name] + if event_config.bidi_event in self.subscriptions: + subscription_callbacks = self.subscriptions[event_config.bidi_event] if callback_id in subscription_callbacks: subscription_callbacks.remove(callback_id) if not subscription_callbacks: session = Session(self.conn) - self.conn.execute(session.unsubscribe(event_name)) - del self.subscriptions[event_name] + self.conn.execute(session.unsubscribe(event_config.bidi_event)) + del self.subscriptions[event_config.bidi_event] def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" for event_name in list(self.subscriptions.keys()): - event_class = self.EVENT_CLASSES.get(event_name) + # Find the event class for this BiDi event name + event_class = None + for config in self.EVENT_CONFIGS.values(): + if config.bidi_event == event_name: + event_class = config.event_class + break + if event_class: for callback_id in self.subscriptions[event_name]: self.conn.remove_callback(event_class, callback_id) From 43cc3f102113696ebf0c5d1d2cb0467e8d62894e Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Mon, 28 Jul 2025 12:24:44 +0530 Subject: [PATCH 3/8] remove unused `BrowsingContextEvent` class --- .../webdriver/common/bidi/browsing_context.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 25a4cf73ff00b..bd20c19a5a67e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -374,32 +374,6 @@ def from_json(cls, json: dict) -> "HistoryUpdatedParams": ) -class BrowsingContextEvent: - """Base class for browsing context events.""" - - def __init__(self, event_class: str, **kwargs): - self.event_class = event_class - self.params = kwargs - - @classmethod - def from_json(cls, json: dict) -> "BrowsingContextEvent": - """Creates a BrowsingContextEvent instance from a dictionary. - - Parameters: - ----------- - json: A dictionary containing the event information. - - Returns: - ------- - BrowsingContextEvent: A new instance of BrowsingContextEvent. - """ - event_class = json.get("event_class") - if event_class is None or not isinstance(event_class, str): - raise ValueError("event_class is required and must be a string") - - return cls(event_class=event_class, **json) - - class ContextCreated: """Event class for browsingContext.contextCreated event.""" From 9fb9d4a1b28fd594658a2fe9e88018705ddd9497 Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Mon, 28 Jul 2025 14:59:25 +0530 Subject: [PATCH 4/8] refactor event handling to `_EventManager` class --- .../webdriver/common/bidi/browsing_context.py | 199 ++++++++++-------- 1 file changed, 111 insertions(+), 88 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index bd20c19a5a67e..30b7049333e20 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -529,6 +529,113 @@ class EventConfig: event_class: type +class _EventManager: + """Class to manage event subscriptions and callbacks for BrowsingContext.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions = {} + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def create_event_callback(self, event_class: type, user_callback: Callable) -> Callable: + """ + Create a wrapped callback that parses event data. + """ + + def _callback(event_data): + parsed_data = event_class.from_json(event_data) + user_callback(parsed_data) + + return _callback + + def subscribe_to_event(self, bidi_event: str, contexts: Optional[list[str]] = None) -> None: + """Subscribe to a BiDi event if not already subscribed. + + Parameters: + ---------- + bidi_event: The BiDi event name. + contexts: Optional browsing context IDs to subscribe to. + """ + if bidi_event not in self.subscriptions: + session = Session(self.conn) + self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) + self.subscriptions[bidi_event] = [] + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist. + + Parameters: + ---------- + bidi_event: The BiDi event name. + """ + callback_list = self.subscriptions.get(bidi_event) + if callback_list is not None and not callback_list: + session = Session(self.conn) + self.conn.execute(session.unsubscribe(bidi_event)) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + self.subscriptions[bidi_event].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + callback_list = self.subscriptions.get(bidi_event) + if callback_list and callback_id in callback_list: + callback_list.remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: + event_config = self.validate_event(event) + + # Create and register the wrapped callback + wrapped_callback = self.create_event_callback(event_config.event_class, callback) + callback_id = self.conn.add_callback(event_config.event_class, wrapped_callback) + + # Subscribe to the event if needed + self.subscribe_to_event(event_config.bidi_event, contexts) + + # Track the callback + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + + # Remove the callback from the connection + self.conn.remove_callback(event_config.event_class, callback_id) + + # Remove from tracking collections + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + + # Unsubscribe if no more callbacks exist + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers from the browsing context.""" + if not self.subscriptions: + return + + session = Session(self.conn) + + for bidi_event, callback_ids in list(self.subscriptions.items()): + event_class = self._bidi_to_class.get(bidi_event) + if event_class: + # Remove all callbacks for this event + for callback_id in callback_ids: + self.conn.remove_callback(event_class, callback_id) + + self.conn.execute(session.unsubscribe(bidi_event)) + + self.subscriptions.clear() + + class BrowsingContext: """BiDi implementation of the browsingContext module.""" @@ -554,8 +661,7 @@ class BrowsingContext: def __init__(self, conn): self.conn = conn - self.subscriptions = {} - self.callbacks = {} + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @classmethod def get_event_names(cls) -> list[str]: @@ -883,34 +989,6 @@ def traverse_history(self, context: str, delta: int) -> dict: result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) return result - def _on_event(self, event_name: str, event_class: type, callback: Callable) -> int: - """Set a callback function to subscribe to a browsing context event. - - Parameters: - ---------- - event_name: The BiDi event name to subscribe to. - event_class: The event class for parsing. - callback: The callback function to execute on event. - - Returns: - ------- - int: callback id - """ - - def _callback(event_data): - # Parse the event data using the appropriate event class - parsed_data = event_class.from_json(event_data) - callback(parsed_data) - - callback_id = self.conn.add_callback(event_class, _callback) - - if event_name in self.callbacks: - self.callbacks[event_name].append(callback_id) - else: - self.callbacks[event_name] = [callback_id] - - return callback_id - def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: """Add an event handler to the browsing context. @@ -924,24 +1002,7 @@ def add_event_handler(self, event: str, callback: Callable, contexts: Optional[l ------- int: callback id """ - event_config = self.EVENT_CONFIGS.get(event) - if not event_config: - available = ", ".join(sorted(self.get_event_names())) - raise ValueError(f"Event '{event}' not found. Available events: {available}") - - callback_id = self._on_event(event_config.bidi_event, event_config.event_class, callback) - - if event_config.bidi_event in self.subscriptions: - self.subscriptions[event_config.bidi_event].append(callback_id) - else: - session = Session(self.conn) - if contexts is not None: - self.conn.execute(session.subscribe(event_config.bidi_event, browsing_contexts=contexts)) - else: - self.conn.execute(session.subscribe(event_config.bidi_event)) - self.subscriptions[event_config.bidi_event] = [callback_id] - - return callback_id + return self._event_manager.add_event_handler(event, callback, contexts) def remove_event_handler(self, event: str, callback_id: int) -> None: """Remove an event handler from the browsing context. @@ -951,46 +1012,8 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: event: The event to unsubscribe from. callback_id: The callback id to remove. """ - event_config = self.EVENT_CONFIGS.get(event) - if not event_config: - available = ", ".join(sorted(self.get_event_names())) - raise ValueError(f"Event '{event}' not found. Available events: {available}") - - self.conn.remove_callback(event_config.event_class, callback_id) - - # Remove from callbacks tracking - if event_config.bidi_event in self.callbacks: - callbacks = self.callbacks[event_config.bidi_event] - if callback_id in callbacks: - callbacks.remove(callback_id) - if not callbacks: - del self.callbacks[event_config.bidi_event] - - # Remove from subscriptions and unsubscribe if no more callbacks - if event_config.bidi_event in self.subscriptions: - subscription_callbacks = self.subscriptions[event_config.bidi_event] - if callback_id in subscription_callbacks: - subscription_callbacks.remove(callback_id) - if not subscription_callbacks: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(event_config.bidi_event)) - del self.subscriptions[event_config.bidi_event] + self._event_manager.remove_event_handler(event, callback_id) def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" - for event_name in list(self.subscriptions.keys()): - # Find the event class for this BiDi event name - event_class = None - for config in self.EVENT_CONFIGS.values(): - if config.bidi_event == event_name: - event_class = config.event_class - break - - if event_class: - for callback_id in self.subscriptions[event_name]: - self.conn.remove_callback(event_class, callback_id) - session = Session(self.conn) - self.conn.execute(session.unsubscribe(event_name)) - - self.subscriptions = {} - self.callbacks = {} + self._event_manager.clear_event_handlers() From 0baf75dbf48a1960e55cff06463c6de7dc1063b3 Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Mon, 28 Jul 2025 22:48:39 +0530 Subject: [PATCH 5/8] fix callback --- .../webdriver/common/bidi/browsing_context.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 30b7049333e20..a90d3d47a54fd 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -535,7 +535,7 @@ class _EventManager: def __init__(self, conn, event_configs: dict[str, EventConfig]): self.conn = conn self.event_configs = event_configs - self.subscriptions = {} + self.subscriptions: dict = {} self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} self._available_events = ", ".join(sorted(event_configs.keys())) @@ -545,17 +545,6 @@ def validate_event(self, event: str) -> EventConfig: raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") return event_config - def create_event_callback(self, event_class: type, user_callback: Callable) -> Callable: - """ - Create a wrapped callback that parses event data. - """ - - def _callback(event_data): - parsed_data = event_class.from_json(event_data) - user_callback(parsed_data) - - return _callback - def subscribe_to_event(self, bidi_event: str, contexts: Optional[list[str]] = None) -> None: """Subscribe to a BiDi event if not already subscribed. @@ -593,9 +582,7 @@ def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> No def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: event_config = self.validate_event(event) - # Create and register the wrapped callback - wrapped_callback = self.create_event_callback(event_config.event_class, callback) - callback_id = self.conn.add_callback(event_config.event_class, wrapped_callback) + callback_id = self.conn.add_callback(event_config.event_class, callback) # Subscribe to the event if needed self.subscribe_to_event(event_config.bidi_event, contexts) From 086bc513bc65d8a4352c827fe58067287af06e21 Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Mon, 28 Jul 2025 22:49:13 +0530 Subject: [PATCH 6/8] add tests for events --- .../common/bidi_browsing_context_tests.py | 214 ++++++++++++++++++ 1 file changed, 214 insertions(+) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 9a75cb3a535c4..b7d450a9fa583 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -593,6 +593,220 @@ def on_navigation_committed(info): driver.browsing_context.remove_event_handler("navigation_committed", callback_id) +def test_add_event_handler_dom_content_loaded(driver, pages): + """Test adding event handler for dom_content_loaded event.""" + events_received = [] + + def on_dom_content_loaded(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("dom_content_loaded", on_dom_content_loaded) + assert callback_id is not None + + # Navigate to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("dom_content_loaded", callback_id) + + +def test_add_event_handler_load(driver, pages): + """Test adding event handler for load event.""" + events_received = [] + + def on_load(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("load", on_load) + assert callback_id is not None + + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("load", callback_id) + + +def test_add_event_handler_navigation_started(driver, pages): + """Test adding event handler for navigation_started event.""" + events_received = [] + + def on_navigation_started(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_started", on_navigation_started) + assert callback_id is not None + + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("navigation_started", callback_id) + + +def test_add_event_handler_fragment_navigated(driver, pages): + """Test adding event handler for fragment_navigated event.""" + events_received = [] + + def on_fragment_navigated(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("fragment_navigated", on_fragment_navigated) + assert callback_id is not None + + # First navigate to a page + context_id = driver.current_window_handle + url = pages.url("linked_image.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + # Then navigate to the same page with a fragment to trigger the event + fragment_url = url + "#link" + driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("link" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("fragment_navigated", callback_id) + + +@pytest.mark.xfail_firefox +def test_add_event_handler_navigation_failed(driver): + """Test adding event handler for navigation_failed event.""" + events_received = [] + + def on_navigation_failed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_failed", on_navigation_failed) + assert callback_id is not None + + # Navigate to an invalid URL to trigger the event + context_id = driver.current_window_handle + try: + driver.browsing_context.navigate(context=context_id, url="http://invalid-domain-that-does-not-exist.test/") + except Exception: + # Expect an exception due to navigation failure + pass + + assert len(events_received) == 1 + assert events_received[0].url == "http://invalid-domain-that-does-not-exist.test/" + assert events_received[0].context == context_id + + driver.browsing_context.remove_event_handler("navigation_failed", callback_id) + + +def test_add_event_handler_user_prompt_opened(driver, pages): + """Test adding event handler for user_prompt_opened event.""" + events_received = [] + + def on_user_prompt_opened(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("user_prompt_opened", on_user_prompt_opened) + assert callback_id is not None + + # Create an alert to trigger the event + create_alert_page(driver, pages) + driver.find_element(By.ID, "alert").click() + WebDriverWait(driver, 5).until(EC.alert_is_present()) + + assert len(events_received) == 1 + assert events_received[0].type == "alert" + assert events_received[0].message == "cheese" + + # Clean up the alert + driver.browsing_context.handle_user_prompt(context=driver.current_window_handle) + driver.browsing_context.remove_event_handler("user_prompt_opened", callback_id) + + +def test_add_event_handler_user_prompt_closed(driver, pages): + """Test adding event handler for user_prompt_closed event.""" + events_received = [] + + def on_user_prompt_closed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("user_prompt_closed", on_user_prompt_closed) + assert callback_id is not None + + create_prompt_page(driver, pages) + driver.execute_script("prompt('Enter something')") + WebDriverWait(driver, 5).until(EC.alert_is_present()) + + driver.browsing_context.handle_user_prompt( + context=driver.current_window_handle, accept=True, user_text="test input" + ) + + assert len(events_received) == 1 + assert events_received[0].accepted is True + assert events_received[0].user_text == "test input" + + driver.browsing_context.remove_event_handler("user_prompt_closed", callback_id) + + +@pytest.mark.xfail_chrome +@pytest.mark.xfail_firefox +@pytest.mark.xfail_edge +def test_add_event_handler_history_updated(driver, pages): + """Test adding event handler for history_updated event.""" + events_received = [] + + def on_history_updated(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("history_updated", on_history_updated) + assert callback_id is not None + + # Navigate to a page and use history API to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + # Use history.pushState to trigger history updated event + driver.execute_script("history.pushState({}, '', '/new-path');") + + assert len(events_received) == 1 + assert any("/new-path" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("history_updated", callback_id) + + +@pytest.mark.xfail_firefox +def test_add_event_handler_download_will_begin(driver, pages): + """Test adding event handler for download_will_begin event.""" + events_received = [] + + def on_download_will_begin(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("download_will_begin", on_download_will_begin) + assert callback_id is not None + + # click on a download link to trigger the event + context_id = driver.current_window_handle + url = pages.url("downloads/download.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + download_xpath_file_1_txt = '//*[@id="file-1"]' + driver.find_element(By.XPATH, download_xpath_file_1_txt).click() + WebDriverWait(driver, 5).until(lambda d: len(events_received) > 0) + + assert len(events_received) == 1 + assert events_received[0].suggested_filename == "file_1.txt" + + driver.browsing_context.remove_event_handler("download_will_begin", callback_id) + + def test_add_event_handler_with_specific_contexts(driver): """Test adding event handler with specific browsing contexts.""" events_received = [] From 06c05a208dae282c0a2e8753cd7c0245c02ba854 Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Tue, 29 Jul 2025 13:39:56 +0530 Subject: [PATCH 7/8] add thread safety lock and tests --- .../webdriver/common/bidi/browsing_context.py | 56 ++--- .../common/bidi_browsing_context_tests.py | 201 ++++++++++++++++-- 2 files changed, 218 insertions(+), 39 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index a90d3d47a54fd..9cd2913d80dc3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import threading from dataclasses import dataclass from typing import Any, Callable, Optional, Union @@ -538,6 +539,8 @@ def __init__(self, conn, event_configs: dict[str, EventConfig]): self.subscriptions: dict = {} self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} self._available_events = ", ".join(sorted(event_configs.keys())) + # Thread safety lock for subscription operations + self._subscription_lock = threading.Lock() def validate_event(self, event: str) -> EventConfig: event_config = self.event_configs.get(event) @@ -553,10 +556,11 @@ def subscribe_to_event(self, bidi_event: str, contexts: Optional[list[str]] = No bidi_event: The BiDi event name. contexts: Optional browsing context IDs to subscribe to. """ - if bidi_event not in self.subscriptions: - session = Session(self.conn) - self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) - self.subscriptions[bidi_event] = [] + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) + self.subscriptions[bidi_event] = [] def unsubscribe_from_event(self, bidi_event: str) -> None: """Unsubscribe from a BiDi event if no more callbacks exist. @@ -565,19 +569,22 @@ def unsubscribe_from_event(self, bidi_event: str) -> None: ---------- bidi_event: The BiDi event name. """ - callback_list = self.subscriptions.get(bidi_event) - if callback_list is not None and not callback_list: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(bidi_event)) - del self.subscriptions[bidi_event] + with self._subscription_lock: + callback_list = self.subscriptions.get(bidi_event) + if callback_list is not None and not callback_list: + session = Session(self.conn) + self.conn.execute(session.unsubscribe(bidi_event)) + del self.subscriptions[bidi_event] def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - self.subscriptions[bidi_event].append(callback_id) + with self._subscription_lock: + self.subscriptions[bidi_event].append(callback_id) def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - callback_list = self.subscriptions.get(bidi_event) - if callback_list and callback_id in callback_list: - callback_list.remove(callback_id) + with self._subscription_lock: + callback_list = self.subscriptions.get(bidi_event) + if callback_list and callback_id in callback_list: + callback_list.remove(callback_id) def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: event_config = self.validate_event(event) @@ -606,21 +613,22 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" - if not self.subscriptions: - return + with self._subscription_lock: + if not self.subscriptions: + return - session = Session(self.conn) + session = Session(self.conn) - for bidi_event, callback_ids in list(self.subscriptions.items()): - event_class = self._bidi_to_class.get(bidi_event) - if event_class: - # Remove all callbacks for this event - for callback_id in callback_ids: - self.conn.remove_callback(event_class, callback_id) + for bidi_event, callback_ids in list(self.subscriptions.items()): + event_class = self._bidi_to_class.get(bidi_event) + if event_class: + # Remove all callbacks for this event + for callback_id in callback_ids: + self.conn.remove_callback(event_class, callback_id) - self.conn.execute(session.unsubscribe(bidi_event)) + self.conn.execute(session.unsubscribe(bidi_event)) - self.subscriptions.clear() + self.subscriptions.clear() class BrowsingContext: diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index b7d450a9fa583..a683c7b41beae 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -931,25 +931,196 @@ def on_context_created_2(info): def test_event_handler_thread_safety(driver): - """Test event handlers are thread-safe.""" + """Test thread safety with multiple non-atomic operations in callbacks.""" + import concurrent.futures + import time + events_received = [] - event_lock = threading.Lock() + context_counts = {} + event_type_counts = {} + processing_times = [] + consistency_errors = [] + thread_errors = [] + + data_lock = threading.Lock() + callback_ids = [] + registration_complete = threading.Event() + + def complex_event_callback(info): + """Callback with multiple non-atomic operations that require thread synchronization.""" + start_time = time.time() + time.sleep(0.02) # Create race condition window + + with data_lock: + # Multiple operations that could race without proper locking + initial_event_count = len(events_received) + _ = sum(context_counts.values()) if context_counts else 0 + _ = sum(event_type_counts.values()) if event_type_counts else 0 - def on_context_created(info): - with event_lock: events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + context_id = info.context + if context_id not in context_counts: + context_counts[context_id] = 0 + context_counts[context_id] += 1 + + event_type = info.__class__.__name__ + if event_type not in event_type_counts: + event_type_counts[event_type] = 0 + event_type_counts[event_type] += 1 + + processing_time = time.time() - start_time + processing_times.append(processing_time) + + # Verify data consistency + final_event_count = len(events_received) + final_context_total = sum(context_counts.values()) + final_type_total = sum(event_type_counts.values()) + final_processing_count = len(processing_times) + + expected_count = initial_event_count + 1 + if not ( + final_event_count == final_context_total == final_type_total == final_processing_count == expected_count + ): + error_msg = ( + f"Data consistency error! Events: {final_event_count}, " + f"Contexts: {final_context_total}, Types: {final_type_total}, " + f"Times: {final_processing_count}, Expected: {expected_count}" + ) + consistency_errors.append(error_msg) + + def register_handler(thread_id): + try: + callback_id = driver.browsing_context.add_event_handler("context_created", complex_event_callback) + with data_lock: + callback_ids.append(callback_id) + if len(callback_ids) == 5: + registration_complete.set() + return callback_id + except Exception as e: + with data_lock: + thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") + return None + + def remove_handler(callback_id, thread_id): + try: + driver.browsing_context.remove_event_handler("context_created", callback_id) + except Exception as e: + with data_lock: + thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") + + initial_context = driver.browsing_context.create(type=WindowTypes.TAB) + + # Concurrent registration + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = {} + for i in range(5): + future = executor.submit(register_handler, f"reg-{i}") + futures[future] = f"reg-{i}" + + for future in futures: + thread_id = futures[future] + try: + future.result(timeout=15) + except concurrent.futures.TimeoutError: + with data_lock: + thread_errors.append(f"Thread {thread_id}: Registration timed out") + except Exception as e: + with data_lock: + thread_errors.append(f"Thread {thread_id}: Registration exception: {e}") + + registration_complete.wait(timeout=5) + + with data_lock: + successful_registrations = len(callback_ids) + + # Trigger events while handlers are active + if successful_registrations > 0: + test_contexts = [] + for i in range(3): + try: + context = driver.browsing_context.create(type=WindowTypes.TAB) + test_contexts.append(context) + time.sleep(0.1) + except Exception as e: + thread_errors.append(f"Failed to create test context {i}: {e}") + + time.sleep(1.0) # Allow event processing + + for context in test_contexts: + try: + driver.browsing_context.close(context) + except Exception: + pass + + # Concurrent removal + if callback_ids: + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = {} + for i, callback_id in enumerate(callback_ids): + future = executor.submit(remove_handler, callback_id, f"rem-{i}") + futures[future] = f"rem-{i}" + + for future in futures: + thread_id = futures[future] + try: + future.result(timeout=15) + except concurrent.futures.TimeoutError: + with data_lock: + thread_errors.append(f"Thread {thread_id}: Removal timed out") + except Exception as e: + with data_lock: + thread_errors.append(f"Thread {thread_id}: Removal exception: {e}") + + time.sleep(0.5) + + # Verify handlers are removed + with data_lock: + events_before_removal_test = len(events_received) + + try: + post_removal_context = driver.browsing_context.create(type=WindowTypes.TAB) + time.sleep(0.8) + driver.browsing_context.close(post_removal_context) + except Exception as e: + thread_errors.append(f"Failed to create post-removal test context: {e}") + + with data_lock: + events_after_removal = len(events_received) - events_before_removal_test - # Create multiple contexts in rapid succession - context_ids = [] - for i in range(3): - context_id = driver.browsing_context.create(type=WindowTypes.TAB) - context_ids.append(context_id) + # Cleanup + try: + driver.browsing_context.close(initial_context) + except Exception as e: + thread_errors.append(f"Cleanup error: {e}") - # Verify all events were received (might be 1 more than 3 due to default context) - assert len(events_received) >= 3 + # Assertions + all_errors = thread_errors + consistency_errors + if all_errors: + pytest.fail("Thread safety test failed with errors:\n" + "\n".join(all_errors)) - for context_id in context_ids: - driver.browsing_context.close(context_id) - driver.browsing_context.remove_event_handler("context_created", callback_id) + assert successful_registrations > 0, f"No handlers were successfully registered (got {successful_registrations})" + assert len(events_received) > 0, "No events were received during test" + + # Verify data consistency across multiple counters + with data_lock: + total_context_events = sum(context_counts.values()) if context_counts else 0 + total_type_events = sum(event_type_counts.values()) if event_type_counts else 0 + + assert len(events_received) == total_context_events, ( + f"Context count mismatch: {len(events_received)} vs {total_context_events}" + ) + assert len(events_received) == total_type_events, ( + f"Type count mismatch: {len(events_received)} vs {total_type_events}" + ) + assert len(events_received) == len(processing_times), ( + f"Processing time count mismatch: {len(events_received)} vs {len(processing_times)}" + ) + + # Verify handlers were properly removed + assert events_after_removal == 0, f"Handlers still active after removal! Got {events_after_removal} events" + + # Verify event object + for i, event in enumerate(events_received): + assert hasattr(event, "context"), f"Event {i} missing 'context' attribute" + assert isinstance(event.context, str), f"Event {i} 'context' is not string: {type(event.context)}" From 9ea80f9c4b236b951826eb8cdd1446fe85da6e18 Mon Sep 17 00:00:00 2001 From: Navin Chandra Date: Wed, 30 Jul 2025 13:59:58 +0530 Subject: [PATCH 8/8] break thread test into 4 tests --- .../common/bidi_browsing_context_tests.py | 300 ++++++++---------- 1 file changed, 127 insertions(+), 173 deletions(-) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index a683c7b41beae..768640d7f71a8 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -16,7 +16,9 @@ # under the License. import base64 +import concurrent.futures import threading +import time import pytest @@ -930,197 +932,149 @@ def on_context_created_2(info): driver.browsing_context.remove_event_handler("context_created", callback_id_2) -def test_event_handler_thread_safety(driver): - """Test thread safety with multiple non-atomic operations in callbacks.""" - import concurrent.futures - import time - - events_received = [] - context_counts = {} - event_type_counts = {} - processing_times = [] - consistency_errors = [] - thread_errors = [] - - data_lock = threading.Lock() - callback_ids = [] - registration_complete = threading.Event() - - def complex_event_callback(info): - """Callback with multiple non-atomic operations that require thread synchronization.""" - start_time = time.time() - time.sleep(0.02) # Create race condition window - - with data_lock: - # Multiple operations that could race without proper locking - initial_event_count = len(events_received) - _ = sum(context_counts.values()) if context_counts else 0 - _ = sum(event_type_counts.values()) if event_type_counts else 0 - - events_received.append(info) - - context_id = info.context - if context_id not in context_counts: - context_counts[context_id] = 0 - context_counts[context_id] += 1 - - event_type = info.__class__.__name__ - if event_type not in event_type_counts: - event_type_counts[event_type] = 0 - event_type_counts[event_type] += 1 - - processing_time = time.time() - start_time - processing_times.append(processing_time) - - # Verify data consistency - final_event_count = len(events_received) - final_context_total = sum(context_counts.values()) - final_type_total = sum(event_type_counts.values()) - final_processing_count = len(processing_times) - - expected_count = initial_event_count + 1 - if not ( - final_event_count == final_context_total == final_type_total == final_processing_count == expected_count - ): - error_msg = ( - f"Data consistency error! Events: {final_event_count}, " - f"Contexts: {final_context_total}, Types: {final_type_total}, " - f"Times: {final_processing_count}, Expected: {expected_count}" - ) - consistency_errors.append(error_msg) - - def register_handler(thread_id): +class _EventHandlerTestHelper: + def __init__(self, driver): + self.driver = driver + self.events_received = [] + self.context_counts = {} + self.event_type_counts = {} + self.processing_times = [] + self.consistency_errors = [] + self.thread_errors = [] + self.callback_ids = [] + self.data_lock = threading.Lock() + self.registration_complete = threading.Event() + + def make_callback(self): + def callback(info): + start_time = time.time() + time.sleep(0.02) # Simulate race window + + with self.data_lock: + initial_event_count = len(self.events_received) + + self.events_received.append(info) + + context_id = info.context + self.context_counts.setdefault(context_id, 0) + self.context_counts[context_id] += 1 + + event_type = info.__class__.__name__ + self.event_type_counts.setdefault(event_type, 0) + self.event_type_counts[event_type] += 1 + + processing_time = time.time() - start_time + self.processing_times.append(processing_time) + + final_event_count = len(self.events_received) + final_context_total = sum(self.context_counts.values()) + final_type_total = sum(self.event_type_counts.values()) + final_processing_count = len(self.processing_times) + + expected_count = initial_event_count + 1 + if not ( + final_event_count + == final_context_total + == final_type_total + == final_processing_count + == expected_count + ): + self.consistency_errors.append("Data consistency error") + + return callback + + def register_handler(self, thread_id): try: - callback_id = driver.browsing_context.add_event_handler("context_created", complex_event_callback) - with data_lock: - callback_ids.append(callback_id) - if len(callback_ids) == 5: - registration_complete.set() + callback = self.make_callback() + callback_id = self.driver.browsing_context.add_event_handler("context_created", callback) + with self.data_lock: + self.callback_ids.append(callback_id) + if len(self.callback_ids) == 5: + self.registration_complete.set() return callback_id except Exception as e: - with data_lock: - thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") + with self.data_lock: + self.thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") return None - def remove_handler(callback_id, thread_id): + def remove_handler(self, callback_id, thread_id): try: - driver.browsing_context.remove_event_handler("context_created", callback_id) + self.driver.browsing_context.remove_event_handler("context_created", callback_id) except Exception as e: - with data_lock: - thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") + with self.data_lock: + self.thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") + - initial_context = driver.browsing_context.create(type=WindowTypes.TAB) +def test_concurrent_event_handler_registration(driver): + helper = _EventHandlerTestHelper(driver) - # Concurrent registration with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = {} - for i in range(5): - future = executor.submit(register_handler, f"reg-{i}") - futures[future] = f"reg-{i}" + futures = [executor.submit(helper.register_handler, f"reg-{i}") for i in range(5)] + for future in futures: + future.result(timeout=15) + + helper.registration_complete.wait(timeout=5) + assert len(helper.callback_ids) == 5, f"Expected 5 handlers, got {len(helper.callback_ids)}" + assert not helper.thread_errors, "Errors during registration: \n" + "\n".join(helper.thread_errors) + + +def test_event_callback_data_consistency(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + test_contexts = [] + for _ in range(3): + context = driver.browsing_context.create(type=WindowTypes.TAB) + test_contexts.append(context) + + for ctx in test_contexts: + driver.browsing_context.close(ctx) + + with helper.data_lock: + assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) + assert len(helper.events_received) > 0, "No events received" + assert len(helper.events_received) == sum(helper.context_counts.values()) + assert len(helper.events_received) == sum(helper.event_type_counts.values()) + assert len(helper.events_received) == len(helper.processing_times) + +def test_concurrent_event_handler_removal(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit(helper.remove_handler, callback_id, f"rem-{i}") + for i, callback_id in enumerate(helper.callback_ids) + ] for future in futures: - thread_id = futures[future] - try: - future.result(timeout=15) - except concurrent.futures.TimeoutError: - with data_lock: - thread_errors.append(f"Thread {thread_id}: Registration timed out") - except Exception as e: - with data_lock: - thread_errors.append(f"Thread {thread_id}: Registration exception: {e}") - - registration_complete.wait(timeout=5) - - with data_lock: - successful_registrations = len(callback_ids) - - # Trigger events while handlers are active - if successful_registrations > 0: - test_contexts = [] - for i in range(3): - try: - context = driver.browsing_context.create(type=WindowTypes.TAB) - test_contexts.append(context) - time.sleep(0.1) - except Exception as e: - thread_errors.append(f"Failed to create test context {i}: {e}") - - time.sleep(1.0) # Allow event processing - - for context in test_contexts: - try: - driver.browsing_context.close(context) - except Exception: - pass - - # Concurrent removal - if callback_ids: - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = {} - for i, callback_id in enumerate(callback_ids): - future = executor.submit(remove_handler, callback_id, f"rem-{i}") - futures[future] = f"rem-{i}" - - for future in futures: - thread_id = futures[future] - try: - future.result(timeout=15) - except concurrent.futures.TimeoutError: - with data_lock: - thread_errors.append(f"Thread {thread_id}: Removal timed out") - except Exception as e: - with data_lock: - thread_errors.append(f"Thread {thread_id}: Removal exception: {e}") - - time.sleep(0.5) - - # Verify handlers are removed - with data_lock: - events_before_removal_test = len(events_received) + future.result(timeout=15) - try: - post_removal_context = driver.browsing_context.create(type=WindowTypes.TAB) - time.sleep(0.8) - driver.browsing_context.close(post_removal_context) - except Exception as e: - thread_errors.append(f"Failed to create post-removal test context: {e}") + assert not helper.thread_errors, "Errors during removal: \n" + "\n".join(helper.thread_errors) - with data_lock: - events_after_removal = len(events_received) - events_before_removal_test - # Cleanup - try: - driver.browsing_context.close(initial_context) - except Exception as e: - thread_errors.append(f"Cleanup error: {e}") +def test_no_event_after_handler_removal(driver): + helper = _EventHandlerTestHelper(driver) - # Assertions - all_errors = thread_errors + consistency_errors - if all_errors: - pytest.fail("Thread safety test failed with errors:\n" + "\n".join(all_errors)) + for i in range(5): + helper.register_handler(f"reg-{i}") - assert successful_registrations > 0, f"No handlers were successfully registered (got {successful_registrations})" - assert len(events_received) > 0, "No events were received during test" + context = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(context) - # Verify data consistency across multiple counters - with data_lock: - total_context_events = sum(context_counts.values()) if context_counts else 0 - total_type_events = sum(event_type_counts.values()) if event_type_counts else 0 + events_before = len(helper.events_received) - assert len(events_received) == total_context_events, ( - f"Context count mismatch: {len(events_received)} vs {total_context_events}" - ) - assert len(events_received) == total_type_events, ( - f"Type count mismatch: {len(events_received)} vs {total_type_events}" - ) - assert len(events_received) == len(processing_times), ( - f"Processing time count mismatch: {len(events_received)} vs {len(processing_times)}" - ) + for i, callback_id in enumerate(helper.callback_ids): + helper.remove_handler(callback_id, f"rem-{i}") + + post_context = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(post_context) - # Verify handlers were properly removed - assert events_after_removal == 0, f"Handlers still active after removal! Got {events_after_removal} events" + with helper.data_lock: + new_events = len(helper.events_received) - events_before - # Verify event object - for i, event in enumerate(events_received): - assert hasattr(event, "context"), f"Event {i} missing 'context' attribute" - assert isinstance(event.context, str), f"Event {i} 'context' is not string: {type(event.context)}" + assert new_events == 0, f"Expected 0 new events after removal, got {new_events}"