diff --git a/py/selenium/webdriver/chromium/service.py b/py/selenium/webdriver/chromium/service.py index 9f50e21e8e36f..74045feb2521d 100644 --- a/py/selenium/webdriver/chromium/service.py +++ b/py/selenium/webdriver/chromium/service.py @@ -49,7 +49,7 @@ def __init__( if isinstance(log_output, str): self.service_args.append(f"--log-path={log_output}") - self.log_output: Optional[IOBase] = None + self.log_output: cast(IOBase, None) elif isinstance(log_output, IOBase): self.log_output = log_output else: diff --git a/py/selenium/webdriver/chromium/webdriver.py b/py/selenium/webdriver/chromium/webdriver.py index 4dcf8d73e8fe0..f525d34317d6e 100644 --- a/py/selenium/webdriver/chromium/webdriver.py +++ b/py/selenium/webdriver/chromium/webdriver.py @@ -33,10 +33,12 @@ def __init__( self, browser_name: Optional[str] = None, vendor_prefix: Optional[str] = None, - options: ArgOptions = ArgOptions(), + options: ArgOptions = None, service: Optional[Service] = None, keep_alive: bool = True, ) -> None: + if options is None: + options = ArgOptions() """Creates a new WebDriver instance of the ChromiumDriver. Starts the service and then creates new WebDriver instance of ChromiumDriver. @@ -49,6 +51,9 @@ def __init__( """ self.service = service + if self.service is None: + raise ValueError("Service must be provided and cannot be None") + finder = DriverFinder(self.service, options) if finder.get_browser_path(): options.binary_location = finder.get_browser_path() @@ -59,8 +64,8 @@ def __init__( executor = ChromiumRemoteConnection( remote_server_addr=self.service.service_url, - browser_name=browser_name, - vendor_prefix=vendor_prefix, + browser_name=browser_name or "", + vendor_prefix=vendor_prefix or "", keep_alive=keep_alive, ignore_proxy=options._ignore_local_proxy, ) @@ -221,7 +226,8 @@ def quit(self) -> None: # We don't care about the message because something probably has gone wrong pass finally: - self.service.stop() + if self.service is not None: + self.service.stop() def download_file(self, *args, **kwargs): raise NotImplementedError diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index ce697051bef44..52229037d8de7 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -125,13 +125,13 @@ def from_dict(cls, data: dict) -> "ClientWindowInfo": ClientWindowInfo: A new instance of ClientWindowInfo. """ return cls( - client_window=data.get("clientWindow"), - state=data.get("state"), - width=data.get("width"), - height=data.get("height"), - x=data.get("x"), - y=data.get("y"), - active=data.get("active"), + client_window=str(data.get("clientWindow")), + state=str(data.get("state")), + width=int(data.get("width") or 0), + height=int(data.get("height") or 0), + x=int(data.get("x") or 0), + y=int(data.get("y") or 0), + active=bool(data.get("active")), ) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index e843752b7a27f..55e62aaf9f153 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Union +from typing import Optional, Union, Callable from selenium.webdriver.common.bidi.common import command_builder @@ -67,10 +67,10 @@ def from_json(cls, json: dict) -> "NavigationInfo": NavigationInfo: A new instance of NavigationInfo. """ return cls( - context=json.get("context"), + context=str(json.get("context")), navigation=json.get("navigation"), - timestamp=json.get("timestamp"), - url=json.get("url"), + timestamp=int(json.get("timestamp") or 0), + url=str(json.get("url")), ) @@ -109,11 +109,11 @@ def from_json(cls, json: dict) -> "BrowsingContextInfo": """ children = None if json.get("children") is not None: - children = [BrowsingContextInfo.from_json(child) for child in json.get("children")] + children = [BrowsingContextInfo.from_json(child) for child in json.get("children", [])] return cls( - context=json.get("context"), - url=json.get("url"), + context=str(json.get("context")), + url=str(json.get("url")), children=children, parent=json.get("parent"), user_context=json.get("userContext"), @@ -149,11 +149,11 @@ def from_json(cls, json: dict) -> "DownloadWillBeginParams": DownloadWillBeginParams: A new instance of DownloadWillBeginParams. """ return cls( - context=json.get("context"), + context=str(json.get("context")), navigation=json.get("navigation"), - timestamp=json.get("timestamp"), - url=json.get("url"), - suggested_filename=json.get("suggestedFilename"), + timestamp=int(json.get("timestamp") or 0), + url=str(json.get("url")), + suggested_filename=str(json.get("suggestedFilename")), ) @@ -187,10 +187,10 @@ def from_json(cls, json: dict) -> "UserPromptOpenedParams": UserPromptOpenedParams: A new instance of UserPromptOpenedParams. """ return cls( - context=json.get("context"), - handler=json.get("handler"), - message=json.get("message"), - type=json.get("type"), + context=str(json.get("context")), + handler=str(json.get("handler")), + message=str(json.get("message")), + type=str(json.get("type")), default_value=json.get("defaultValue"), ) @@ -223,9 +223,9 @@ def from_json(cls, json: dict) -> "UserPromptClosedParams": UserPromptClosedParams: A new instance of UserPromptClosedParams. """ return cls( - context=json.get("context"), - accepted=json.get("accepted"), - type=json.get("type"), + context=str(json.get("context")), + accepted=bool(json.get("accepted")), + type=str(json.get("type")), user_text=json.get("userText"), ) @@ -254,8 +254,8 @@ def from_json(cls, json: dict) -> "HistoryUpdatedParams": HistoryUpdatedParams: A new instance of HistoryUpdatedParams. """ return cls( - context=json.get("context"), - url=json.get("url"), + context=str(json.get("context")), + url=str(json.get("url")), ) @@ -278,7 +278,7 @@ def from_json(cls, json: dict) -> "BrowsingContextEvent": ------- BrowsingContextEvent: A new instance of BrowsingContextEvent. """ - return cls(event_class=json.get("event_class"), **json) + return cls(event_class=str(json.get("event_class")), **json) class BrowsingContext: @@ -341,9 +341,9 @@ def capture_screenshot( """ params = {"context": context, "origin": origin} if format is not None: - params["format"] = format + params["format"] = str(format) if clip is not None: - params["clip"] = clip + params["clip"] = str(clip) result = self.conn.execute(command_builder("browsingContext.captureScreenshot", params)) return result["data"] @@ -387,7 +387,7 @@ def create( if reference_context is not None: params["referenceContext"] = reference_context if background is not None: - params["background"] = background + params["background"] = str(background) if user_context is not None: params["userContext"] = user_context @@ -415,7 +415,7 @@ def get_tree( if max_depth is not None: params["maxDepth"] = max_depth if root is not None: - params["root"] = root + params["root"] = int(root or 0) result = self.conn.execute(command_builder("browsingContext.getTree", params)) return [BrowsingContextInfo.from_json(context) for context in result["contexts"]] @@ -436,7 +436,7 @@ def handle_user_prompt( """ params = {"context": context} if accept is not None: - params["accept"] = accept + params["accept"] = str(accept) if user_text is not None: params["userText"] = user_text @@ -466,7 +466,7 @@ def locate_nodes( """ params = {"context": context, "locator": locator} if max_node_count is not None: - params["maxNodeCount"] = max_node_count + params["maxNodeCount"] = [int(max_node_count or 0)] if serialization_options is not None: params["serializationOptions"] = serialization_options if start_nodes is not None: @@ -566,7 +566,7 @@ def reload( """ params = {"context": context} if ignore_cache is not None: - params["ignoreCache"] = ignore_cache + params["ignoreCache"] = str(ignore_cache) if wait is not None: params["wait"] = wait @@ -597,11 +597,11 @@ def set_viewport( if context is not None: params["context"] = context if viewport is not None: - params["viewport"] = viewport + params["viewport"] = str(viewport) if device_pixel_ratio is not None: - params["devicePixelRatio"] = device_pixel_ratio + params["devicePixelRatio"] = str(device_pixel_ratio) if user_contexts is not None: - params["userContexts"] = user_contexts + params["userContexts"] = str(user_contexts) self.conn.execute(command_builder("browsingContext.setViewport", params)) @@ -621,7 +621,7 @@ def traverse_history(self, context: str, delta: int) -> dict: result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) return result - def _on_event(self, event_name: str, callback: callable) -> int: + def _on_event(self, event_name: str, callback: Callable) -> int: """Set a callback function to subscribe to a browsing context event. Parameters: @@ -665,7 +665,7 @@ def _callback(event_data): return callback_id - def add_event_handler(self, event: str, callback: callable, contexts: Optional[list[str]] = None) -> int: + def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: """Add an event handler to the browsing context. Parameters: @@ -710,7 +710,7 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: except KeyError: raise Exception(f"Event {event} not found") - event = BrowsingContextEvent(event_name) + event = str(BrowsingContextEvent(event_name)) self.conn.remove_callback(event, callback_id) self.subscriptions[event_name].remove(callback_id) diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py index 6ab3d3b012299..13d5628bbaa15 100644 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -427,6 +427,10 @@ async def connect_session(self, target_id) -> "CdpSession": """Returns a new :class:`CdpSession` connected to the specified target.""" global devtools + if devtools and devtools.target: + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + else: + raise RuntimeError("devtools.target is not available.") session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) session = CdpSession(self.ws, session_id, target_id) self.sessions[session_id] = session diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 620477aeae6da..5a25f9ccceff7 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -88,9 +88,9 @@ def from_dict(cls, data: dict) -> "Cookie": value = BytesValue(data.get("value", {}).get("type"), data.get("value", {}).get("value")) return cls( - name=data.get("name"), + name=str(data.get("name")), value=value, - domain=data.get("domain"), + domain=str(data.get("domain")), path=data.get("path"), size=data.get("size"), http_only=data.get("httpOnly"), @@ -136,21 +136,21 @@ def to_dict(self) -> dict: if self.name is not None: result["name"] = self.name if self.value is not None: - result["value"] = self.value.to_dict() + result["value"] = str(self.value.to_dict()) if self.domain is not None: result["domain"] = self.domain if self.path is not None: result["path"] = self.path if self.size is not None: - result["size"] = self.size + result["size"] = str(self.size) if self.http_only is not None: - result["httpOnly"] = self.http_only + result["httpOnly"] = str(self.http_only) if self.secure is not None: - result["secure"] = self.secure + result["secure"] = str(self.secure) if self.same_site is not None: result["sameSite"] = self.same_site if self.expiry is not None: - result["expiry"] = self.expiry + result["expiry"] = str(self.expiry) return result @@ -257,13 +257,13 @@ def to_dict(self) -> dict: if self.path is not None: result["path"] = self.path if self.http_only is not None: - result["httpOnly"] = self.http_only + result["httpOnly"] = [self.http_only] if self.secure is not None: - result["secure"] = self.secure + result["secure"] = [self.secure] if self.same_site is not None: result["sameSite"] = self.same_site if self.expiry is not None: - result["expiry"] = self.expiry + result["expiry"] = [self.expiry] return result diff --git a/py/selenium/webdriver/common/options.py b/py/selenium/webdriver/common/options.py index 67e5765645133..21d459e42e0af 100644 --- a/py/selenium/webdriver/common/options.py +++ b/py/selenium/webdriver/common/options.py @@ -422,7 +422,7 @@ def __init__(self) -> None: self._caps = self.default_capabilities self._proxy = None self.set_capability("pageLoadStrategy", PageLoadStrategy.normal) - self.mobile_options = None + self.mobile_options: Optional[dict[str, str]] = None self._ignore_local_proxy = False @property @@ -475,6 +475,7 @@ class ArgOptions(BaseOptions): def __init__(self) -> None: super().__init__() self._arguments: list[str] = [] + self.binary_location: Optional[str] = None @property def arguments(self): diff --git a/py/selenium/webdriver/common/service.py b/py/selenium/webdriver/common/service.py index e03adb6202f84..3e210b66148a3 100644 --- a/py/selenium/webdriver/common/service.py +++ b/py/selenium/webdriver/common/service.py @@ -58,12 +58,12 @@ def __init__( ) -> None: if isinstance(log_output, str): self.log_output = cast(IOBase, open(log_output, "a+", encoding="utf-8")) - elif log_output == subprocess.STDOUT: - self.log_output = cast(Optional[Union[int, IOBase]], None) - elif log_output is None or log_output == subprocess.DEVNULL: - self.log_output = cast(Optional[Union[int, IOBase]], subprocess.DEVNULL) - else: + elif log_output in {subprocess.STDOUT, subprocess.DEVNULL, None}: + self.log_output = cast(IOBase, subprocess.DEVNULL) + elif isinstance(log_output, IOBase): self.log_output = log_output + else: + raise TypeError("log_output must be a string, IOBase, or a valid subprocess constant") self.port = port or utils.free_port() # Default value for every python subprocess: subprocess.Popen(..., creationflags=0) diff --git a/py/selenium/webdriver/common/utils.py b/py/selenium/webdriver/common/utils.py index b04e2b0e40c30..11f60d59375b2 100644 --- a/py/selenium/webdriver/common/utils.py +++ b/py/selenium/webdriver/common/utils.py @@ -64,13 +64,13 @@ def find_connectable_ip(host: Union[str, bytes, bytearray, None], port: Optional for family, _, _, _, sockaddr in addrinfos: connectable = True if port: - connectable = is_connectable(port, sockaddr[0]) + connectable = is_connectable(port, str(sockaddr[0])) if connectable and family == socket.AF_INET: - return sockaddr[0] + return str(sockaddr[0]) if connectable and not ip and family == socket.AF_INET6: ip = sockaddr[0] - return ip + return str(ip) if ip else None def join_host_port(host: str, port: int) -> str: @@ -132,7 +132,7 @@ def keys_to_typing(value: Iterable[AnyKey]) -> list[str]: for val in value: if isinstance(val, Keys): # Todo: Does this even work? - characters.append(val) + characters.append(str(val)) elif isinstance(val, (int, float)): characters.extend(str(val)) else: diff --git a/py/selenium/webdriver/common/virtual_authenticator.py b/py/selenium/webdriver/common/virtual_authenticator.py index a434de83741df..19b0ec615e7c5 100644 --- a/py/selenium/webdriver/common/virtual_authenticator.py +++ b/py/selenium/webdriver/common/virtual_authenticator.py @@ -24,17 +24,17 @@ class Protocol(str, Enum): """Protocol to communicate with the authenticator.""" - CTAP2: str = "ctap2" - U2F: str = "ctap1/u2f" + CTAP2 = "ctap2" + U2F = "ctap1/u2f" class Transport(str, Enum): """Transport method to communicate with the authenticator.""" - BLE: str = "ble" - USB: str = "usb" - NFC: str = "nfc" - INTERNAL: str = "internal" + BLE = "ble" + USB = "usb" + NFC = "nfc" + INTERNAL = "internal" class VirtualAuthenticatorOptions: diff --git a/py/selenium/webdriver/firefox/remote_connection.py b/py/selenium/webdriver/firefox/remote_connection.py index a749cce37dc62..8f8019f223613 100644 --- a/py/selenium/webdriver/firefox/remote_connection.py +++ b/py/selenium/webdriver/firefox/remote_connection.py @@ -23,7 +23,7 @@ class FirefoxRemoteConnection(RemoteConnection): - browser_name = DesiredCapabilities.FIREFOX["browserName"] + browser_name: str = str(DesiredCapabilities.FIREFOX["browserName"]) def __init__( self, diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index 7ed47b56275b3..11cbc91e7ac2d 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -23,6 +23,7 @@ from typing import Optional from urllib import parse from urllib.parse import urlparse +from typing import Any import urllib3 @@ -368,7 +369,7 @@ def __init__( self._conn = self._get_connection_manager() self._commands = remote_commands - extra_commands = {} + extra_commands: dict[str, Any] = {} def add_command(self, name, method, url): """Register a new command.""" diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 149f12d8fe1a0..302ef3fd36595 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -30,7 +30,7 @@ from base64 import b64decode, urlsafe_b64encode from contextlib import asynccontextmanager, contextmanager from importlib import import_module -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from selenium.common.exceptions import ( InvalidArgumentException, @@ -124,7 +124,7 @@ def get_remote_connection( handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection) return handler( - remote_server_addr=command_executor, + remote_server_addr=str(command_executor), keep_alive=keep_alive, ignore_proxy=ignore_local_proxy, client_config=client_config, @@ -132,7 +132,7 @@ def get_remote_connection( def create_matches(options: list[BaseOptions]) -> dict: - capabilities = {"capabilities": {}} + capabilities: dict[str, Any] = {"capabilities": {}} opts = [] for opt in options: opts.append(opt.to_capabilities()) @@ -156,8 +156,10 @@ def create_matches(options: list[BaseOptions]) -> dict: always[k] = v for i in opts: - for k in always: - del i[k] + # Ensure `i` is a dictionary before we delete from it + if isinstance(i, dict): + for k in always: + del i[k] capabilities["capabilities"]["alwaysMatch"] = always capabilities["capabilities"]["firstMatch"] = opts @@ -242,16 +244,19 @@ def __init__( ignore_local_proxy=_ignore_local_proxy, client_config=client_config, ) + if web_element_cls and not issubclass(web_element_cls, WebElement): + raise TypeError("web_element_cls must be a subclass of WebElement") + self._is_remote = True self.session_id = None - self.caps = {} - self.pinned_scripts = {} + self.caps: dict[str, Any] = {} + self.pinned_scripts: dict[str, Any] = {} self.error_handler = ErrorHandler() self._switch_to = SwitchTo(self) self._mobile = Mobile(self) self.file_detector = file_detector or LocalFileDetector() self.locator_converter = locator_converter or LocatorConverter() - self._web_element_cls = web_element_cls or self._web_element_cls + self._web_element_cls = cast(type[WebElement], web_element_cls) if web_element_cls else WebElement self._authenticator_id = None self.start_client() self.start_session(capabilities) @@ -442,7 +447,13 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None) elif "sessionId" not in params: params["sessionId"] = self.session_id - response = self.command_executor.execute(driver_command, params) + # Ensure `self.command_executor` is an instance of `RemoteConnection` + # before attempting to call its `execute` method. + if isinstance(self.command_executor, RemoteConnection): + response = self.command_executor.execute(driver_command, params) + else: + raise TypeError("command_executor must be an instance of RemoteConnection") + if response: self.error_handler.check_response(response) response["value"] = self._unwrap_value(response.get("value", None)) @@ -605,7 +616,8 @@ def quit(self) -> None: self.execute(Command.QUIT) finally: self.stop_client() - self.command_executor.close() + if isinstance(self.command_executor, RemoteConnection): + self.command_executor.close() @property def current_window_handle(self) -> str: @@ -660,9 +672,9 @@ def print_page(self, print_options: Optional[PrintOptions] = None) -> str: -------- >>> driver.print_page() """ - options = {} + options: dict[str, Any] = {} if print_options: - options = print_options.to_dict() + options = cast(dict[str, Any], print_options.to_dict()) return self.execute(Command.PRINT_PAGE, options)["value"] @@ -943,8 +955,10 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> list[WebElemen if isinstance(by, RelativeBy): _pkg = ".".join(__name__.split(".")[:-1]) - raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8") - find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" + raw_function = pkgutil.get_data(_pkg, "findElements.js") + if raw_function is None: + raise ValueError("Failed to load 'findElements.js' using pkgutil.get_data.") + find_element_js = f"/* findElements */return ({raw_function.decode('utf8')}).apply(null, arguments);" return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null @@ -1404,7 +1418,7 @@ def virtual_authenticator_id(self) -> str: -------- >>> print(driver.virtual_authenticator_id) """ - return self._authenticator_id + return self._authenticator_id or "" @required_virtual_authenticator def remove_virtual_authenticator(self) -> None: diff --git a/py/selenium/webdriver/webkitgtk/service.py b/py/selenium/webdriver/webkitgtk/service.py index e7f300019c838..bffc9418a4aab 100644 --- a/py/selenium/webdriver/webkitgtk/service.py +++ b/py/selenium/webdriver/webkitgtk/service.py @@ -21,7 +21,7 @@ from selenium.webdriver.common import service -DEFAULT_EXECUTABLE_PATH: str = shutil.which("WebKitWebDriver") +DEFAULT_EXECUTABLE_PATH: Optional[str] = shutil.which("WebKitWebDriver") class Service(service.Service): diff --git a/py/selenium/webdriver/wpewebkit/service.py b/py/selenium/webdriver/wpewebkit/service.py index 1f2b244807583..a8c0ef3d54c87 100644 --- a/py/selenium/webdriver/wpewebkit/service.py +++ b/py/selenium/webdriver/wpewebkit/service.py @@ -38,7 +38,7 @@ class Service(service.Service): def __init__( self, - executable_path: str = DEFAULT_EXECUTABLE_PATH, + executable_path: Optional[str] = DEFAULT_EXECUTABLE_PATH, port: int = 0, log_output: Optional[str] = None, service_args: Optional[list[str]] = None,