diff --git a/py/selenium/webdriver/common/action_chains.py b/py/selenium/webdriver/common/action_chains.py index 7dcf5cae7f143..17aa1f3fba7a4 100644 --- a/py/selenium/webdriver/common/action_chains.py +++ b/py/selenium/webdriver/common/action_chains.py @@ -273,7 +273,7 @@ def pause(self, seconds: float | int) -> ActionChains: """Pause all inputs for the specified duration in seconds.""" self.w3c_actions.pointer_action.pause(seconds) - self.w3c_actions.key_action.pause(seconds) + self.w3c_actions.key_action.pause(int(seconds)) return self diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index e843752b7a27f..deb0c2b534f2c 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Union +from typing import Any, Callable, Optional, Union from selenium.webdriver.common.bidi.common import command_builder @@ -66,12 +66,23 @@ def from_json(cls, json: dict) -> "NavigationInfo": ------- NavigationInfo: A new instance of NavigationInfo. """ - return cls( - context=json.get("context"), - navigation=json.get("navigation"), - timestamp=json.get("timestamp"), - url=json.get("url"), - ) + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + navigation = json.get("navigation") + if navigation is not None and not isinstance(navigation, str): + raise ValueError("navigation must be a string") + + timestamp = json.get("timestamp") + if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: + raise ValueError("timestamp is required and must be a non-negative integer") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + + return cls(context, navigation, timestamp, url) class BrowsingContextInfo: @@ -82,10 +93,10 @@ def __init__( context: str, url: str, children: Optional[list["BrowsingContextInfo"]], + client_window: str, + user_context: str, parent: Optional[str] = None, - user_context: Optional[str] = None, original_opener: Optional[str] = None, - client_window: Optional[str] = None, ): self.context = context self.url = url @@ -108,17 +119,49 @@ def from_json(cls, json: dict) -> "BrowsingContextInfo": BrowsingContextInfo: A new instance of BrowsingContextInfo. """ children = None - if json.get("children") is not None: - children = [BrowsingContextInfo.from_json(child) for child in json.get("children")] + raw_children = json.get("children") + if raw_children is not None: + if not isinstance(raw_children, list): + raise ValueError("children must be a list if provided") + + children = [] + for child in raw_children: + if not isinstance(child, dict): + raise ValueError(f"Each child must be a dictionary, got {type(child)}") + children.append(BrowsingContextInfo.from_json(child)) + + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + + parent = json.get("parent") + if parent is not None and not isinstance(parent, str): + raise ValueError("parent must be a string if provided") + + user_context = json.get("userContext") + if user_context is None or not isinstance(user_context, str): + raise ValueError("userContext is required and must be a string") + + original_opener = json.get("originalOpener") + if original_opener is not None and not isinstance(original_opener, str): + raise ValueError("originalOpener must be a string if provided") + + client_window = json.get("clientWindow") + if client_window is None or not isinstance(client_window, str): + raise ValueError("clientWindow is required and must be a string") return cls( - context=json.get("context"), - url=json.get("url"), + context=context, + url=url, children=children, - parent=json.get("parent"), - user_context=json.get("userContext"), - original_opener=json.get("originalOpener"), - client_window=json.get("clientWindow"), + client_window=client_window, + user_context=user_context, + parent=parent, + original_opener=original_opener, ) @@ -148,12 +191,32 @@ def from_json(cls, json: dict) -> "DownloadWillBeginParams": ------- DownloadWillBeginParams: A new instance of DownloadWillBeginParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + navigation = json.get("navigation") + if navigation is not None and not isinstance(navigation, str): + raise ValueError("navigation must be a string") + + timestamp = json.get("timestamp") + if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: + raise ValueError("timestamp is required and must be a non-negative integer") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + + suggested_filename = json.get("suggestedFilename") + if suggested_filename is None or not isinstance(suggested_filename, str): + raise ValueError("suggestedFilename is required and must be a string") + return cls( - context=json.get("context"), - navigation=json.get("navigation"), - timestamp=json.get("timestamp"), - url=json.get("url"), - suggested_filename=json.get("suggestedFilename"), + context=context, + navigation=navigation, + timestamp=timestamp, + url=url, + suggested_filename=suggested_filename, ) @@ -186,12 +249,32 @@ def from_json(cls, json: dict) -> "UserPromptOpenedParams": ------- UserPromptOpenedParams: A new instance of UserPromptOpenedParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + handler = json.get("handler") + if handler is None or not isinstance(handler, str): + raise ValueError("handler is required and must be a string") + + message = json.get("message") + if message is None or not isinstance(message, str): + raise ValueError("message is required and must be a string") + + type_value = json.get("type") + if type_value is None or not isinstance(type_value, str): + raise ValueError("type is required and must be a string") + + default_value = json.get("defaultValue") + if default_value is not None and not isinstance(default_value, str): + raise ValueError("defaultValue must be a string if provided") + return cls( - context=json.get("context"), - handler=json.get("handler"), - message=json.get("message"), - type=json.get("type"), - default_value=json.get("defaultValue"), + context=context, + handler=handler, + message=message, + type=type_value, + default_value=default_value, ) @@ -222,11 +305,27 @@ def from_json(cls, json: dict) -> "UserPromptClosedParams": ------- UserPromptClosedParams: A new instance of UserPromptClosedParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + accepted = json.get("accepted") + if accepted is None or not isinstance(accepted, bool): + raise ValueError("accepted is required and must be a boolean") + + type_value = json.get("type") + if type_value is None or not isinstance(type_value, str): + raise ValueError("type is required and must be a string") + + user_text = json.get("userText") + if user_text is not None and not isinstance(user_text, str): + raise ValueError("userText must be a string if provided") + return cls( - context=json.get("context"), - accepted=json.get("accepted"), - type=json.get("type"), - user_text=json.get("userText"), + context=context, + accepted=accepted, + type=type_value, + user_text=user_text, ) @@ -253,9 +352,17 @@ def from_json(cls, json: dict) -> "HistoryUpdatedParams": ------- HistoryUpdatedParams: A new instance of HistoryUpdatedParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + return cls( - context=json.get("context"), - url=json.get("url"), + context=context, + url=url, ) @@ -278,7 +385,11 @@ def from_json(cls, json: dict) -> "BrowsingContextEvent": ------- BrowsingContextEvent: A new instance of BrowsingContextEvent. """ - return cls(event_class=json.get("event_class"), **json) + event_class = json.get("event_class") + if event_class is None or not isinstance(event_class, str): + raise ValueError("event_class is required and must be a string") + + return cls(event_class=event_class, **json) class BrowsingContext: @@ -339,7 +450,7 @@ def capture_screenshot( ------- str: The Base64-encoded screenshot. """ - params = {"context": context, "origin": origin} + params: dict[str, Any] = {"context": context, "origin": origin} if format is not None: params["format"] = format if clip is not None: @@ -383,7 +494,7 @@ def create( ------- str: The browsing context ID of the created navigable. """ - params = {"type": type} + params: dict[str, Any] = {"type": type} if reference_context is not None: params["referenceContext"] = reference_context if background is not None: @@ -411,7 +522,7 @@ def get_tree( ------- List[BrowsingContextInfo]: A list of browsing context information. """ - params = {} + params: dict[str, Any] = {} if max_depth is not None: params["maxDepth"] = max_depth if root is not None: @@ -434,7 +545,7 @@ def handle_user_prompt( accept: Whether to accept the prompt. user_text: The text to enter in the prompt. """ - params = {"context": context} + params: dict[str, Any] = {"context": context} if accept is not None: params["accept"] = accept if user_text is not None: @@ -464,7 +575,7 @@ def locate_nodes( ------- List[Dict]: A list of nodes. """ - params = {"context": context, "locator": locator} + params: dict[str, Any] = {"context": context, "locator": locator} if max_node_count is not None: params["maxNodeCount"] = max_node_count if serialization_options is not None: @@ -564,7 +675,7 @@ def reload( ------- Dict: A dictionary containing the navigation result. """ - params = {"context": context} + params: dict[str, Any] = {"context": context} if ignore_cache is not None: params["ignoreCache"] = ignore_cache if wait is not None: @@ -593,7 +704,7 @@ def set_viewport( ------ Exception: If the browsing context is not a top-level traversable. """ - params = {} + params: dict[str, Any] = {} if context is not None: params["context"] = context if viewport is not None: @@ -621,7 +732,7 @@ def traverse_history(self, context: str, delta: int) -> dict: result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) return result - def _on_event(self, event_name: str, callback: callable) -> int: + def _on_event(self, event_name: str, callback: Callable) -> int: """Set a callback function to subscribe to a browsing context event. Parameters: @@ -665,7 +776,7 @@ def _callback(event_data): return callback_id - def add_event_handler(self, event: str, callback: callable, contexts: Optional[list[str]] = None) -> int: + def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: """Add an event handler to the browsing context. Parameters: @@ -710,15 +821,18 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: except KeyError: raise Exception(f"Event {event} not found") - event = BrowsingContextEvent(event_name) + event_obj = BrowsingContextEvent(event_name) - self.conn.remove_callback(event, callback_id) - self.subscriptions[event_name].remove(callback_id) - if len(self.subscriptions[event_name]) == 0: - params = {"events": [event_name]} - session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) - del self.subscriptions[event_name] + self.conn.remove_callback(event_obj, callback_id) + if event_name in self.subscriptions: + callbacks = self.subscriptions[event_name] + if callback_id in callbacks: + callbacks.remove(callback_id) + if not callbacks: + params = {"events": [event_name]} + session = Session(self.conn) + self.conn.execute(session.unsubscribe(**params)) + del self.subscriptions[event_name] def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" diff --git a/py/selenium/webdriver/common/utils.py b/py/selenium/webdriver/common/utils.py index b04e2b0e40c30..f022b617ec59f 100644 --- a/py/selenium/webdriver/common/utils.py +++ b/py/selenium/webdriver/common/utils.py @@ -64,12 +64,12 @@ def find_connectable_ip(host: Union[str, bytes, bytearray, None], port: Optional for family, _, _, _, sockaddr in addrinfos: connectable = True if port: - connectable = is_connectable(port, sockaddr[0]) + connectable = is_connectable(port, str(sockaddr[0])) if connectable and family == socket.AF_INET: - return sockaddr[0] + return str(sockaddr[0]) if connectable and not ip and family == socket.AF_INET6: - ip = sockaddr[0] + ip = str(sockaddr[0]) return ip @@ -132,7 +132,7 @@ def keys_to_typing(value: Iterable[AnyKey]) -> list[str]: for val in value: if isinstance(val, Keys): # Todo: Does this even work? - characters.append(val) + characters.append(str(val)) elif isinstance(val, (int, float)): characters.extend(str(val)) else: