From 4af29a0d79a374e58c3597c8c284c93892602b8e Mon Sep 17 00:00:00 2001 From: Shaurya Bisht <87357655+ShauryaDusht@users.noreply.github.com> Date: Sun, 15 Jun 2025 20:56:54 +0530 Subject: [PATCH 1/2] Fix mypy type errors in remote/webdriver.py --- py/selenium/webdriver/remote/webdriver.py | 39 +++++++++++++++-------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 149f12d8fe1a0..985f0720c721d 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -30,7 +30,7 @@ from base64 import b64decode, urlsafe_b64encode from contextlib import asynccontextmanager, contextmanager from importlib import import_module -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Type, Union, cast from selenium.common.exceptions import ( InvalidArgumentException, @@ -123,8 +123,13 @@ def get_remote_connection( candidates = [ChromeRemoteConnection, EdgeRemoteConnection, SafariRemoteConnection, FirefoxRemoteConnection] handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection) + if hasattr(command_executor, "client_config") and command_executor.client_config: + remote_server_addr = command_executor.client_config.remote_server_addr + else: + None + return handler( - remote_server_addr=command_executor, + remote_server_addr=remote_server_addr, keep_alive=keep_alive, ignore_proxy=ignore_local_proxy, client_config=client_config, @@ -132,7 +137,7 @@ def get_remote_connection( def create_matches(options: list[BaseOptions]) -> dict: - capabilities = {"capabilities": {}} + capabilities: Dict[str, Any] = {"capabilities": {}} opts = [] for opt in options: opts.append(opt.to_capabilities()) @@ -155,9 +160,9 @@ def create_matches(options: list[BaseOptions]) -> dict: for k, v in samesies.items(): always[k] = v - for i in opts: + for opt_dict in opts: for k in always: - del i[k] + del opt_dict[k] capabilities["capabilities"]["alwaysMatch"] = always capabilities["capabilities"]["firstMatch"] = opts @@ -197,7 +202,7 @@ def __init__( file_detector: Optional[FileDetector] = None, options: Optional[Union[BaseOptions, list[BaseOptions]]] = None, locator_converter: Optional[LocatorConverter] = None, - web_element_cls: Optional[type] = None, + web_element_cls: Optional[Type[WebElement]] = None, client_config: Optional[ClientConfig] = None, ) -> None: """Create a new driver that will issue commands using the wire @@ -243,9 +248,9 @@ def __init__( client_config=client_config, ) self._is_remote = True - self.session_id = None - self.caps = {} - self.pinned_scripts = {} + self.session_id: Optional[str] = None + self.caps: Dict[str, Any] = {} + self.pinned_scripts: Dict[str, Any] = {} self.error_handler = ErrorHandler() self._switch_to = SwitchTo(self) self._mobile = Mobile(self) @@ -442,7 +447,9 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None) elif "sessionId" not in params: params["sessionId"] = self.session_id - response = self.command_executor.execute(driver_command, params) + executor = cast(RemoteConnection, self.command_executor) + response = executor.execute(driver_command, params) + if response: self.error_handler.check_response(response) response["value"] = self._unwrap_value(response.get("value", None)) @@ -605,7 +612,8 @@ def quit(self) -> None: self.execute(Command.QUIT) finally: self.stop_client() - self.command_executor.close() + executor = cast(RemoteConnection, self.command_executor) + executor.close() @property def current_window_handle(self) -> str: @@ -660,7 +668,7 @@ def print_page(self, print_options: Optional[PrintOptions] = None) -> str: -------- >>> driver.print_page() """ - options = {} + options: Union[Dict[str, Any], Any] = {} if print_options: options = print_options.to_dict() @@ -943,7 +951,10 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> list[WebElemen if isinstance(by, RelativeBy): _pkg = ".".join(__name__.split(".")[:-1]) - raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8") + raw_data = pkgutil.get_data(_pkg, "findElements.js") + if raw_data is None: + raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") + raw_function = raw_data.decode("utf8") find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" return self.execute_script(find_element_js, by.to_dict()) @@ -1397,7 +1408,7 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] @property - def virtual_authenticator_id(self) -> str: + def virtual_authenticator_id(self) -> Optional[str]: """Returns the id of the virtual authenticator. Example: From 8e3ebd7e305148b7bf84ac92e8343afe90b5c893 Mon Sep 17 00:00:00 2001 From: Shaurya Bisht <87357655+ShauryaDusht@users.noreply.github.com> Date: Sun, 15 Jun 2025 21:12:15 +0530 Subject: [PATCH 2/2] Minor fixes --- py/selenium/webdriver/remote/webdriver.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 985f0720c721d..a41d63ad2a17f 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -126,7 +126,7 @@ def get_remote_connection( if hasattr(command_executor, "client_config") and command_executor.client_config: remote_server_addr = command_executor.client_config.remote_server_addr else: - None + remote_server_addr = command_executor return handler( remote_server_addr=remote_server_addr, @@ -447,8 +447,7 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None) elif "sessionId" not in params: params["sessionId"] = self.session_id - executor = cast(RemoteConnection, self.command_executor) - response = executor.execute(driver_command, params) + response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) if response: self.error_handler.check_response(response)