diff --git a/py/selenium/webdriver/common/options.py b/py/selenium/webdriver/common/options.py index c43e0fb0a912e..d45804858f8d6 100644 --- a/py/selenium/webdriver/common/options.py +++ b/py/selenium/webdriver/common/options.py @@ -19,10 +19,11 @@ from abc import ABCMeta from abc import abstractmethod from enum import Enum + +from typing import Any from typing import List from typing import Optional -from selenium.common.exceptions import InvalidArgumentException from selenium.webdriver.common.proxy import Proxy @@ -43,23 +44,28 @@ class PageLoadStrategy(str, Enum): class _BaseOptionsDescriptor: - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls): + def __get__(self, obj: object, cls: type[object]): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") if self.name == "enableBidi": # whether BiDi is or will be enabled - value = obj._caps.get("webSocketUrl") + value = obj.capabilities.get("webSocketUrl") return value is True or isinstance(value, str) if self.name == "webSocketUrl": # Return socket url or None if not created yet - value = obj._caps.get(self.name) + value = obj.capabilities.get(self.name) return None if not isinstance(value, str) else value if self.name in ("acceptInsecureCerts", "strictFileInteractability", "setWindowRect", "se:downloadsEnabled"): - return obj._caps.get(self.name, False) - return obj._caps.get(self.name) + return obj.capabilities.get(self.name, False) + return obj.capabilities.get(self.name) + + def __set__(self, obj: object, value: Any): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") - def __set__(self, obj, value): if self.name == "enableBidi": obj.set_capability("webSocketUrl", value) else: @@ -73,17 +79,20 @@ class _PageLoadStrategyDescriptor: :param strategy: the strategy corresponding to a document readiness state """ - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls): - return obj._caps.get(self.name) + def __get__(self, obj: object, cls: type[object]): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") + return obj.capabilities.get(self.name) - def __set__(self, obj, value): - if value in ("normal", "eager", "none"): - obj.set_capability(self.name, value) - else: + def __set__(self, obj: object, value: str): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") + if value not in ("normal", "eager", "none"): raise ValueError("Strategy can only be one of the following: normal, eager, none") + obj.set_capability(self.name, value) class _UnHandledPromptBehaviorDescriptor: @@ -96,20 +105,23 @@ class _UnHandledPromptBehaviorDescriptor: :returns: Values for implicit timeout, pageLoad timeout and script timeout if set (in milliseconds) """ - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls): - return obj._caps.get(self.name) + def __get__(self, obj: object, cls: type[object]): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") + return obj.capabilities.get(self.name) - def __set__(self, obj, value): - if value in ("dismiss", "accept", "dismiss and notify", "accept and notify", "ignore"): - obj.set_capability(self.name, value) - else: + def __set__(self, obj: object, value: str): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") + if value not in ("dismiss", "accept", "dismiss and notify", "accept and notify", "ignore"): raise ValueError( "Behavior can only be one of the following: dismiss, accept, dismiss and notify, " "accept and notify, ignore" ) + obj.set_capability(self.name, value) class _TimeoutsDescriptor: @@ -121,13 +133,17 @@ class _TimeoutsDescriptor: :returns: Values for implicit timeout, pageLoad timeout and script timeout if set (in milliseconds) """ - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls): - return obj._caps.get(self.name) + def __get__(self, obj: object, cls: type[object]): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") + return obj.capabilities.get(self.name) - def __set__(self, obj, value): + def __set__(self, obj: object, value: dict[str, Any]): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") if all(x in ("implicit", "pageLoad", "script") for x in value.keys()): obj.set_capability(self.name, value) else: @@ -137,17 +153,19 @@ def __set__(self, obj, value): class _ProxyDescriptor: """:Returns: Proxy if set, otherwise None.""" - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls): + def __get__(self, obj: object, cls: type[object]): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") return obj._proxy - def __set__(self, obj, value): - if not isinstance(value, Proxy): - raise InvalidArgumentException("Only Proxy objects can be passed in.") + def __set__(self, obj: object, value: Proxy): + if not isinstance(obj, BaseOptions): + raise ValueError("Invalid object: Expected an instance of BaseOptions.") obj._proxy = value - obj._caps[self.name] = value.to_capabilities() + obj.capabilities[self.name] = value.to_capabilities() class BaseOptions(metaclass=ABCMeta): @@ -422,7 +440,7 @@ class BaseOptions(metaclass=ABCMeta): def __init__(self) -> None: super().__init__() self._caps = self.default_capabilities - self._proxy = None + self._proxy: Proxy | None = None self.set_capability("pageLoadStrategy", PageLoadStrategy.normal) self.mobile_options = None self._ignore_local_proxy = False @@ -431,7 +449,7 @@ def __init__(self) -> None: def capabilities(self): return self._caps - def set_capability(self, name, value) -> None: + def set_capability(self, name: Any, value: Any) -> None: """Sets a capability.""" self._caps[name] = value @@ -455,12 +473,12 @@ def enable_mobile( self.mobile_options["androidDeviceSerial"] = device_serial @abstractmethod - def to_capabilities(self): + def to_capabilities(self) -> dict[Any, Any]: """Convert options into capabilities dictionary.""" @property @abstractmethod - def default_capabilities(self): + def default_capabilities(self) -> dict[Any, Any]: """Return minimal capabilities necessary as a dictionary.""" def ignore_local_proxy_environment_variables(self) -> None: @@ -476,8 +494,9 @@ class ArgOptions(BaseOptions): def __init__(self) -> None: super().__init__() - self._arguments: List[str] = [] + self._arguments: List[str] = [] + @property def arguments(self): """:Returns: A list of arguments needed for the browser.""" @@ -512,5 +531,5 @@ def to_capabilities(self): return self._caps @property - def default_capabilities(self): + def default_capabilities(self) -> dict[Any, Any]: return {} diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index 4997b3b7c0bf6..704af91c151cd 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -20,7 +20,9 @@ import string import warnings from base64 import b64encode +from typing import Any from typing import Optional +from typing import TypeVar from urllib import parse from urllib.parse import urlparse @@ -35,6 +37,11 @@ LOGGER = logging.getLogger(__name__) +# TODO: Replace with 'Self' when Python 3.11+ is supported. +# from typing import Self + +RemoteConnectionType = TypeVar("RemoteConnectionType", bound="RemoteConnection") + remote_commands = { Command.NEW_SESSION: ("POST", "/session"), Command.QUIT: ("DELETE", "/session/$sessionId"), @@ -154,6 +161,7 @@ class RemoteConnection: _timeout = socket.getdefaulttimeout() _ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where() + _client_config: Optional[ClientConfig] = None system = platform.system().lower() @@ -169,7 +177,7 @@ def client_config(self): return self._client_config @classmethod - def get_timeout(cls): + def get_timeout(cls) -> float | int | None: """:Returns: Timeout value in seconds for all http requests made to the @@ -183,7 +191,7 @@ def get_timeout(cls): return cls._client_config.timeout @classmethod - def set_timeout(cls, timeout): + def set_timeout(cls, timeout: int | float): """Override the default timeout. :Args: @@ -207,7 +215,7 @@ def reset_timeout(cls): cls._client_config.reset_timeout() @classmethod - def get_certificate_bundle_path(cls): + def get_certificate_bundle_path(cls) -> str: """:Returns: Paths of the .pem encoded certificate to verify connection to @@ -222,7 +230,7 @@ def get_certificate_bundle_path(cls): return cls._client_config.ca_certs @classmethod - def set_certificate_bundle_path(cls, path): + def set_certificate_bundle_path(cls, path: str): """Set the path to the certificate bundle to verify connection to command executor. Can also be set to None to disable certificate validation. @@ -238,7 +246,7 @@ def set_certificate_bundle_path(cls, path): cls._client_config.ca_certs = path @classmethod - def get_remote_connection_headers(cls, parsed_url, keep_alive=False): + def get_remote_connection_headers(cls, parsed_url: str, keep_alive: bool = False) -> dict[str, Any]: """Get headers for remote request. :Args: @@ -309,7 +317,7 @@ def __init__( keep_alive: Optional[bool] = True, ignore_proxy: Optional[bool] = False, ignore_certificates: Optional[bool] = False, - init_args_for_pool_manager: Optional[dict] = None, + init_args_for_pool_manager: Optional[dict[Any, Any]] = None, client_config: Optional[ClientConfig] = None, ): self._client_config = client_config or ClientConfig( @@ -370,7 +378,7 @@ def __init__( extra_commands = {} - def add_command(self, name, method, url): + def add_command(self, name: str, method: str, url: str): """Register a new command.""" self._commands[name] = (method, url) @@ -378,7 +386,7 @@ def get_command(self, name: str): """Retrieve a command if it exists.""" return self._commands.get(name) - def execute(self, command, params): + def execute(self, command: str, params: dict[Any, Any]) -> dict[str, Any]: """Send a command to the remote server. Any path substitutions required for the URL mapped to the command should be @@ -403,7 +411,7 @@ def execute(self, command, params): LOGGER.debug("%s %s %s", command_info[0], url, str(trimmed)) return self._request(command_info[0], url, body=data) - def _request(self, method, url, body=None): + def _request(self, method: str, url: str, body: str | None = None) -> dict[Any, Any]: """Send an HTTP request to the remote server. :Args: @@ -470,7 +478,7 @@ def close(self): if hasattr(self, "_conn"): self._conn.clear() - def _trim_large_entries(self, input_dict, max_length=100): + def _trim_large_entries(self, input_dict: dict[Any, Any], max_length: int = 100) -> dict[str, str]: """Truncate string values in a dictionary if they exceed max_length. :param dict: Dictionary with potentially large values diff --git a/py/selenium/webdriver/remote/script_key.py b/py/selenium/webdriver/remote/script_key.py index 930b699c7d79b..4633b06cba0b6 100644 --- a/py/selenium/webdriver/remote/script_key.py +++ b/py/selenium/webdriver/remote/script_key.py @@ -19,14 +19,14 @@ class ScriptKey: - def __init__(self, id=None): + def __init__(self, id: uuid.UUID | str | None = None): self._id = id or uuid.uuid4() @property def id(self): return self._id - def __eq__(self, other): + def __eq__(self, other: object): return self._id == other def __repr__(self) -> str: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 336b05ea0107f..c4ef086d8abc4 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -17,6 +17,8 @@ """The WebDriver implementation.""" +from __future__ import annotations + import base64 import contextlib import copy @@ -37,26 +39,25 @@ from typing import List from typing import Optional from typing import Type -from typing import Union +from typing import TypeVar +from uuid import UUID from selenium.common.exceptions import InvalidArgumentException from selenium.common.exceptions import JavascriptException from selenium.common.exceptions import NoSuchCookieException from selenium.common.exceptions import NoSuchElementException from selenium.common.exceptions import WebDriverException -from selenium.webdriver.common.bidi.browser import Browser -from selenium.webdriver.common.bidi.browsing_context import BrowsingContext -from selenium.webdriver.common.bidi.network import Network from selenium.webdriver.common.bidi.script import Script -from selenium.webdriver.common.bidi.session import Session -from selenium.webdriver.common.bidi.storage import Storage from selenium.webdriver.common.by import By +from selenium.webdriver.common.by import ByType from selenium.webdriver.common.options import ArgOptions from selenium.webdriver.common.options import BaseOptions from selenium.webdriver.common.print_page_options import PrintOptions from selenium.webdriver.common.timeouts import Timeouts from selenium.webdriver.common.virtual_authenticator import Credential -from selenium.webdriver.common.virtual_authenticator import VirtualAuthenticatorOptions +from selenium.webdriver.common.virtual_authenticator import ( + VirtualAuthenticatorOptions, +) from selenium.webdriver.common.virtual_authenticator import ( required_virtual_authenticator, ) @@ -81,6 +82,7 @@ cdp = None devtools = None +_TValue = TypeVar("_TValue") def import_cdp(): @@ -89,7 +91,7 @@ def import_cdp(): cdp = import_module("selenium.webdriver.common.bidi.cdp") -def _create_caps(caps): +def _create_caps(caps: dict[Any, Any]) -> dict[Any, Any]: """Makes a W3C alwaysMatch capabilities object. Filters out capability names that are not in the W3C spec. Spec-compliant @@ -111,23 +113,43 @@ def _create_caps(caps): def get_remote_connection( - capabilities: dict, - command_executor: Union[str, RemoteConnection], + capabilities: dict[Any, Any], + command_executor: str | RemoteConnection, keep_alive: bool, ignore_local_proxy: bool, client_config: Optional[ClientConfig] = None, ) -> RemoteConnection: if isinstance(command_executor, str): - client_config = client_config or ClientConfig(remote_server_addr=command_executor) + client_config = client_config or ClientConfig( + remote_server_addr=command_executor + ) client_config.remote_server_addr = command_executor command_executor = RemoteConnection(client_config=client_config) - from selenium.webdriver.chrome.remote_connection import ChromeRemoteConnection + from selenium.webdriver.chrome.remote_connection import ( + ChromeRemoteConnection, + ) from selenium.webdriver.edge.remote_connection import EdgeRemoteConnection - from selenium.webdriver.firefox.remote_connection import FirefoxRemoteConnection - from selenium.webdriver.safari.remote_connection import SafariRemoteConnection + from selenium.webdriver.firefox.remote_connection import ( + FirefoxRemoteConnection, + ) + from selenium.webdriver.safari.remote_connection import ( + SafariRemoteConnection, + ) - candidates = [ChromeRemoteConnection, EdgeRemoteConnection, SafariRemoteConnection, FirefoxRemoteConnection] - handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection) + candidates: list[type[RemoteConnection]] = [ + ChromeRemoteConnection, + EdgeRemoteConnection, + SafariRemoteConnection, + FirefoxRemoteConnection, + ] + handler = next( + ( + c + for c in candidates + if c.browser_name == capabilities.get("browserName") + ), + RemoteConnection, + ) return handler( remote_server_addr=command_executor, @@ -137,7 +159,7 @@ def get_remote_connection( ) -def create_matches(options: List[BaseOptions]) -> Dict: +def create_matches(options: List[BaseOptions]) -> dict[Any, Any]: capabilities = {"capabilities": {}} opts = [] for opt in options: @@ -198,10 +220,10 @@ class WebDriver(BaseWebDriver): def __init__( self, - command_executor: Union[str, RemoteConnection] = "http://127.0.0.1:4444", + command_executor: str | RemoteConnection = "http://127.0.0.1:4444", keep_alive: bool = True, file_detector: Optional[FileDetector] = None, - options: Optional[Union[BaseOptions, List[BaseOptions]]] = None, + options: BaseOptions | List[BaseOptions] | None = None, locator_converter: Optional[LocatorConverter] = None, web_element_cls: Optional[type] = None, client_config: Optional[ClientConfig] = None, @@ -285,7 +307,9 @@ def __exit__( self.quit() @contextmanager - def file_detector_context(self, file_detector_class, *args, **kwargs): + def file_detector_context( + self, file_detector_class: type[object], *args: Any, **kwargs: Any + ): """Overrides the current file detector (if necessary) in limited context. Ensures the original file detector is set afterwards. @@ -346,7 +370,7 @@ def stop_client(self): """ pass - def start_session(self, capabilities: dict) -> None: + def start_session(self, capabilities: dict[Any, Any]) -> None: """Creates a new session with the desired capabilities. Parameters: @@ -365,10 +389,13 @@ def start_session(self, capabilities: dict) -> None: self.service.stop() raise - def _wrap_value(self, value): + def _wrap_value( + self, value: _TValue + ) -> list[Any] | dict[Any, Any] | _TValue: if isinstance(value, dict): - converted = {} - for key, val in value.items(): + converted: dict[Any, Any] = {} + value_dict: dict[Any, Any] = value + for key, val in value_dict.items(): converted[key] = self._wrap_value(val) return converted if isinstance(value, self._web_element_cls): @@ -376,27 +403,36 @@ def _wrap_value(self, value): if isinstance(value, self._shadowroot_cls): return {"shadow-6066-11e4-a52e-4f735466cecf": value.id} if isinstance(value, list): - return list(self._wrap_value(item) for item in value) + value_list: list[Any] = value + return list(self._wrap_value(item) for item in value_list) return value def create_web_element(self, element_id: str) -> WebElement: """Creates a web element with the specified `element_id`.""" return self._web_element_cls(self, element_id) - def _unwrap_value(self, value): + def _unwrap_value( + self, value: _TValue + ) -> WebElement | ShadowRoot | dict[Any, Any] | list[Any] | _TValue: if isinstance(value, dict): + value_dict: dict[Any, Any] = value if "element-6066-11e4-a52e-4f735466cecf" in value: - return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) + return self.create_web_element( + value_dict["element-6066-11e4-a52e-4f735466cecf"] + ) if "shadow-6066-11e4-a52e-4f735466cecf" in value: - return self._shadowroot_cls(self, value["shadow-6066-11e4-a52e-4f735466cecf"]) - for key, val in value.items(): + return self._shadowroot_cls( + self, value_dict["shadow-6066-11e4-a52e-4f735466cecf"] + ) + for key, val in value_dict.items(): value[key] = self._unwrap_value(val) - return value + return value_dict if isinstance(value, list): - return list(self._unwrap_value(item) for item in value) + value_list: list[Any] = value + return list(self._unwrap_value(item) for item in value_list) return value - def execute_cdp_cmd(self, cmd: str, cmd_args: dict): + def execute_cdp_cmd(self, cmd: str, cmd_args: dict[Any, Any]): """Execute Chrome Devtools Protocol command and get returned result The command and command args should follow chrome devtools protocol domains/commands, refer to link @@ -421,7 +457,9 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): >>> driver.execute_cdp_cmd('Network.getResponseBody', {'requestId': requestId}) """ - return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})["value"] + return self.execute( + "executeCdpCommand", {"cmd": cmd, "params": cmd_args} + )["value"] def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None) -> dict[str, Any]: """Sends a command to be executed by a command.CommandExecutor. @@ -449,7 +487,9 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None) response = self.command_executor.execute(driver_command, params) if response: self.error_handler.check_response(response) - response["value"] = self._unwrap_value(response.get("value", None)) + response["value"] = self._unwrap_value( + response.get("value", None) + ) return response # If the server doesn't send a response, assume the command was # a success @@ -486,7 +526,9 @@ def title(self) -> str: """ return self.execute(Command.GET_TITLE).get("value", "") - def pin_script(self, script: str, script_key=None) -> ScriptKey: + def pin_script( + self, script: str, script_key: ScriptKey | UUID | str | None = None + ) -> ScriptKey: """Store common javascript scripts to be executed later by a unique hashable ID. @@ -494,7 +536,11 @@ def pin_script(self, script: str, script_key=None) -> ScriptKey: -------- >>> script = "return document.getElementById('foo').value" """ - script_key_instance = ScriptKey(script_key) + if isinstance(script_key, ScriptKey): + script_key_instance = script_key + else: + script_key_instance = ScriptKey(script_key) + self.pinned_scripts[script_key_instance.id] = script return script_key_instance @@ -508,7 +554,9 @@ def unpin(self, script_key: ScriptKey) -> None: try: self.pinned_scripts.pop(script_key.id) except KeyError: - raise KeyError(f"No script with key: {script_key} existed in {self.pinned_scripts}") from None + raise KeyError( + f"No script with key: {script_key} existed in {self.pinned_scripts}" + ) from None def get_pinned_scripts(self) -> List[str]: """Return a list of all pinned scripts. @@ -519,7 +567,7 @@ def get_pinned_scripts(self) -> List[str]: """ return list(self.pinned_scripts) - def execute_script(self, script: str, *args): + def execute_script(self, script: str, *args: Any): """Synchronously Executes JavaScript in the current window/frame. Parameters: @@ -547,9 +595,11 @@ def execute_script(self, script: str, *args): converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute( + command, {"script": script, "args": converted_args} + )["value"] - def execute_async_script(self, script: str, *args): + def execute_async_script(self, script: str, *args: Any): """Asynchronously Executes JavaScript in the current window/frame. Parameters: @@ -569,7 +619,9 @@ def execute_async_script(self, script: str, *args): converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT_ASYNC - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute( + command, {"script": script, "args": converted_args} + )["value"] @property def current_url(self) -> str: @@ -624,7 +676,7 @@ def current_window_handle(self) -> str: return self.execute(Command.W3C_GET_CURRENT_WINDOW_HANDLE)["value"] @property - def window_handles(self) -> List[str]: + def window_handles(self) -> list[str]: """Returns the handles of all windows within the current session. Example: @@ -722,7 +774,7 @@ def refresh(self) -> None: self.execute(Command.REFRESH) # Options - def get_cookies(self) -> List[dict]: + def get_cookies(self) -> list[dict[Any, Any]]: """Returns a set of dictionaries, corresponding to cookies visible in the current session. @@ -736,7 +788,7 @@ def get_cookies(self) -> List[dict]: """ return self.execute(Command.GET_ALL_COOKIES)["value"] - def get_cookie(self, name) -> Optional[Dict]: + def get_cookie(self, name: str) -> Optional[dict[Any, Any]]: """Get a single cookie by name. Raises ValueError if the name is empty or whitespace. Returns the cookie if found, None if not. @@ -752,7 +804,7 @@ def get_cookie(self, name) -> Optional[Dict]: return None - def delete_cookie(self, name) -> None: + def delete_cookie(self, name: str) -> None: """Deletes a single cookie with the given name. Raises ValueError if the name is empty or whitespace. @@ -776,7 +828,7 @@ def delete_all_cookies(self) -> None: """ self.execute(Command.DELETE_ALL_COOKIES) - def add_cookie(self, cookie_dict) -> None: + def add_cookie(self, cookie_dict: dict[Any, Any]) -> None: """Adds a cookie to your current session. Parameters: @@ -814,7 +866,10 @@ def implicitly_wait(self, time_to_wait: float) -> None: -------- >>> driver.implicitly_wait(30) """ - self.execute(Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, + {"implicit": int(float(time_to_wait) * 1000)}, + ) def set_script_timeout(self, time_to_wait: float) -> None: """Set the amount of time that the script should wait during an @@ -829,7 +884,9 @@ def set_script_timeout(self, time_to_wait: float) -> None: -------- >>> driver.set_script_timeout(30) """ - self.execute(Command.SET_TIMEOUTS, {"script": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"script": int(float(time_to_wait) * 1000)} + ) def set_page_load_timeout(self, time_to_wait: float) -> None: """Set the amount of time to wait for a page load to complete before @@ -845,9 +902,15 @@ def set_page_load_timeout(self, time_to_wait: float) -> None: >>> driver.set_page_load_timeout(30) """ try: - self.execute(Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, + {"pageLoad": int(float(time_to_wait) * 1000)}, + ) except WebDriverException: - self.execute(Command.SET_TIMEOUTS, {"ms": float(time_to_wait) * 1000, "type": "page load"}) + self.execute( + Command.SET_TIMEOUTS, + {"ms": float(time_to_wait) * 1000, "type": "page load"}, + ) @property def timeouts(self) -> Timeouts: @@ -871,7 +934,7 @@ def timeouts(self) -> Timeouts: return Timeouts(**timeouts) @timeouts.setter - def timeouts(self, timeouts) -> None: + def timeouts(self, timeouts: Timeouts) -> None: """Set all timeouts for the session. This will override any previously set timeouts. @@ -883,7 +946,9 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element(self, by=By.ID, value: Optional[str] = None) -> WebElement: + def find_element( + self, by: ByType = By.ID, value: Optional[str] = None + ) -> WebElement: """Find an element given a By strategy and locator. Parameters: @@ -914,12 +979,18 @@ def find_element(self, by=By.ID, value: Optional[str] = None) -> WebElement: if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: - raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") + raise NoSuchElementException( + f"Cannot locate relative element with: {by.root}" + ) return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] + return self.execute( + Command.FIND_ELEMENT, {"using": by, "value": value} + )["value"] - def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElement]: + def find_elements( + self, by: ByType = By.ID, value: Optional[str] = None + ) -> List[WebElement]: """Find elements given a By strategy and locator. Parameters: @@ -949,16 +1020,23 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElemen if isinstance(by, RelativeBy): _pkg = ".".join(__name__.split(".")[:-1]) - raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8") + raw_function = pkgutil.get_data(_pkg, "findElements.js").decode( + "utf8" + ) find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 - return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] + return ( + self.execute( + Command.FIND_ELEMENTS, {"using": by, "value": value} + )["value"] + or [] + ) @property - def capabilities(self) -> dict: + def capabilities(self) -> dict[Any, Any]: """Returns the drivers current capabilities being used. Example: @@ -967,7 +1045,7 @@ def capabilities(self) -> dict: """ return self.caps - def get_screenshot_as_file(self, filename) -> bool: + def get_screenshot_as_file(self, filename: str) -> bool: """Saves a screenshot of the current window to a PNG image file. Returns False if there is any IOError, else returns True. Use full paths in your filename. @@ -998,7 +1076,7 @@ def get_screenshot_as_file(self, filename) -> bool: del png return True - def save_screenshot(self, filename) -> bool: + def save_screenshot(self, filename: str) -> bool: """Saves a screenshot of the current window to a PNG image file. Returns False if there is any IOError, else returns True. Use full paths in your filename. @@ -1034,7 +1112,9 @@ def get_screenshot_as_base64(self) -> str: """ return self.execute(Command.SCREENSHOT)["value"] - def set_window_size(self, width, height, windowHandle: str = "current") -> None: + def set_window_size( + self, width: int, height: int, windowHandle: str = "current" + ) -> None: """Sets the width and height of the current window. (window.resizeTo) Parameters: @@ -1052,7 +1132,9 @@ def set_window_size(self, width, height, windowHandle: str = "current") -> None: self._check_if_window_handle_is_current(windowHandle) self.set_window_rect(width=int(width), height=int(height)) - def get_window_size(self, windowHandle: str = "current") -> dict: + def get_window_size( + self, windowHandle: str = "current" + ) -> dict[str, int]: """Gets the width and height of the current window. Example: @@ -1068,7 +1150,9 @@ def get_window_size(self, windowHandle: str = "current") -> dict: return {k: size[k] for k in ("width", "height")} - def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: + def set_window_position( + self, x: float, y: float, windowHandle: str = "current" + ) -> dict[str, int | float]: """Sets the x,y position of the current window. (window.moveTo) Parameters: @@ -1086,7 +1170,9 @@ def set_window_position(self, x: float, y: float, windowHandle: str = "current") self._check_if_window_handle_is_current(windowHandle) return self.set_window_rect(x=int(x), y=int(y)) - def get_window_position(self, windowHandle="current") -> dict: + def get_window_position( + self, windowHandle: str = "current" + ) -> dict[str, int | float]: """Gets the x,y position of the current window. Example: @@ -1102,9 +1188,12 @@ def get_window_position(self, windowHandle="current") -> dict: def _check_if_window_handle_is_current(self, windowHandle: str) -> None: """Warns if the window handle is not equal to `current`.""" if windowHandle != "current": - warnings.warn("Only 'current' window is supported for W3C compatible browsers.", stacklevel=2) + warnings.warn( + "Only 'current' window is supported for W3C compatible browsers.", + stacklevel=2, + ) - def get_window_rect(self) -> dict: + def get_window_rect(self) -> dict[str, int | float]: """Gets the x, y coordinates of the window as well as height and width of the current window. @@ -1114,7 +1203,13 @@ def get_window_rect(self) -> dict: """ return self.execute(Command.GET_WINDOW_RECT)["value"] - def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: + def set_window_rect( + self, + x: Optional[float] = None, + y: Optional[float] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> dict[str, int | float]: """Sets the x, y coordinates of the window as well as height and width of the current window. This method is only supported for W3C compatible browsers; other browsers should use `set_window_position` and @@ -1128,16 +1223,21 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: """ if (x is None and y is None) and (not height and not width): - raise InvalidArgumentException("x and y or height and width need values") + raise InvalidArgumentException( + "x and y or height and width need values" + ) - return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] + return self.execute( + Command.SET_WINDOW_RECT, + {"x": x, "y": y, "width": width, "height": height}, + )["value"] @property def file_detector(self) -> FileDetector: return self._file_detector @file_detector.setter - def file_detector(self, detector) -> None: + def file_detector(self, detector: Any) -> None: """Set the file detector to be used when sending keyboard input. By default, this is set to a file detector that does nothing. @@ -1151,9 +1251,13 @@ def file_detector(self, detector) -> None: - The detector to use. Must not be None. """ if not detector: - raise WebDriverException("You may not set a file detector that is null") + raise WebDriverException( + "You may not set a file detector that is null" + ) if not isinstance(detector, FileDetector): - raise WebDriverException("Detector has to be instance of FileDetector") + raise WebDriverException( + "Detector has to be instance of FileDetector" + ) self._file_detector = detector @property @@ -1167,7 +1271,7 @@ def orientation(self): return self.execute(Command.GET_SCREEN_ORIENTATION)["value"] @orientation.setter - def orientation(self, value) -> None: + def orientation(self, value: str) -> None: """Sets the current orientation of the device. Parameters: @@ -1181,9 +1285,41 @@ def orientation(self, value) -> None: """ allowed_values = ["LANDSCAPE", "PORTRAIT"] if value.upper() in allowed_values: - self.execute(Command.SET_SCREEN_ORIENTATION, {"orientation": value}) + self.execute( + Command.SET_SCREEN_ORIENTATION, {"orientation": value} + ) else: - raise WebDriverException("You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'") + raise WebDriverException( + "You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'" + ) + + @property + def log_types(self): + """Gets a list of the available log types. This only works with w3c + compliant browsers. + + Example: + -------- + >>> driver.log_types + """ + return self.execute(Command.GET_AVAILABLE_LOG_TYPES)["value"] + + def get_log(self, log_type: str): + """Gets the log for a given log type. + + Parameters: + ---------- + log_type : str + - Type of log that which will be returned + + Example: + -------- + >>> driver.get_log('browser') + >>> driver.get_log('driver') + >>> driver.get_log('client') + >>> driver.get_log('server') + """ + return self.execute(Command.GET_LOG, {"type": log_type})["value"] def start_devtools(self): global devtools @@ -1201,15 +1337,23 @@ def start_devtools(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) devtools = cdp.import_devtools(version) if self.caps["browserName"].lower() == "firefox": - raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.") + raise RuntimeError( + "CDP support for Firefox has been removed. Please switch to WebDriver BiDi." + ) self._websocket_connection = WebSocketConnection(ws_url) - targets = self._websocket_connection.execute(devtools.target.get_targets()) + targets = self._websocket_connection.execute( + devtools.target.get_targets() + ) target_id = targets[0].target_id - session = self._websocket_connection.execute(devtools.target.attach_to_target(target_id, True)) + session = self._websocket_connection.execute( + devtools.target.attach_to_target(target_id, True) + ) self._websocket_connection.session_id = session return devtools, self._websocket_connection @@ -1224,7 +1368,9 @@ async def bidi_connection(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) devtools = cdp.import_devtools(version) async with cdp.open_cdp(ws_url) as conn: @@ -1247,102 +1393,12 @@ def _start_bidi(self): if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") else: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) self._websocket_connection = WebSocketConnection(ws_url) - @property - def network(self): - if not self._websocket_connection: - self._start_bidi() - - if not hasattr(self, "_network") or self._network is None: - self._network = Network(self._websocket_connection) - - return self._network - - @property - def browser(self): - """Returns a browser module object for BiDi browser commands. - - Returns: - -------- - Browser: an object containing access to BiDi browser commands. - - Examples: - --------- - >>> user_context = driver.browser.create_user_context() - >>> user_contexts = driver.browser.get_user_contexts() - >>> client_windows = driver.browser.get_client_windows() - >>> driver.browser.remove_user_context(user_context) - """ - if not self._websocket_connection: - self._start_bidi() - - if self._browser is None: - self._browser = Browser(self._websocket_connection) - - return self._browser - - @property - def _session(self): - """ - Returns the BiDi session object for the current WebDriver session. - """ - if not self._websocket_connection: - self._start_bidi() - - if self._bidi_session is None: - self._bidi_session = Session(self._websocket_connection) - - return self._bidi_session - - @property - def browsing_context(self): - """Returns a browsing context module object for BiDi browsing context commands. - - Returns: - -------- - BrowsingContext: an object containing access to BiDi browsing context commands. - - Examples: - --------- - >>> context_id = driver.browsing_context.create(type="tab") - >>> driver.browsing_context.navigate(context=context_id, url="https://www.selenium.dev") - >>> driver.browsing_context.capture_screenshot(context=context_id) - >>> driver.browsing_context.close(context_id) - """ - if not self._websocket_connection: - self._start_bidi() - - if self._browsing_context is None: - self._browsing_context = BrowsingContext(self._websocket_connection) - - return self._browsing_context - - @property - def storage(self): - """Returns a storage module object for BiDi storage commands. - - Returns: - -------- - Storage: an object containing access to BiDi storage commands. - - Examples: - --------- - >>> cookie_filter = CookieFilter(name="example") - >>> result = driver.storage.get_cookies(filter=cookie_filter) - >>> driver.storage.set_cookie(cookie=PartialCookie("name", BytesValue(BytesValue.TYPE_STRING, "value"), "domain")) - >>> driver.storage.delete_cookies(filter=CookieFilter(name="example")) - """ - if not self._websocket_connection: - self._start_bidi() - - if self._storage is None: - self._storage = Storage(self._websocket_connection) - - return self._storage - def _get_cdp_details(self): import json @@ -1350,9 +1406,13 @@ def _get_cdp_details(self): http = urllib3.PoolManager() if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") + debugger_address = self.caps.get("goog:chromeOptions").get( + "debuggerAddress" + ) elif self.caps.get("browserName") == "MicrosoftEdge": - debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") + debugger_address = self.caps.get("ms:edgeOptions").get( + "debuggerAddress" + ) res = http.request("GET", f"http://{debugger_address}/json/version") data = json.loads(res.data) @@ -1367,7 +1427,9 @@ def _get_cdp_details(self): return version, websocket_url # Virtual Authenticator Methods - def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> None: + def add_virtual_authenticator( + self, options: VirtualAuthenticatorOptions + ) -> None: """Adds a virtual authenticator with the given options. Example: @@ -1376,7 +1438,9 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non >>> options = VirtualAuthenticatorOptions(protocol="u2f", transport="usb", device_id="myDevice123") >>> driver.add_virtual_authenticator(options) """ - self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] + self._authenticator_id = self.execute( + Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict() + )["value"] @property def virtual_authenticator_id(self) -> str: @@ -1399,7 +1463,10 @@ def remove_virtual_authenticator(self) -> None: -------- >>> driver.remove_virtual_authenticator() """ - self.execute(Command.REMOVE_VIRTUAL_AUTHENTICATOR, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_VIRTUAL_AUTHENTICATOR, + {"authenticatorId": self._authenticator_id}, + ) self._authenticator_id = None @required_virtual_authenticator @@ -1412,7 +1479,13 @@ def add_credential(self, credential: Credential) -> None: >>> credential = Credential(id="user@example.com", password="aPassword") >>> driver.add_credential(credential) """ - self.execute(Command.ADD_CREDENTIAL, {**credential.to_dict(), "authenticatorId": self._authenticator_id}) + self.execute( + Command.ADD_CREDENTIAL, + { + **credential.to_dict(), + "authenticatorId": self._authenticator_id, + }, + ) @required_virtual_authenticator def get_credentials(self) -> List[Credential]: @@ -1422,11 +1495,17 @@ def get_credentials(self) -> List[Credential]: -------- >>> credentials = driver.get_credentials() """ - credential_data = self.execute(Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id}) - return [Credential.from_dict(credential) for credential in credential_data["value"]] + credential_data = self.execute( + Command.GET_CREDENTIALS, + {"authenticatorId": self._authenticator_id}, + ) + return [ + Credential.from_dict(credential) + for credential in credential_data["value"] + ] @required_virtual_authenticator - def remove_credential(self, credential_id: Union[str, bytearray]) -> None: + def remove_credential(self, credential_id: str | bytearray) -> None: """Removes a credential from the authenticator. Example: @@ -1439,7 +1518,11 @@ def remove_credential(self, credential_id: Union[str, bytearray]) -> None: credential_id = urlsafe_b64encode(credential_id).decode() self.execute( - Command.REMOVE_CREDENTIAL, {"credentialId": credential_id, "authenticatorId": self._authenticator_id} + Command.REMOVE_CREDENTIAL, + { + "credentialId": credential_id, + "authenticatorId": self._authenticator_id, + }, ) @required_virtual_authenticator @@ -1450,7 +1533,10 @@ def remove_all_credentials(self) -> None: -------- >>> driver.remove_all_credentials() """ - self.execute(Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_ALL_CREDENTIALS, + {"authenticatorId": self._authenticator_id}, + ) @required_virtual_authenticator def set_user_verified(self, verified: bool) -> None: @@ -1465,9 +1551,15 @@ def set_user_verified(self, verified: bool) -> None: -------- >>> driver.set_user_verified(True) """ - self.execute(Command.SET_USER_VERIFIED, {"authenticatorId": self._authenticator_id, "isUserVerified": verified}) + self.execute( + Command.SET_USER_VERIFIED, + { + "authenticatorId": self._authenticator_id, + "isUserVerified": verified, + }, + ) - def get_downloadable_files(self) -> list: + def get_downloadable_files(self) -> list[str]: """Retrieves the downloadable files as a list of file names. Example: @@ -1475,7 +1567,9 @@ def get_downloadable_files(self) -> list: >>> files = driver.get_downloadable_files() """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] @@ -1496,12 +1590,16 @@ def download_file(self, file_name: str, target_directory: str) -> None: >>> driver.download_file("example.zip", "/path/to/directory") """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) if not os.path.exists(target_directory): os.makedirs(target_directory) - contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"] + contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})[ + "value" + ]["contents"] with tempfile.TemporaryDirectory() as tmp_dir: zip_file = os.path.join(tmp_dir, file_name + ".zip") @@ -1519,7 +1617,9 @@ def delete_downloadable_files(self) -> None: >>> driver.delete_downloadable_files() """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) self.execute(Command.DELETE_DOWNLOADABLE_FILES) @@ -1576,7 +1676,12 @@ def dialog(self): self._require_fedcm_support() return Dialog(self) - def fedcm_dialog(self, timeout=5, poll_frequency=0.5, ignored_exceptions=None): + def fedcm_dialog( + self, + timeout: int = 5, + poll_frequency: float = 0.5, + ignored_exceptions: Any = None, + ): """Waits for and returns the FedCM dialog. Parameters: @@ -1614,5 +1719,10 @@ def _check_fedcm(): except NoAlertPresentException: return None - wait = WebDriverWait(self, timeout, poll_frequency=poll_frequency, ignored_exceptions=ignored_exceptions) + wait = WebDriverWait( + self, + timeout, + poll_frequency=poll_frequency, + ignored_exceptions=ignored_exceptions, + ) return wait.until(lambda _: _check_fedcm()) diff --git a/py/selenium/webdriver/remote/webelement.py b/py/selenium/webdriver/remote/webelement.py index b8d8a32c3f285..6a0a9e3bb8f04 100644 --- a/py/selenium/webdriver/remote/webelement.py +++ b/py/selenium/webdriver/remote/webelement.py @@ -25,12 +25,15 @@ from base64 import encodebytes from hashlib import md5 as md5_hash from io import BytesIO +from typing import Any from typing import List from selenium.common.exceptions import JavascriptException from selenium.common.exceptions import WebDriverException from selenium.webdriver.common.by import By +from selenium.webdriver.common.by import ByType from selenium.webdriver.common.utils import keys_to_typing +from selenium.webdriver.remote.webdriver import WebDriver from .command import Command from .shadowroot import ShadowRoot @@ -72,7 +75,7 @@ class WebElement(BaseWebElement): instance will fail. """ - def __init__(self, parent, id_) -> None: + def __init__(self, parent: WebDriver, id_: str) -> None: self._parent = parent self._id = id_ @@ -157,7 +160,7 @@ def clear(self) -> None: """ self._execute(Command.CLEAR_ELEMENT) - def get_property(self, name) -> str | bool | WebElement | dict: + def get_property(self, name: str) -> str | bool | WebElement | dict[Any, Any]: """Gets the given property of the element. Parameters: @@ -179,7 +182,7 @@ def get_property(self, name) -> str | bool | WebElement | dict: # if we hit an end point that doesn't understand getElementProperty lets fake it return self.parent.execute_script("return arguments[0][arguments[1]]", self, name) - def get_dom_attribute(self, name) -> str: + def get_dom_attribute(self, name: str) -> str: """Gets the given attribute of the element. Unlike :func:`~selenium.webdriver.remote.BaseWebElement.get_attribute`, this method only returns attributes declared in the element's HTML markup. @@ -199,7 +202,7 @@ def get_dom_attribute(self, name) -> str: """ return self._execute(Command.GET_ELEMENT_ATTRIBUTE, {"name": name})["value"] - def get_attribute(self, name) -> str | None: + def get_attribute(self, name: str) -> str | None: """Gets the given attribute or property of the element. This method will first try to return the value of a property with the @@ -344,7 +347,7 @@ def is_displayed(self) -> bool: return self.parent.execute_script(f"/* isDisplayed */return ({isDisplayed_js}).apply(null, arguments);", self) @property - def location_once_scrolled_into_view(self) -> dict: + def location_once_scrolled_into_view(self) -> dict[Any, Any]: """THIS PROPERTY MAY CHANGE WITHOUT WARNING. Use this to discover where on the screen an element is so that we can click it. This method should cause the element to be scrolled into view. @@ -368,7 +371,7 @@ def location_once_scrolled_into_view(self) -> dict: return {"x": round(old_loc["x"]), "y": round(old_loc["y"])} @property - def size(self) -> dict: + def size(self) -> dict[str, int | float]: """The size of the element. Returns: @@ -383,7 +386,7 @@ def size(self) -> dict: new_size = {"height": size["height"], "width": size["width"]} return new_size - def value_of_css_property(self, property_name) -> str: + def value_of_css_property(self, property_name: str) -> str: """The value of a CSS property. Parameters: @@ -402,7 +405,7 @@ def value_of_css_property(self, property_name) -> str: return self._execute(Command.GET_ELEMENT_VALUE_OF_CSS_PROPERTY, {"propertyName": property_name})["value"] @property - def location(self) -> dict: + def location(self) -> dict[str, int | float]: """The location of the element in the renderable canvas. Returns: @@ -418,7 +421,7 @@ def location(self) -> dict: return new_loc @property - def rect(self) -> dict: + def rect(self) -> dict[str, Any]: """A dictionary with the size and location of the element. Returns: @@ -488,7 +491,7 @@ def screenshot_as_png(self) -> bytes: """ return b64decode(self.screenshot_as_base64.encode("ascii")) - def screenshot(self, filename) -> bool: + def screenshot(self, filename: str) -> bool: """Saves a screenshot of the current element to a PNG image file. Returns False if there is any IOError, else returns True. Use full paths in your filename. @@ -548,14 +551,17 @@ def id(self) -> str: """ return self._id - def __eq__(self, element): + def __eq__(self, element: object): + if not isinstance(element, WebElement): + return False + return hasattr(element, "id") and self._id == element.id - def __ne__(self, element): + def __ne__(self, element: object): return not self.__eq__(element) # Private Methods - def _execute(self, command, params=None): + def _execute(self, command: Any, params: dict[Any, Any] | None = None): """Executes a command against the underlying HTML element. Parameters: @@ -575,7 +581,7 @@ def _execute(self, command, params=None): params["id"] = self._id return self._parent.execute(command, params) - def find_element(self, by=By.ID, value=None) -> WebElement: + def find_element(self, by: ByType = By.ID, value: str | None = None) -> WebElement: """Find an element given a By strategy and locator. Parameters: @@ -604,7 +610,7 @@ def find_element(self, by=By.ID, value=None) -> WebElement: by, value = self._parent.locator_converter.convert(by, value) return self._execute(Command.FIND_CHILD_ELEMENT, {"using": by, "value": value})["value"] - def find_elements(self, by=By.ID, value=None) -> List[WebElement]: + def find_elements(self, by: ByType = By.ID, value: str | None = None) -> List[WebElement]: """Find elements given a By strategy and locator. Parameters: @@ -636,7 +642,7 @@ def find_elements(self, by=By.ID, value=None) -> List[WebElement]: def __hash__(self) -> int: return int(md5_hash(self._id.encode("utf-8")).hexdigest(), 16) - def _upload(self, filename): + def _upload(self, filename: str): fp = BytesIO() zipped = zipfile.ZipFile(fp, "w", zipfile.ZIP_DEFLATED) zipped.write(filename, os.path.split(filename)[1])