diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 27656be2a1c6e..9cd2913d80dc3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import threading +from dataclasses import dataclass from typing import Any, Callable, Optional, Union from selenium.webdriver.common.bidi.common import command_builder @@ -373,55 +375,298 @@ def from_json(cls, json: dict) -> "HistoryUpdatedParams": ) -class BrowsingContextEvent: - """Base class for browsing context events.""" +class ContextCreated: + """Event class for browsingContext.contextCreated event.""" - def __init__(self, event_class: str, **kwargs): - self.event_class = event_class - self.params = kwargs + event_class = "browsingContext.contextCreated" @classmethod - def from_json(cls, json: dict) -> "BrowsingContextEvent": - """Creates a BrowsingContextEvent instance from a dictionary. + def from_json(cls, json: dict): + if isinstance(json, BrowsingContextInfo): + return json + return BrowsingContextInfo.from_json(json) + + +class ContextDestroyed: + """Event class for browsingContext.contextDestroyed event.""" + + event_class = "browsingContext.contextDestroyed" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, BrowsingContextInfo): + return json + return BrowsingContextInfo.from_json(json) + + +class NavigationStarted: + """Event class for browsingContext.navigationStarted event.""" + + event_class = "browsingContext.navigationStarted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationCommitted: + """Event class for browsingContext.navigationCommitted event.""" + + event_class = "browsingContext.navigationCommitted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationFailed: + """Event class for browsingContext.navigationFailed event.""" + + event_class = "browsingContext.navigationFailed" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class NavigationAborted: + """Event class for browsingContext.navigationAborted event.""" + + event_class = "browsingContext.navigationAborted" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class DomContentLoaded: + """Event class for browsingContext.domContentLoaded event.""" + + event_class = "browsingContext.domContentLoaded" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class Load: + """Event class for browsingContext.load event.""" + + event_class = "browsingContext.load" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class FragmentNavigated: + """Event class for browsingContext.fragmentNavigated event.""" + + event_class = "browsingContext.fragmentNavigated" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, NavigationInfo): + return json + return NavigationInfo.from_json(json) + + +class DownloadWillBegin: + """Event class for browsingContext.downloadWillBegin event.""" + + event_class = "browsingContext.downloadWillBegin" + + @classmethod + def from_json(cls, json: dict): + return DownloadWillBeginParams.from_json(json) + + +class UserPromptOpened: + """Event class for browsingContext.userPromptOpened event.""" + + event_class = "browsingContext.userPromptOpened" + + @classmethod + def from_json(cls, json: dict): + return UserPromptOpenedParams.from_json(json) + + +class UserPromptClosed: + """Event class for browsingContext.userPromptClosed event.""" + + event_class = "browsingContext.userPromptClosed" + + @classmethod + def from_json(cls, json: dict): + return UserPromptClosedParams.from_json(json) + + +class HistoryUpdated: + """Event class for browsingContext.historyUpdated event.""" + + event_class = "browsingContext.historyUpdated" + + @classmethod + def from_json(cls, json: dict): + return HistoryUpdatedParams.from_json(json) + + +@dataclass +class EventConfig: + event_key: str + bidi_event: str + event_class: type + + +class _EventManager: + """Class to manage event subscriptions and callbacks for BrowsingContext.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + # Thread safety lock for subscription operations + self._subscription_lock = threading.Lock() + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: Optional[list[str]] = None) -> None: + """Subscribe to a BiDi event if not already subscribed. Parameters: - ----------- - json: A dictionary containing the event information. + ---------- + bidi_event: The BiDi event name. + contexts: Optional browsing context IDs to subscribe to. + """ + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) + self.subscriptions[bidi_event] = [] - Returns: - ------- - BrowsingContextEvent: A new instance of BrowsingContextEvent. + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist. + + Parameters: + ---------- + bidi_event: The BiDi event name. """ - event_class = json.get("event_class") - if event_class is None or not isinstance(event_class, str): - raise ValueError("event_class is required and must be a string") + with self._subscription_lock: + callback_list = self.subscriptions.get(bidi_event) + if callback_list is not None and not callback_list: + session = Session(self.conn) + self.conn.execute(session.unsubscribe(bidi_event)) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + callback_list = self.subscriptions.get(bidi_event) + if callback_list and callback_id in callback_list: + callback_list.remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: + event_config = self.validate_event(event) + + callback_id = self.conn.add_callback(event_config.event_class, callback) + + # Subscribe to the event if needed + self.subscribe_to_event(event_config.bidi_event, contexts) + + # Track the callback + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + + # Remove the callback from the connection + self.conn.remove_callback(event_config.event_class, callback_id) - return cls(event_class=event_class, **json) + # Remove from tracking collections + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + + # Unsubscribe if no more callbacks exist + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers from the browsing context.""" + with self._subscription_lock: + if not self.subscriptions: + return + + session = Session(self.conn) + + for bidi_event, callback_ids in list(self.subscriptions.items()): + event_class = self._bidi_to_class.get(bidi_event) + if event_class: + # Remove all callbacks for this event + for callback_id in callback_ids: + self.conn.remove_callback(event_class, callback_id) + + self.conn.execute(session.unsubscribe(bidi_event)) + + self.subscriptions.clear() class BrowsingContext: """BiDi implementation of the browsingContext module.""" - EVENTS = { - "context_created": "browsingContext.contextCreated", - "context_destroyed": "browsingContext.contextDestroyed", - "dom_content_loaded": "browsingContext.domContentLoaded", - "download_will_begin": "browsingContext.downloadWillBegin", - "fragment_navigated": "browsingContext.fragmentNavigated", - "history_updated": "browsingContext.historyUpdated", - "load": "browsingContext.load", - "navigation_aborted": "browsingContext.navigationAborted", - "navigation_committed": "browsingContext.navigationCommitted", - "navigation_failed": "browsingContext.navigationFailed", - "navigation_started": "browsingContext.navigationStarted", - "user_prompt_closed": "browsingContext.userPromptClosed", - "user_prompt_opened": "browsingContext.userPromptOpened", + EVENT_CONFIGS = { + "context_created": EventConfig("context_created", "browsingContext.contextCreated", ContextCreated), + "context_destroyed": EventConfig("context_destroyed", "browsingContext.contextDestroyed", ContextDestroyed), + "dom_content_loaded": EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", DomContentLoaded), + "download_will_begin": EventConfig( + "download_will_begin", "browsingContext.downloadWillBegin", DownloadWillBegin + ), + "fragment_navigated": EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", FragmentNavigated), + "history_updated": EventConfig("history_updated", "browsingContext.historyUpdated", HistoryUpdated), + "load": EventConfig("load", "browsingContext.load", Load), + "navigation_aborted": EventConfig("navigation_aborted", "browsingContext.navigationAborted", NavigationAborted), + "navigation_committed": EventConfig( + "navigation_committed", "browsingContext.navigationCommitted", NavigationCommitted + ), + "navigation_failed": EventConfig("navigation_failed", "browsingContext.navigationFailed", NavigationFailed), + "navigation_started": EventConfig("navigation_started", "browsingContext.navigationStarted", NavigationStarted), + "user_prompt_closed": EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", UserPromptClosed), + "user_prompt_opened": EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", UserPromptOpened), } def __init__(self, conn): self.conn = conn - self.subscriptions = {} - self.callbacks = {} + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + @classmethod + def get_event_names(cls) -> list[str]: + """Get a list of all available event names. + + Returns: + ------- + List[str]: A list of event names that can be used with event handlers. + """ + return list(cls.EVENT_CONFIGS.keys()) def activate(self, context: str) -> None: """Activates and focuses the given top-level traversable. @@ -739,50 +984,6 @@ def traverse_history(self, context: str, delta: int) -> dict: result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) return result - def _on_event(self, event_name: str, callback: Callable) -> int: - """Set a callback function to subscribe to a browsing context event. - - Parameters: - ---------- - event_name: The event to subscribe to. - callback: The callback function to execute on event. - - Returns: - ------- - int: callback id - """ - event = BrowsingContextEvent(event_name) - - def _callback(event_data): - if event_name == self.EVENTS["context_created"] or event_name == self.EVENTS["context_destroyed"]: - info = BrowsingContextInfo.from_json(event_data.params) - callback(info) - elif event_name == self.EVENTS["download_will_begin"]: - params = DownloadWillBeginParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["user_prompt_opened"]: - params = UserPromptOpenedParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["user_prompt_closed"]: - params = UserPromptClosedParams.from_json(event_data.params) - callback(params) - elif event_name == self.EVENTS["history_updated"]: - params = HistoryUpdatedParams.from_json(event_data.params) - callback(params) - else: - # For navigation events - info = NavigationInfo.from_json(event_data.params) - callback(info) - - callback_id = self.conn.add_callback(event, _callback) - - if event_name in self.callbacks: - self.callbacks[event_name].append(callback_id) - else: - self.callbacks[event_name] = [callback_id] - - return callback_id - def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: """Add an event handler to the browsing context. @@ -796,24 +997,7 @@ def add_event_handler(self, event: str, callback: Callable, contexts: Optional[l ------- int: callback id """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") - - callback_id = self._on_event(event_name, callback) - - if event_name in self.subscriptions: - self.subscriptions[event_name].append(callback_id) - else: - params = {"events": [event_name]} - if contexts is not None: - params["browsingContexts"] = contexts - session = Session(self.conn) - self.conn.execute(session.subscribe(**params)) - self.subscriptions[event_name] = [callback_id] - - return callback_id + return self._event_manager.add_event_handler(event, callback, contexts) def remove_event_handler(self, event: str, callback_id: int) -> None: """Remove an event handler from the browsing context. @@ -823,31 +1007,8 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: event: The event to unsubscribe from. callback_id: The callback id to remove. """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") - - event_obj = BrowsingContextEvent(event_name) - - self.conn.remove_callback(event_obj, callback_id) - if event_name in self.subscriptions: - callbacks = self.subscriptions[event_name] - if callback_id in callbacks: - callbacks.remove(callback_id) - if not callbacks: - params = {"events": [event_name]} - session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) - del self.subscriptions[event_name] + self._event_manager.remove_event_handler(event, callback_id) def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" - for event_name in self.subscriptions: - event = BrowsingContextEvent(event_name) - for callback_id in self.subscriptions[event_name]: - self.conn.remove_callback(event, callback_id) - params = {"events": [event_name]} - session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) - self.subscriptions = {} + self._event_manager.clear_event_handlers() diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 74a0f53aaf640..768640d7f71a8 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -16,6 +16,9 @@ # under the License. import base64 +import concurrent.futures +import threading +import time import pytest @@ -525,3 +528,553 @@ def test_locate_nodes_given_start_nodes(driver, pages): ) # The login form should have 3 input elements (email, age, and submit button) assert len(elements) == 3 + + +# Tests for event handlers + + +def test_add_event_handler_context_created(driver): + """Test adding event handler for context_created event.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + assert callback_id is not None + + # Create a new context to trigger the event + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify the event was received (might be > 1 since default context is also included) + assert len(events_received) >= 1 + assert events_received[0].context == context_id or events_received[1].context == context_id + + driver.browsing_context.close(context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id) + + +def test_add_event_handler_context_destroyed(driver): + """Test adding event handler for context_destroyed event.""" + events_received = [] + + def on_context_destroyed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_destroyed", on_context_destroyed) + assert callback_id is not None + + # Create and then close a context to trigger the event + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(context_id) + + assert len(events_received) == 1 + assert events_received[0].context == context_id + + driver.browsing_context.remove_event_handler("context_destroyed", callback_id) + + +def test_add_event_handler_navigation_committed(driver, pages): + """Test adding event handler for navigation_committed event.""" + events_received = [] + + def on_navigation_committed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_committed", on_navigation_committed) + assert callback_id is not None + + # Navigate to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) >= 1 + assert any(url in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("navigation_committed", callback_id) + + +def test_add_event_handler_dom_content_loaded(driver, pages): + """Test adding event handler for dom_content_loaded event.""" + events_received = [] + + def on_dom_content_loaded(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("dom_content_loaded", on_dom_content_loaded) + assert callback_id is not None + + # Navigate to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("dom_content_loaded", callback_id) + + +def test_add_event_handler_load(driver, pages): + """Test adding event handler for load event.""" + events_received = [] + + def on_load(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("load", on_load) + assert callback_id is not None + + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("load", callback_id) + + +def test_add_event_handler_navigation_started(driver, pages): + """Test adding event handler for navigation_started event.""" + events_received = [] + + def on_navigation_started(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_started", on_navigation_started) + assert callback_id is not None + + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("simpleTest" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("navigation_started", callback_id) + + +def test_add_event_handler_fragment_navigated(driver, pages): + """Test adding event handler for fragment_navigated event.""" + events_received = [] + + def on_fragment_navigated(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("fragment_navigated", on_fragment_navigated) + assert callback_id is not None + + # First navigate to a page + context_id = driver.current_window_handle + url = pages.url("linked_image.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + # Then navigate to the same page with a fragment to trigger the event + fragment_url = url + "#link" + driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) + + assert len(events_received) == 1 + assert any("link" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("fragment_navigated", callback_id) + + +@pytest.mark.xfail_firefox +def test_add_event_handler_navigation_failed(driver): + """Test adding event handler for navigation_failed event.""" + events_received = [] + + def on_navigation_failed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("navigation_failed", on_navigation_failed) + assert callback_id is not None + + # Navigate to an invalid URL to trigger the event + context_id = driver.current_window_handle + try: + driver.browsing_context.navigate(context=context_id, url="http://invalid-domain-that-does-not-exist.test/") + except Exception: + # Expect an exception due to navigation failure + pass + + assert len(events_received) == 1 + assert events_received[0].url == "http://invalid-domain-that-does-not-exist.test/" + assert events_received[0].context == context_id + + driver.browsing_context.remove_event_handler("navigation_failed", callback_id) + + +def test_add_event_handler_user_prompt_opened(driver, pages): + """Test adding event handler for user_prompt_opened event.""" + events_received = [] + + def on_user_prompt_opened(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("user_prompt_opened", on_user_prompt_opened) + assert callback_id is not None + + # Create an alert to trigger the event + create_alert_page(driver, pages) + driver.find_element(By.ID, "alert").click() + WebDriverWait(driver, 5).until(EC.alert_is_present()) + + assert len(events_received) == 1 + assert events_received[0].type == "alert" + assert events_received[0].message == "cheese" + + # Clean up the alert + driver.browsing_context.handle_user_prompt(context=driver.current_window_handle) + driver.browsing_context.remove_event_handler("user_prompt_opened", callback_id) + + +def test_add_event_handler_user_prompt_closed(driver, pages): + """Test adding event handler for user_prompt_closed event.""" + events_received = [] + + def on_user_prompt_closed(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("user_prompt_closed", on_user_prompt_closed) + assert callback_id is not None + + create_prompt_page(driver, pages) + driver.execute_script("prompt('Enter something')") + WebDriverWait(driver, 5).until(EC.alert_is_present()) + + driver.browsing_context.handle_user_prompt( + context=driver.current_window_handle, accept=True, user_text="test input" + ) + + assert len(events_received) == 1 + assert events_received[0].accepted is True + assert events_received[0].user_text == "test input" + + driver.browsing_context.remove_event_handler("user_prompt_closed", callback_id) + + +@pytest.mark.xfail_chrome +@pytest.mark.xfail_firefox +@pytest.mark.xfail_edge +def test_add_event_handler_history_updated(driver, pages): + """Test adding event handler for history_updated event.""" + events_received = [] + + def on_history_updated(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("history_updated", on_history_updated) + assert callback_id is not None + + # Navigate to a page and use history API to trigger the event + context_id = driver.current_window_handle + url = pages.url("simpleTest.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + # Use history.pushState to trigger history updated event + driver.execute_script("history.pushState({}, '', '/new-path');") + + assert len(events_received) == 1 + assert any("/new-path" in event.url for event in events_received) + + driver.browsing_context.remove_event_handler("history_updated", callback_id) + + +@pytest.mark.xfail_firefox +def test_add_event_handler_download_will_begin(driver, pages): + """Test adding event handler for download_will_begin event.""" + events_received = [] + + def on_download_will_begin(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("download_will_begin", on_download_will_begin) + assert callback_id is not None + + # click on a download link to trigger the event + context_id = driver.current_window_handle + url = pages.url("downloads/download.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + download_xpath_file_1_txt = '//*[@id="file-1"]' + driver.find_element(By.XPATH, download_xpath_file_1_txt).click() + WebDriverWait(driver, 5).until(lambda d: len(events_received) > 0) + + assert len(events_received) == 1 + assert events_received[0].suggested_filename == "file_1.txt" + + driver.browsing_context.remove_event_handler("download_will_begin", callback_id) + + +def test_add_event_handler_with_specific_contexts(driver): + """Test adding event handler with specific browsing contexts.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Add event handler for specific context + callback_id = driver.browsing_context.add_event_handler( + "context_created", on_context_created, contexts=[context_id] + ) + assert callback_id is not None + + # Create another context (should trigger event) + new_context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + assert len(events_received) >= 1 + + driver.browsing_context.close(context_id) + driver.browsing_context.close(new_context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id) + + +def test_remove_event_handler(driver): + """Test removing event handler.""" + events_received = [] + + def on_context_created(info): + events_received.append(info) + + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + + # Create a context to trigger the event + context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) + + initial_events = len(events_received) + + # Remove the event handler + driver.browsing_context.remove_event_handler("context_created", callback_id) + + # Create another context (should not trigger event after removal) + context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify no new events were received after removal + assert len(events_received) == initial_events + + driver.browsing_context.close(context_id_1) + driver.browsing_context.close(context_id_2) + + +def test_multiple_event_handlers_same_event(driver): + """Test adding multiple event handlers for the same event.""" + events_received_1 = [] + events_received_2 = [] + + def on_context_created_1(info): + events_received_1.append(info) + + def on_context_created_2(info): + events_received_2.append(info) + + # Add multiple event handlers for the same event + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + + # Create a context to trigger both handlers + context_id = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify both handlers received the event + assert len(events_received_1) >= 1 + assert len(events_received_2) >= 1 + # Check any of the events has the required context ID + assert any(event.context == context_id for event in events_received_1) + assert any(event.context == context_id for event in events_received_2) + + driver.browsing_context.close(context_id) + driver.browsing_context.remove_event_handler("context_created", callback_id_1) + driver.browsing_context.remove_event_handler("context_created", callback_id_2) + + +def test_remove_specific_event_handler_multiple_handlers(driver): + """Test removing a specific event handler when multiple handlers exist.""" + events_received_1 = [] + events_received_2 = [] + + def on_context_created_1(info): + events_received_1.append(info) + + def on_context_created_2(info): + events_received_2.append(info) + + # Add multiple event handlers + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + + # Create a context to trigger both handlers + context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify both handlers received the event + assert len(events_received_1) >= 1 + assert len(events_received_2) >= 1 + + # store the initial event counts + initial_count_1 = len(events_received_1) + initial_count_2 = len(events_received_2) + + # Remove only the first handler + driver.browsing_context.remove_event_handler("context_created", callback_id_1) + + # Create another context + context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) + + # Verify only the second handler received the new event + assert len(events_received_1) == initial_count_1 # No new events + assert len(events_received_2) == initial_count_2 + 1 # 1 new event + + driver.browsing_context.close(context_id_1) + driver.browsing_context.close(context_id_2) + driver.browsing_context.remove_event_handler("context_created", callback_id_2) + + +class _EventHandlerTestHelper: + def __init__(self, driver): + self.driver = driver + self.events_received = [] + self.context_counts = {} + self.event_type_counts = {} + self.processing_times = [] + self.consistency_errors = [] + self.thread_errors = [] + self.callback_ids = [] + self.data_lock = threading.Lock() + self.registration_complete = threading.Event() + + def make_callback(self): + def callback(info): + start_time = time.time() + time.sleep(0.02) # Simulate race window + + with self.data_lock: + initial_event_count = len(self.events_received) + + self.events_received.append(info) + + context_id = info.context + self.context_counts.setdefault(context_id, 0) + self.context_counts[context_id] += 1 + + event_type = info.__class__.__name__ + self.event_type_counts.setdefault(event_type, 0) + self.event_type_counts[event_type] += 1 + + processing_time = time.time() - start_time + self.processing_times.append(processing_time) + + final_event_count = len(self.events_received) + final_context_total = sum(self.context_counts.values()) + final_type_total = sum(self.event_type_counts.values()) + final_processing_count = len(self.processing_times) + + expected_count = initial_event_count + 1 + if not ( + final_event_count + == final_context_total + == final_type_total + == final_processing_count + == expected_count + ): + self.consistency_errors.append("Data consistency error") + + return callback + + def register_handler(self, thread_id): + try: + callback = self.make_callback() + callback_id = self.driver.browsing_context.add_event_handler("context_created", callback) + with self.data_lock: + self.callback_ids.append(callback_id) + if len(self.callback_ids) == 5: + self.registration_complete.set() + return callback_id + except Exception as e: + with self.data_lock: + self.thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") + return None + + def remove_handler(self, callback_id, thread_id): + try: + self.driver.browsing_context.remove_event_handler("context_created", callback_id) + except Exception as e: + with self.data_lock: + self.thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") + + +def test_concurrent_event_handler_registration(driver): + helper = _EventHandlerTestHelper(driver) + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(helper.register_handler, f"reg-{i}") for i in range(5)] + for future in futures: + future.result(timeout=15) + + helper.registration_complete.wait(timeout=5) + assert len(helper.callback_ids) == 5, f"Expected 5 handlers, got {len(helper.callback_ids)}" + assert not helper.thread_errors, "Errors during registration: \n" + "\n".join(helper.thread_errors) + + +def test_event_callback_data_consistency(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + test_contexts = [] + for _ in range(3): + context = driver.browsing_context.create(type=WindowTypes.TAB) + test_contexts.append(context) + + for ctx in test_contexts: + driver.browsing_context.close(ctx) + + with helper.data_lock: + assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) + assert len(helper.events_received) > 0, "No events received" + assert len(helper.events_received) == sum(helper.context_counts.values()) + assert len(helper.events_received) == sum(helper.event_type_counts.values()) + assert len(helper.events_received) == len(helper.processing_times) + + +def test_concurrent_event_handler_removal(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit(helper.remove_handler, callback_id, f"rem-{i}") + for i, callback_id in enumerate(helper.callback_ids) + ] + for future in futures: + future.result(timeout=15) + + assert not helper.thread_errors, "Errors during removal: \n" + "\n".join(helper.thread_errors) + + +def test_no_event_after_handler_removal(driver): + helper = _EventHandlerTestHelper(driver) + + for i in range(5): + helper.register_handler(f"reg-{i}") + + context = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(context) + + events_before = len(helper.events_received) + + for i, callback_id in enumerate(helper.callback_ids): + helper.remove_handler(callback_id, f"rem-{i}") + + post_context = driver.browsing_context.create(type=WindowTypes.TAB) + driver.browsing_context.close(post_context) + + with helper.data_lock: + new_events = len(helper.events_received) - events_before + + assert new_events == 0, f"Expected 0 new events after removal, got {new_events}"