diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index b90ae35d703a0..bba1b8112b704 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -29,7 +29,7 @@ from base64 import b64decode, urlsafe_b64encode from contextlib import asynccontextmanager, contextmanager from importlib import import_module -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Type, Union, cast from selenium.common.exceptions import ( InvalidArgumentException, @@ -122,8 +122,13 @@ def get_remote_connection( candidates = [ChromeRemoteConnection, EdgeRemoteConnection, SafariRemoteConnection, FirefoxRemoteConnection] handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection) + if hasattr(command_executor, "client_config") and command_executor.client_config: + remote_server_addr = command_executor.client_config.remote_server_addr + else: + remote_server_addr = command_executor + return handler( - remote_server_addr=command_executor, + remote_server_addr=remote_server_addr, keep_alive=keep_alive, ignore_proxy=ignore_local_proxy, client_config=client_config, @@ -131,7 +136,7 @@ def get_remote_connection( def create_matches(options: list[BaseOptions]) -> dict: - capabilities = {"capabilities": {}} + capabilities: Dict[str, Any] = {"capabilities": {}} opts = [] for opt in options: opts.append(opt.to_capabilities()) @@ -154,9 +159,9 @@ def create_matches(options: list[BaseOptions]) -> dict: for k, v in samesies.items(): always[k] = v - for i in opts: + for opt_dict in opts: for k in always: - del i[k] + del opt_dict[k] capabilities["capabilities"]["alwaysMatch"] = always capabilities["capabilities"]["firstMatch"] = opts @@ -196,7 +201,7 @@ def __init__( file_detector: Optional[FileDetector] = None, options: Optional[Union[BaseOptions, list[BaseOptions]]] = None, locator_converter: Optional[LocatorConverter] = None, - web_element_cls: Optional[type] = None, + web_element_cls: Optional[Type[WebElement]] = None, client_config: Optional[ClientConfig] = None, ) -> None: """Create a new driver that will issue commands using the wire @@ -242,9 +247,9 @@ def __init__( client_config=client_config, ) self._is_remote = True - self.session_id = None - self.caps = {} - self.pinned_scripts = {} + self.session_id: Optional[str] = None + self.caps: Dict[str, Any] = {} + self.pinned_scripts: Dict[str, Any] = {} self.error_handler = ErrorHandler() self._switch_to = SwitchTo(self) self._mobile = Mobile(self) @@ -443,7 +448,8 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None) elif "sessionId" not in params: params["sessionId"] = self.session_id - response = self.command_executor.execute(driver_command, params) + response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) + if response: self.error_handler.check_response(response) response["value"] = self._unwrap_value(response.get("value", None)) @@ -606,7 +612,8 @@ def quit(self) -> None: self.execute(Command.QUIT) finally: self.stop_client() - self.command_executor.close() + executor = cast(RemoteConnection, self.command_executor) + executor.close() @property def current_window_handle(self) -> str: @@ -661,7 +668,7 @@ def print_page(self, print_options: Optional[PrintOptions] = None) -> str: -------- >>> driver.print_page() """ - options = {} + options: Union[Dict[str, Any], Any] = {} if print_options: options = print_options.to_dict() @@ -944,7 +951,10 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> list[WebElemen if isinstance(by, RelativeBy): _pkg = ".".join(__name__.split(".")[:-1]) - raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8") + raw_data = pkgutil.get_data(_pkg, "findElements.js") + if raw_data is None: + raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") + raw_function = raw_data.decode("utf8") find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" return self.execute_script(find_element_js, by.to_dict()) @@ -1416,7 +1426,7 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] @property - def virtual_authenticator_id(self) -> str: + def virtual_authenticator_id(self) -> Optional[str]: """Returns the id of the virtual authenticator. Example: