diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 80e9c640d59b2..76f5c25722dda 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -15,14 +15,234 @@ # specific language governing permissions and limitations # under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from selenium.webdriver.common.bidi.common import command_builder + from .log import LogEntryAdded from .session import Session +class ResultOwnership: + """Represents the possible result ownership types.""" + + NONE = "none" + ROOT = "root" + + +class RealmType: + """Represents the possible realm types.""" + + WINDOW = "window" + DEDICATED_WORKER = "dedicated-worker" + SHARED_WORKER = "shared-worker" + SERVICE_WORKER = "service-worker" + WORKER = "worker" + PAINT_WORKLET = "paint-worklet" + AUDIO_WORKLET = "audio-worklet" + WORKLET = "worklet" + + +@dataclass +class RealmInfo: + """Represents information about a realm.""" + + realm: str + origin: str + type: str + context: Optional[str] = None + sandbox: Optional[str] = None + + @classmethod + def from_json(cls, json: Dict[str, Any]) -> "RealmInfo": + """Creates a RealmInfo instance from a dictionary. + + Parameters: + ----------- + json: A dictionary containing the realm information. + + Returns: + ------- + RealmInfo: A new instance of RealmInfo. + """ + if "realm" not in json: + raise ValueError("Missing required field 'realm' in RealmInfo") + if "origin" not in json: + raise ValueError("Missing required field 'origin' in RealmInfo") + if "type" not in json: + raise ValueError("Missing required field 'type' in RealmInfo") + + return cls( + realm=json["realm"], + origin=json["origin"], + type=json["type"], + context=json.get("context"), + sandbox=json.get("sandbox"), + ) + + +@dataclass +class Source: + """Represents the source of a script message.""" + + realm: str + context: Optional[str] = None + + @classmethod + def from_json(cls, json: Dict[str, Any]) -> "Source": + """Creates a Source instance from a dictionary. + + Parameters: + ----------- + json: A dictionary containing the source information. + + Returns: + ------- + Source: A new instance of Source. + """ + if "realm" not in json: + raise ValueError("Missing required field 'realm' in Source") + + return cls( + realm=json["realm"], + context=json.get("context"), + ) + + +@dataclass +class EvaluateResult: + """Represents the result of script evaluation.""" + + type: str + realm: str + result: Optional[dict] = None + exception_details: Optional[dict] = None + + @classmethod + def from_json(cls, json: Dict[str, Any]) -> "EvaluateResult": + """Creates an EvaluateResult instance from a dictionary. + + Parameters: + ----------- + json: A dictionary containing the evaluation result. + + Returns: + ------- + EvaluateResult: A new instance of EvaluateResult. + """ + if "realm" not in json: + raise ValueError("Missing required field 'realm' in EvaluateResult") + if "type" not in json: + raise ValueError("Missing required field 'type' in EvaluateResult") + + return cls( + type=json["type"], + realm=json["realm"], + result=json.get("result"), + exception_details=json.get("exceptionDetails"), + ) + + +class ScriptMessage: + """Represents a script message event.""" + + event_class = "script.message" + + def __init__(self, channel: str, data: dict, source: Source): + self.channel = channel + self.data = data + self.source = source + + @classmethod + def from_json(cls, json: Dict[str, Any]) -> "ScriptMessage": + """Creates a ScriptMessage instance from a dictionary. + + Parameters: + ----------- + json: A dictionary containing the script message. + + Returns: + ------- + ScriptMessage: A new instance of ScriptMessage. + """ + if "channel" not in json: + raise ValueError("Missing required field 'channel' in ScriptMessage") + if "data" not in json: + raise ValueError("Missing required field 'data' in ScriptMessage") + if "source" not in json: + raise ValueError("Missing required field 'source' in ScriptMessage") + + return cls( + channel=json["channel"], + data=json["data"], + source=Source.from_json(json["source"]), + ) + + +class RealmCreated: + """Represents a realm created event.""" + + event_class = "script.realmCreated" + + def __init__(self, realm_info: RealmInfo): + self.realm_info = realm_info + + @classmethod + def from_json(cls, json: Dict[str, Any]) -> "RealmCreated": + """Creates a RealmCreated instance from a dictionary. + + Parameters: + ----------- + json: A dictionary containing the realm created event. + + Returns: + ------- + RealmCreated: A new instance of RealmCreated. + """ + return cls(realm_info=RealmInfo.from_json(json)) + + +class RealmDestroyed: + """Represents a realm destroyed event.""" + + event_class = "script.realmDestroyed" + + def __init__(self, realm: str): + self.realm = realm + + @classmethod + def from_json(cls, json: Dict[str, Any]) -> "RealmDestroyed": + """Creates a RealmDestroyed instance from a dictionary. + + Parameters: + ----------- + json: A dictionary containing the realm destroyed event. + + Returns: + ------- + RealmDestroyed: A new instance of RealmDestroyed. + """ + if "realm" not in json: + raise ValueError("Missing required field 'realm' in RealmDestroyed") + + return cls(realm=json["realm"]) + + class Script: + """BiDi implementation of the script module.""" + + EVENTS = { + "message": "script.message", + "realm_created": "script.realmCreated", + "realm_destroyed": "script.realmDestroyed", + } + def __init__(self, conn): self.conn = conn self.log_entry_subscribed = False + self.subscriptions = {} + self.callbacks = {} def add_console_message_handler(self, handler): self._subscribe_to_log_entries() @@ -38,6 +258,186 @@ def remove_console_message_handler(self, id): remove_javascript_error_handler = remove_console_message_handler + # low-level APIs for script module + def _add_preload_script( + self, + function_declaration: str, + arguments: Optional[List[Dict[str, Any]]] = None, + contexts: Optional[List[str]] = None, + user_contexts: Optional[List[str]] = None, + sandbox: Optional[str] = None, + ) -> str: + """Adds a preload script. + + Parameters: + ----------- + function_declaration: The function declaration to preload. + arguments: The arguments to pass to the function. + contexts: The browsing context IDs to apply the script to. + user_contexts: The user context IDs to apply the script to. + sandbox: The sandbox name to apply the script to. + + Returns: + ------- + str: The preload script ID. + + Raises: + ------ + ValueError: If both contexts and user_contexts are provided. + """ + if contexts is not None and user_contexts is not None: + raise ValueError("Cannot specify both contexts and user_contexts") + + params: Dict[str, Any] = {"functionDeclaration": function_declaration} + + if arguments is not None: + params["arguments"] = arguments + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + if sandbox is not None: + params["sandbox"] = sandbox + + result = self.conn.execute(command_builder("script.addPreloadScript", params)) + return result["script"] + + def _remove_preload_script(self, script_id: str) -> None: + """Removes a preload script. + + Parameters: + ----------- + script_id: The preload script ID to remove. + """ + params = {"script": script_id} + self.conn.execute(command_builder("script.removePreloadScript", params)) + + def _disown(self, handles: List[str], target: dict) -> None: + """Disowns the given handles. + + Parameters: + ----------- + handles: The handles to disown. + target: The target realm or context. + """ + params = { + "handles": handles, + "target": target, + } + self.conn.execute(command_builder("script.disown", params)) + + def _call_function( + self, + function_declaration: str, + await_promise: bool, + target: dict, + arguments: Optional[List[dict]] = None, + result_ownership: Optional[str] = None, + serialization_options: Optional[dict] = None, + this: Optional[dict] = None, + user_activation: bool = False, + ) -> EvaluateResult: + """Calls a provided function with given arguments in a given realm. + + Parameters: + ----------- + function_declaration: The function declaration to call. + await_promise: Whether to await promise resolution. + target: The target realm or context. + arguments: The arguments to pass to the function. + result_ownership: The result ownership type. + serialization_options: The serialization options. + this: The 'this' value for the function call. + user_activation: Whether to trigger user activation. + + Returns: + ------- + EvaluateResult: The result of the function call. + """ + params = { + "functionDeclaration": function_declaration, + "awaitPromise": await_promise, + "target": target, + "userActivation": user_activation, + } + + if arguments is not None: + params["arguments"] = arguments + if result_ownership is not None: + params["resultOwnership"] = result_ownership + if serialization_options is not None: + params["serializationOptions"] = serialization_options + if this is not None: + params["this"] = this + + result = self.conn.execute(command_builder("script.callFunction", params)) + return EvaluateResult.from_json(result) + + def _evaluate( + self, + expression: str, + target: dict, + await_promise: bool, + result_ownership: Optional[str] = None, + serialization_options: Optional[dict] = None, + user_activation: bool = False, + ) -> EvaluateResult: + """Evaluates a provided script in a given realm. + + Parameters: + ----------- + expression: The script expression to evaluate. + target: The target realm or context. + await_promise: Whether to await promise resolution. + result_ownership: The result ownership type. + serialization_options: The serialization options. + user_activation: Whether to trigger user activation. + + Returns: + ------- + EvaluateResult: The result of the script evaluation. + """ + params = { + "expression": expression, + "target": target, + "awaitPromise": await_promise, + "userActivation": user_activation, + } + + if result_ownership is not None: + params["resultOwnership"] = result_ownership + if serialization_options is not None: + params["serializationOptions"] = serialization_options + + result = self.conn.execute(command_builder("script.evaluate", params)) + return EvaluateResult.from_json(result) + + def _get_realms( + self, + context: Optional[str] = None, + type: Optional[str] = None, + ) -> List[RealmInfo]: + """Returns a list of all realms, optionally filtered. + + Parameters: + ----------- + context: The browsing context ID to filter by. + type: The realm type to filter by. + + Returns: + ------- + List[RealmInfo]: A list of realm information. + """ + params = {} + + if context is not None: + params["context"] = context + if type is not None: + params["type"] = type + + result = self.conn.execute(command_builder("script.getRealms", params)) + return [RealmInfo.from_json(realm) for realm in result["realms"]] + def _subscribe_to_log_entries(self): if not self.log_entry_subscribed: session = Session(self.conn) diff --git a/py/test/selenium/webdriver/common/bidi_script_tests.py b/py/test/selenium/webdriver/common/bidi_script_tests.py index 9030227ed196f..8677d2dbae396 100644 --- a/py/test/selenium/webdriver/common/bidi_script_tests.py +++ b/py/test/selenium/webdriver/common/bidi_script_tests.py @@ -15,11 +15,28 @@ # specific language governing permissions and limitations # under the License. +import pytest + from selenium.webdriver.common.bidi.log import LogLevel +from selenium.webdriver.common.bidi.script import RealmType, ResultOwnership from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait +def has_shadow_root(node): + if isinstance(node, dict): + shadow_root = node.get("shadowRoot") + if shadow_root and isinstance(shadow_root, dict): + return True + + children = node.get("children", []) + for child in children: + if "value" in child and has_shadow_root(child["value"]): + return True + + return False + + def test_logs_console_messages(driver, pages): pages.load("bidi/logEntryAdded.html") @@ -127,3 +144,420 @@ def test_removes_javascript_message_handler(driver, pages): WebDriverWait(driver, 5).until(lambda _: len(log_entries2) == 2) assert len(log_entries1) == 1 + + +def test_add_preload_script(driver, pages): + """Test adding a preload script.""" + function_declaration = "() => { window.preloadExecuted = true; }" + + script_id = driver.script._add_preload_script(function_declaration) + assert script_id is not None + assert isinstance(script_id, str) + + # Navigate to a page to trigger the preload script + pages.load("blank.html") + + # Check if the preload script was executed + result = driver.script._evaluate( + "window.preloadExecuted", {"context": driver.current_window_handle}, await_promise=False + ) + assert result.result["value"] is True + + +def test_add_preload_script_with_arguments(driver, pages): + """Test adding a preload script with channel arguments.""" + function_declaration = "(channelFunc) => { channelFunc('test_value'); window.preloadValue = 'received'; }" + + arguments = [{"type": "channel", "value": {"channel": "test-channel", "ownership": "root"}}] + + script_id = driver.script._add_preload_script(function_declaration, arguments=arguments) + assert script_id is not None + + pages.load("blank.html") + + result = driver.script._evaluate( + "window.preloadValue", {"context": driver.current_window_handle}, await_promise=False + ) + assert result.result["value"] == "received" + + +def test_add_preload_script_with_contexts(driver, pages): + """Test adding a preload script with specific contexts.""" + function_declaration = "() => { window.contextSpecific = true; }" + contexts = [driver.current_window_handle] + + script_id = driver.script._add_preload_script(function_declaration, contexts=contexts) + assert script_id is not None + + pages.load("blank.html") + + result = driver.script._evaluate( + "window.contextSpecific", {"context": driver.current_window_handle}, await_promise=False + ) + assert result.result["value"] is True + + +def test_add_preload_script_with_user_contexts(driver, pages): + """Test adding a preload script with user contexts.""" + function_declaration = "() => { window.contextSpecific = true; }" + user_context = driver.browser.create_user_context() + + context1 = driver.browsing_context.create(type="window", user_context=user_context) + driver.switch_to.window(context1) + + user_contexts = [user_context] + + script_id = driver.script._add_preload_script(function_declaration, user_contexts=user_contexts) + assert script_id is not None + + pages.load("blank.html") + + result = driver.script._evaluate( + "window.contextSpecific", {"context": driver.current_window_handle}, await_promise=False + ) + assert result.result["value"] is True + + +def test_add_preload_script_with_sandbox(driver, pages): + """Test adding a preload script with sandbox.""" + function_declaration = "() => { window.sandboxScript = true; }" + + script_id = driver.script._add_preload_script(function_declaration, sandbox="test-sandbox") + assert script_id is not None + + pages.load("blank.html") + + # calling evaluate without sandbox should return undefined + result = driver.script._evaluate( + "window.sandboxScript", {"context": driver.current_window_handle}, await_promise=False + ) + assert result.result["type"] == "undefined" + + # calling evaluate within the sandbox should return True + result = driver.script._evaluate( + "window.sandboxScript", + {"context": driver.current_window_handle, "sandbox": "test-sandbox"}, + await_promise=False, + ) + assert result.result["value"] is True + + +def test_add_preload_script_invalid_arguments(driver): + """Test that providing both contexts and user_contexts raises an error.""" + function_declaration = "() => {}" + + with pytest.raises(ValueError, match="Cannot specify both contexts and user_contexts"): + driver.script._add_preload_script(function_declaration, contexts=["context1"], user_contexts=["user1"]) + + +def test_remove_preload_script(driver, pages): + """Test removing a preload script.""" + function_declaration = "() => { window.removableScript = true; }" + + script_id = driver.script._add_preload_script(function_declaration) + driver.script._remove_preload_script(script_id=script_id) + + # Navigate to a page after removing the script + pages.load("blank.html") + + # The script should not have executed + result = driver.script._evaluate( + "typeof window.removableScript", {"context": driver.current_window_handle}, await_promise=False + ) + assert result.result["value"] == "undefined" + + +def test_evaluate_expression(driver, pages): + """Test evaluating a simple expression.""" + pages.load("blank.html") + + result = driver.script._evaluate("1 + 2", {"context": driver.current_window_handle}, await_promise=False) + + assert result.realm is not None + assert result.result["type"] == "number" + assert result.result["value"] == 3 + assert result.exception_details is None + + +def test_evaluate_with_await_promise(driver, pages): + """Test evaluating an expression that returns a promise.""" + pages.load("blank.html") + + result = driver.script._evaluate( + "Promise.resolve(42)", {"context": driver.current_window_handle}, await_promise=True + ) + + assert result.result["type"] == "number" + assert result.result["value"] == 42 + + +def test_evaluate_with_exception(driver, pages): + """Test evaluating an expression that throws an exception.""" + pages.load("blank.html") + + result = driver.script._evaluate( + "throw new Error('Test error')", {"context": driver.current_window_handle}, await_promise=False + ) + + assert result.exception_details is not None + assert "Test error" in str(result.exception_details) + + +def test_evaluate_with_result_ownership(driver, pages): + """Test evaluating with different result ownership settings.""" + pages.load("blank.html") + + # Test with ROOT ownership + result = driver.script._evaluate( + "({ test: 'value' })", + {"context": driver.current_window_handle}, + await_promise=False, + result_ownership=ResultOwnership.ROOT, + ) + + # ROOT result ownership should return a handle + assert "handle" in result.result + + # Test with NONE ownership + result = driver.script._evaluate( + "({ test: 'value' })", + {"context": driver.current_window_handle}, + await_promise=False, + result_ownership=ResultOwnership.NONE, + ) + + assert "handle" not in result.result + assert result.result is not None + + +def test_evaluate_with_serialization_options(driver, pages): + """Test evaluating with serialization options.""" + pages.load("shadowRootPage.html") + + serialization_options = {"maxDomDepth": 2, "maxObjectDepth": 2, "includeShadowTree": "all"} + + result = driver.script._evaluate( + "document.body", + {"context": driver.current_window_handle}, + await_promise=False, + serialization_options=serialization_options, + ) + root_node = result.result["value"] + + # maxDomDepth will contain a children property + assert "children" in result.result["value"] + # the page will have atleast one shadow root + assert has_shadow_root(root_node) + + +def test_evaluate_with_user_activation(driver, pages): + """Test evaluating with user activation.""" + pages.load("blank.html") + + result = driver.script._evaluate( + "navigator.userActivation ? navigator.userActivation.isActive : false", + {"context": driver.current_window_handle}, + await_promise=False, + user_activation=True, + ) + + # the value should be True if user activation is active + assert result.result["value"] is True + + +def test_call_function(driver, pages): + """Test calling a function.""" + pages.load("blank.html") + + result = driver.script._call_function( + "(a, b) => a + b", + await_promise=False, + target={"context": driver.current_window_handle}, + arguments=[{"type": "number", "value": 5}, {"type": "number", "value": 3}], + ) + + assert result.result["type"] == "number" + assert result.result["value"] == 8 + + +def test_call_function_with_this(driver, pages): + """Test calling a function with a specific 'this' value.""" + pages.load("blank.html") + + # First set up an object + driver.script._evaluate( + "window.testObj = { value: 10 }", {"context": driver.current_window_handle}, await_promise=False + ) + + result = driver.script._call_function( + "function() { return this.value; }", + await_promise=False, + target={"context": driver.current_window_handle}, + this={"type": "object", "value": [["value", {"type": "number", "value": 20}]]}, + ) + + assert result.result["type"] == "number" + assert result.result["value"] == 20 + + +def test_call_function_with_user_activation(driver, pages): + """Test calling a function with user activation.""" + pages.load("blank.html") + + result = driver.script._call_function( + "() => navigator.userActivation ? navigator.userActivation.isActive : false", + await_promise=False, + target={"context": driver.current_window_handle}, + user_activation=True, + ) + + # the value should be True if user activation is active + assert result.result["value"] is True + + +def test_call_function_with_serialization_options(driver, pages): + """Test calling a function with serialization options.""" + pages.load("shadowRootPage.html") + + serialization_options = {"maxDomDepth": 2, "maxObjectDepth": 2, "includeShadowTree": "all"} + + result = driver.script._call_function( + "() => document.body", + await_promise=False, + target={"context": driver.current_window_handle}, + serialization_options=serialization_options, + ) + + root_node = result.result["value"] + + # maxDomDepth will contain a children property + assert "children" in result.result["value"] + # the page will have atleast one shadow root + assert has_shadow_root(root_node) + + +def test_call_function_with_exception(driver, pages): + """Test calling a function that throws an exception.""" + pages.load("blank.html") + + result = driver.script._call_function( + "() => { throw new Error('Function error'); }", + await_promise=False, + target={"context": driver.current_window_handle}, + ) + + assert result.exception_details is not None + assert "Function error" in str(result.exception_details) + + +def test_call_function_with_await_promise(driver, pages): + """Test calling a function that returns a promise.""" + pages.load("blank.html") + + result = driver.script._call_function( + "() => Promise.resolve('async result')", await_promise=True, target={"context": driver.current_window_handle} + ) + + assert result.result["type"] == "string" + assert result.result["value"] == "async result" + + +def test_call_function_with_result_ownership(driver, pages): + """Test calling a function with different result ownership settings.""" + pages.load("blank.html") + + # Call a function that returns an object with ownership "root" + result = driver.script._call_function( + "function() { return { greet: 'Hi', number: 42 }; }", + await_promise=False, + target={"context": driver.current_window_handle}, + result_ownership="root", + ) + + # Verify that a handle is returned + assert result.result["type"] == "object" + assert "handle" in result.result + handle = result.result["handle"] + + # Use the handle in another function call + result2 = driver.script._call_function( + "function() { return this.number + 1; }", + await_promise=False, + target={"context": driver.current_window_handle}, + this={"handle": handle}, + ) + + assert result2.result["type"] == "number" + assert result2.result["value"] == 43 + + +def test_get_realms(driver, pages): + """Test getting all realms.""" + pages.load("blank.html") + + realms = driver.script._get_realms() + + assert len(realms) > 0 + assert all(hasattr(realm, "realm") for realm in realms) + assert all(hasattr(realm, "origin") for realm in realms) + assert all(hasattr(realm, "type") for realm in realms) + + +def test_get_realms_filtered_by_context(driver, pages): + """Test getting realms filtered by context.""" + pages.load("blank.html") + + realms = driver.script._get_realms(context=driver.current_window_handle) + + assert len(realms) > 0 + # All realms should be associated with the specified context + for realm in realms: + if realm.context is not None: + assert realm.context == driver.current_window_handle + + +def test_get_realms_filtered_by_type(driver, pages): + """Test getting realms filtered by type.""" + pages.load("blank.html") + + realms = driver.script._get_realms(type=RealmType.WINDOW) + + assert len(realms) > 0 + # All realms should be of the WINDOW type + for realm in realms: + assert realm.type == RealmType.WINDOW + + +def test_disown_handles(driver, pages): + """Test disowning handles.""" + pages.load("blank.html") + + # Create an object with root ownership (this will return a handle) + result = driver.script._evaluate( + "({foo: 'bar'})", target={"context": driver.current_window_handle}, await_promise=False, result_ownership="root" + ) + + handle = result.result["handle"] + assert handle is not None + + # Use the handle in a function call (this should succeed) + result_before = driver.script._call_function( + "function(obj) { return obj.foo; }", + await_promise=False, + target={"context": driver.current_window_handle}, + arguments=[{"handle": handle}], + ) + + assert result_before.result["value"] == "bar" + + # Disown the handle + driver.script._disown(handles=[handle], target={"context": driver.current_window_handle}) + + # Try using the disowned handle (this should fail) + with pytest.raises(Exception): + driver.script._call_function( + "function(obj) { return obj.foo; }", + await_promise=False, + target={"context": driver.current_window_handle}, + arguments=[{"handle": handle}], + )