Skip to content
124 changes: 123 additions & 1 deletion py/selenium/webdriver/common/bidi/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# specific language governing permissions and limitations
# under the License.

import datetime
import math
from dataclasses import dataclass
from typing import Any, Optional

from selenium.common.exceptions import WebDriverException
from selenium.webdriver.common.bidi.common import command_builder

from .log import LogEntryAdded
Expand Down Expand Up @@ -238,12 +241,15 @@ class Script:
"realm_destroyed": "script.realmDestroyed",
}

def __init__(self, conn):
def __init__(self, conn, driver=None):
self.conn = conn
self.driver = driver
self.log_entry_subscribed = False
self.subscriptions = {}
self.callbacks = {}

# High-level APIs for SCRIPT module

def add_console_message_handler(self, handler):
self._subscribe_to_log_entries()
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler))
Expand All @@ -258,6 +264,122 @@ def remove_console_message_handler(self, id):

remove_javascript_error_handler = remove_console_message_handler

def pin(self, script: str) -> str:
"""Pins a script to the current browsing context.

Parameters:
-----------
script: The script to pin.

Returns:
-------
str: The ID of the pinned script.
"""
return self._add_preload_script(script)

def unpin(self, script_id: str) -> None:
"""Unpins a script from the current browsing context.

Parameters:
-----------
script_id: The ID of the pinned script to unpin.
"""
self._remove_preload_script(script_id)

def execute(self, script: str, *args) -> dict:
"""Executes a script in the current browsing context.

Parameters:
-----------
script: The script function to execute.
*args: Arguments to pass to the script function.

Returns:
-------
dict: The result value from the script execution.

Raises:
------
WebDriverException: If the script execution fails.
"""

if self.driver is None:
raise WebDriverException("Driver reference is required for script execution")
browsing_context_id = self.driver.current_window_handle

# Convert arguments to the format expected by BiDi call_function (LocalValue Type)
arguments = []
for arg in args:
arguments.append(self.__convert_to_local_value(arg))

target = {"context": browsing_context_id}

result = self._call_function(
function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None
)

if result.type == "success":
return result.result
else:
error_message = "Error while executing script"
if result.exception_details:
if "text" in result.exception_details:
error_message += f": {result.exception_details['text']}"
elif "message" in result.exception_details:
error_message += f": {result.exception_details['message']}"

raise WebDriverException(error_message)

def __convert_to_local_value(self, value) -> dict:
"""
Converts a Python value to BiDi LocalValue format.
"""
if value is None:
return {"type": "null"}
elif isinstance(value, bool):
return {"type": "boolean", "value": value}
elif isinstance(value, (int, float)):
if isinstance(value, float):
if math.isnan(value):
return {"type": "number", "value": "NaN"}
elif math.isinf(value):
if value > 0:
return {"type": "number", "value": "Infinity"}
else:
return {"type": "number", "value": "-Infinity"}
elif value == 0.0 and math.copysign(1.0, value) < 0:
return {"type": "number", "value": "-0"}

JS_MAX_SAFE_INTEGER = 9007199254740991
if isinstance(value, int) and (value > JS_MAX_SAFE_INTEGER or value < -JS_MAX_SAFE_INTEGER):
return {"type": "bigint", "value": str(value)}

return {"type": "number", "value": value}

elif isinstance(value, str):
return {"type": "string", "value": value}
elif isinstance(value, datetime.datetime):
# Convert Python datetime to JavaScript Date (ISO 8601 format)
return {"type": "date", "value": value.isoformat() + "Z" if value.tzinfo is None else value.isoformat()}
elif isinstance(value, datetime.date):
# Convert Python date to JavaScript Date
dt = datetime.datetime.combine(value, datetime.time.min).replace(tzinfo=datetime.timezone.utc)
return {"type": "date", "value": dt.isoformat()}
elif isinstance(value, set):
return {"type": "set", "value": [self.__convert_to_local_value(item) for item in value]}
elif isinstance(value, (list, tuple)):
return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]}
elif isinstance(value, dict):
return {
"type": "object",
"value": [
[self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items()
],
}
else:
# For other types, convert to string
return {"type": "string", "value": str(value)}

# low-level APIs for script module
def _add_preload_script(
self,
Expand Down
2 changes: 1 addition & 1 deletion py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ def script(self):
self._start_bidi()

if not self._script:
self._script = Script(self._websocket_connection)
self._script = Script(self._websocket_connection, self)

return self._script

Expand Down
Loading