Skip to content
Closed
2 changes: 1 addition & 1 deletion py/selenium/webdriver/common/action_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def pause(self, seconds: float | int) -> ActionChains:
"""Pause all inputs for the specified duration in seconds."""

self.w3c_actions.pointer_action.pause(seconds)
self.w3c_actions.key_action.pause(seconds)
self.w3c_actions.key_action.pause(int(seconds))

return self

Expand Down
186 changes: 141 additions & 45 deletions py/selenium/webdriver/common/bidi/browsing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Union
import warnings
from typing import Any, Callable, Optional, Union

from selenium.webdriver.common.bidi.common import command_builder

Expand Down Expand Up @@ -66,12 +67,23 @@ def from_json(cls, json: dict) -> "NavigationInfo":
-------
NavigationInfo: A new instance of NavigationInfo.
"""
return cls(
context=json.get("context"),
navigation=json.get("navigation"),
timestamp=json.get("timestamp"),
url=json.get("url"),
)
context = json.get("context")
if context is None or not isinstance(context, str):
raise ValueError("context is required and must be a string")

navigation = json.get("navigation")
if navigation is not None and not isinstance(navigation, str):
raise ValueError("navigation must be a string")

timestamp = json.get("timestamp")
if timestamp is None or not isinstance(timestamp, int):
raise ValueError("timestamp is required and must be an integer")

url = json.get("url")
if url is None or not isinstance(url, str):
raise ValueError("url is required and must be a string")

return cls(context, navigation, timestamp, url)


class BrowsingContextInfo:
Expand Down Expand Up @@ -108,12 +120,25 @@ def from_json(cls, json: dict) -> "BrowsingContextInfo":
BrowsingContextInfo: A new instance of BrowsingContextInfo.
"""
children = None
if json.get("children") is not None:
children = [BrowsingContextInfo.from_json(child) for child in json.get("children")]
raw_children = json.get("children")
if raw_children is not None and isinstance(raw_children, list):
children = []
for child in raw_children:
if isinstance(child, dict):
children.append(BrowsingContextInfo.from_json(child))
else:
warnings.warn(f"Unexpected child type in browsing context: {type(child)}")
context = json.get("context")
if context is None or not isinstance(context, str):
raise ValueError("context is required and must be a string")

url = json.get("url")
if url is None or not isinstance(url, str):
raise ValueError("url is required and must be a string")

