Skip to content

Commit 40077d1

Browse files
committed
add high level API for script module - pin, unpin and execute
1 parent 47af349 commit 40077d1

File tree

2 files changed

+97
-2
lines changed

2 files changed

+97
-2
lines changed

py/selenium/webdriver/common/bidi/script.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dataclasses import dataclass
1919
from typing import Any, Optional
2020

21+
from selenium.common.exceptions import WebDriverException
2122
from selenium.webdriver.common.bidi.common import command_builder
2223

2324
from .log import LogEntryAdded
@@ -238,12 +239,15 @@ class Script:
238239
"realm_destroyed": "script.realmDestroyed",
239240
}
240241

241-
def __init__(self, conn):
242+
def __init__(self, conn, driver=None):
242243
self.conn = conn
244+
self.driver = driver
243245
self.log_entry_subscribed = False
244246
self.subscriptions = {}
245247
self.callbacks = {}
246248

249+
# High-level APIs for SCRIPT module
250+
247251
def add_console_message_handler(self, handler):
248252
self._subscribe_to_log_entries()
249253
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler))
@@ -258,6 +262,97 @@ def remove_console_message_handler(self, id):
258262

259263
remove_javascript_error_handler = remove_console_message_handler
260264

265+
def pin(self, script: str) -> str:
266+
"""Pins a script to the current browsing context.
267+
268+
Parameters:
269+
-----------
270+
script: The script to pin.
271+
272+
Returns:
273+
-------
274+
str: The ID of the pinned script.
275+
"""
276+
return self._add_preload_script(script)
277+
278+
def unpin(self, script_id: str) -> None:
279+
"""Unpins a script from the current browsing context.
280+
281+
Parameters:
282+
-----------
283+
script_id: The ID of the pinned script to unpin.
284+
"""
285+
self._remove_preload_script(script_id)
286+
287+
def execute(self, script: str, *args) -> dict:
288+
"""Executes a script in the current browsing context.
289+
290+
Parameters:
291+
-----------
292+
script: The script function to execute.
293+
*args: Arguments to pass to the script function.
294+
295+
Returns:
296+
-------
297+
dict: The result value from the script execution.
298+
299+
Raises:
300+
------
301+
WebDriverException: If the script execution fails.
302+
"""
303+
304+
if self.driver is None:
305+
raise WebDriverException("Driver reference is required for script execution")
306+
browsing_context_id = self.driver.current_window_handle
307+
308+
# Convert arguments to the format expected by BiDi call_function (LocalValue Type)
309+
arguments = []
310+
for arg in args:
311+
arguments.append(self.__convert_to_local_value(arg))
312+
313+
target = {"context": browsing_context_id}
314+
315+
result = self._call_function(
316+
function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None
317+
)
318+
319+
if result.type == "success":
320+
return result.result
321+
else:
322+
error_message = "Error while executing script"
323+
if result.exception_details:
324+
if "text" in result.exception_details:
325+
error_message += f": {result.exception_details['text']}"
326+
elif "message" in result.exception_details:
327+
error_message += f": {result.exception_details['message']}"
328+
329+
raise WebDriverException(error_message)
330+
331+
def __convert_to_local_value(self, value) -> dict:
332+
"""
333+
Converts a Python value to BiDi LocalValue format.
334+
"""
335+
if value is None:
336+
return {"type": "undefined"}
337+
elif isinstance(value, bool):
338+
return {"type": "boolean", "value": value}
339+
elif isinstance(value, (int, float)):
340+
return {"type": "number", "value": value}
341+
elif isinstance(value, str):
342+
return {"type": "string", "value": value}
343+
elif isinstance(value, (list, tuple)):
344+
return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]}
345+
elif isinstance(value, dict):
346+
return {
347+
"type": "object",
348+
"value": [
349+
[self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items()
350+
],
351+
}
352+
else:
353+
# For other types, convert to string
354+
return {"type": "string", "value": str(value)}
355+
261356
# low-level APIs for script module
262357
def _add_preload_script(
263358
self,

py/selenium/webdriver/remote/webdriver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,7 @@ def script(self):
12401240
self._start_bidi()
12411241

12421242
if not self._script:
1243-
self._script = Script(self._websocket_connection)
1243+
self._script = Script(self._websocket_connection, self)
12441244

12451245
return self._script
12461246

0 commit comments

Comments
 (0)