return cls(
context=json.get("context"),
url=json.get("url"),
context=context,
url=url,
children=children,
parent=json.get("parent"),
user_context=json.get("userContext"),
Expand Down Expand Up @@ -148,12 +173,32 @@ def from_json(cls, json: dict) -> "DownloadWillBeginParams":
-------
DownloadWillBeginParams: A new instance of DownloadWillBeginParams.
"""
context = json.get("context")
if context is None or not isinstance(context, str):
raise ValueError("context is required and must be a string")

navigation = json.get("navigation")
if navigation is not None and not isinstance(navigation, str):
raise ValueError("navigation must be a string")

timestamp = json.get("timestamp")
if timestamp is None or not isinstance(timestamp, int):
raise ValueError("timestamp is required and must be an integer")

url = json.get("url")
if url is None or not isinstance(url, str):
raise ValueError("url is required and must be a string")

suggested_filename = json.get("suggestedFilename")
if suggested_filename is None or not isinstance(suggested_filename, str):
raise ValueError("suggestedFilename is required and must be a string")

return cls(
context=json.get("context"),
navigation=json.get("navigation"),
timestamp=json.get("timestamp"),
url=json.get("url"),
suggested_filename=json.get("suggestedFilename"),
context=context,
navigation=navigation,
timestamp=timestamp,
url=url,
suggested_filename=suggested_filename,
)


Expand Down Expand Up @@ -186,12 +231,32 @@ def from_json(cls, json: dict) -> "UserPromptOpenedParams":
-------
UserPromptOpenedParams: A new instance of UserPromptOpenedParams.
"""
context = json.get("context")
if context is None or not isinstance(context, str):
raise ValueError("context is required and must be a string")

handler = json.get("handler")
if handler is None or not isinstance(handler, str):
raise ValueError("handler is required and must be a string")

message = json.get("message")
if message is None or not isinstance(message, str):
raise ValueError("message is required and must be a string")

type_value = json.get("type")
if type_value is None or not isinstance(type_value, str):
raise ValueError("type is required and must be a string")

default_value = json.get("defaultValue")
if default_value is not None and not isinstance(default_value, str):
raise ValueError("defaultValue must be a string if provided")

return cls(
context=json.get("context"),
handler=json.get("handler"),
message=json.get("message"),
type=json.get("type"),
default_value=json.get("defaultValue"),
context=context,
handler=handler,
message=message,
type=type_value,
default_value=default_value,
)


Expand Down Expand Up @@ -222,11 +287,27 @@ def from_json(cls, json: dict) -> "UserPromptClosedParams":
-------
UserPromptClosedParams: A new instance of UserPromptClosedParams.
"""
context = json.get("context")
if context is None or not isinstance(context, str):
raise ValueError("context is required and must be a string")

accepted = json.get("accepted")
if accepted is None or not isinstance(accepted, bool):
raise ValueError("accepted is required and must be a boolean")

type_value = json.get("type")
if type_value is None or not isinstance(type_value, str):
raise ValueError("type is required and must be a string")

user_text = json.get("userText")
if user_text is not None and not isinstance(user_text, str):
raise ValueError("userText must be a string if provided")

return cls(
context=json.get("context"),
accepted=json.get("accepted"),
type=json.get("type"),
user_text=json.get("userText"),
context=context,
accepted=accepted,
type=type_value,
user_text=user_text,
)


Expand All @@ -253,9 +334,17 @@ def from_json(cls, json: dict) -> "HistoryUpdatedParams":
-------
HistoryUpdatedParams: A new instance of HistoryUpdatedParams.
"""
context = json.get("context")
if context is None or not isinstance(context, str):
raise ValueError("context is required and must be a string")

url = json.get("url")
if url is None or not isinstance(url, str):
raise ValueError("url is required and must be a string")

return cls(
context=json.get("context"),
url=json.get("url"),
context=context,
url=url,
)


Expand All @@ -278,7 +367,11 @@ def from_json(cls, json: dict) -> "BrowsingContextEvent":
-------
BrowsingContextEvent: A new instance of BrowsingContextEvent.
"""
return cls(event_class=json.get("event_class"), **json)
event_class = json.get("event_class")
if event_class is None or not isinstance(event_class, str):
raise ValueError("event_class is required and must be a string")

return cls(event_class=event_class, **json)


class BrowsingContext:
Expand Down Expand Up @@ -339,7 +432,7 @@ def capture_screenshot(
-------
str: The Base64-encoded screenshot.
"""
params = {"context": context, "origin": origin}
params: dict[str, Any] = {"context": context, "origin": origin}
if format is not None:
params["format"] = format
if clip is not None:
Expand Down Expand Up @@ -383,7 +476,7 @@ def create(
-------
str: The browsing context ID of the created navigable.
"""
params = {"type": type}
params: dict[str, Any] = {"type": type}
if reference_context is not None:
params["referenceContext"] = reference_context
if background is not None:
Expand Down Expand Up @@ -411,7 +504,7 @@ def get_tree(
-------
List[BrowsingContextInfo]: A list of browsing context information.
"""
params = {}
params: dict[str, Any] = {}
if max_depth is not None:
params["maxDepth"] = max_depth
if root is not None:
Expand All @@ -434,7 +527,7 @@ def handle_user_prompt(
accept: Whether to accept the prompt.
user_text: The text to enter in the prompt.
"""
params = {"context": context}
params: dict[str, Any] = {"context": context}
if accept is not None:
params["accept"] = accept
if user_text is not None:
Expand Down Expand Up @@ -464,7 +557,7 @@ def locate_nodes(
-------
List[Dict]: A list of nodes.
"""
params = {"context": context, "locator": locator}
params: dict[str, Any] = {"context": context, "locator": locator}
if max_node_count is not None:
params["maxNodeCount"] = max_node_count
if serialization_options is not None:
Expand Down Expand Up @@ -564,7 +657,7 @@ def reload(
-------
Dict: A dictionary containing the navigation result.
"""
params = {"context": context}
params: dict[str, Any] = {"context": context}
if ignore_cache is not None:
params["ignoreCache"] = ignore_cache
if wait is not None:
Expand Down Expand Up @@ -593,7 +686,7 @@ def set_viewport(
------
Exception: If the browsing context is not a top-level traversable.
"""
params = {}
params: dict[str, Any] = {}
if context is not None:
params["context"] = context
if viewport is not None:
Expand Down Expand Up @@ -621,7 +714,7 @@ def traverse_history(self, context: str, delta: int) -> dict:
result = self.conn.execute(command_builder("browsingContext.traverseHistory", params))
return result

def _on_event(self, event_name: str, callback: callable) -> int:
def _on_event(self, event_name: str, callback: Callable) -> int:
"""Set a callback function to subscribe to a browsing context event.

Parameters:
Expand Down Expand Up @@ -665,7 +758,7 @@ def _callback(event_data):

return callback_id

def add_event_handler(self, event: str, callback: callable, contexts: Optional[list[str]] = None) -> int:
def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int:
"""Add an event handler to the browsing context.

Parameters:
Expand Down Expand Up @@ -710,15 +803,18 @@ def remove_event_handler(self, event: str, callback_id: int) -> None:
except KeyError:
raise Exception(f"Event {event} not found")

event = BrowsingContextEvent(event_name)
event_obj = BrowsingContextEvent(event_name)

self.conn.remove_callback(event, callback_id)
self.subscriptions[event_name].remove(callback_id)
if len(self.subscriptions[event_name]) == 0:
params = {"events": [event_name]}
session = Session(self.conn)
self.conn.execute(session.unsubscribe(**params))
del self.subscriptions[event_name]
self.conn.remove_callback(event_obj, callback_id)
if event_name in self.subscriptions:
callbacks = self.subscriptions[event_name]
if callback_id in callbacks:
callbacks.remove(callback_id)
if not callbacks:
params = {"events": [event_name]}
session = Session(self.conn)
self.conn.execute(session.unsubscribe(**params))
del self.subscriptions[event_name]

def clear_event_handlers(self) -> None:
"""Clear all event handlers from the browsing context."""
Expand Down
9 changes: 4 additions & 5 deletions py/selenium/webdriver/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def find_connectable_ip(host: Union[str, bytes, bytearray, None], port: Optional
for family, _, _, _, sockaddr in addrinfos:
connectable = True
if port:
connectable = is_connectable(port, sockaddr[0])
connectable = is_connectable(port, str(sockaddr[0]))

if connectable and family == socket.AF_INET:
return sockaddr[0]
return str(sockaddr[0])
if connectable and not ip and family == socket.AF_INET6:
ip = sockaddr[0]
ip = str(sockaddr[0])
return ip


Expand Down Expand Up @@ -131,8 +131,7 @@ def keys_to_typing(value: Iterable[AnyKey]) -> list[str]:
characters: list[str] = []
for val in value:
if isinstance(val, Keys):
# Todo: Does this even work?
characters.append(val)
characters.append(str(val))
elif isinstance(val, (int, float)):
characters.extend(str(val))
else:
Expand Down
9 changes: 6 additions & 3 deletions py/selenium/webdriver/remote/shadowroot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.

from hashlib import md5 as md5_hash
from typing import Union, TYPE_CHECKING
if TYPE_CHECKING:
from selenium.webdriver.support.relative_locator import RelativeBy

from ..common.by import By
from ..common.by import By, ByType
from .command import Command


Expand All @@ -43,7 +46,7 @@ def __repr__(self) -> str:
def id(self) -> str:
return self._id

def find_element(self, by: str = By.ID, value: str = None):
def find_element(self, by: "Union[ByType, RelativeBy]" = By.ID, value: str = None):
"""Find an element inside a shadow root given a By strategy and
locator.

Expand Down Expand Up @@ -82,7 +85,7 @@ def find_element(self, by: str = By.ID, value: str = None):

return self._execute(Command.FIND_ELEMENT_FROM_SHADOW_ROOT, {"using": by, "value": value})["value"]

def find_elements(self, by: str = By.ID, value: str = None):
def find_elements(self, by: "Union[ByType, RelativeBy]" = By.ID, value: str = None):
"""Find elements inside a shadow root given a By strategy and locator.

Parameters:
Expand Down
Loading