Skip to content

Commit b73da5e

Browse files
[py] Fix: Mypy type annotation errors in remote/webdriver.py (#15900)
Co-authored-by: Navin Chandra <[email protected]>
1 parent e033b6f commit b73da5e

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

py/selenium/webdriver/remote/webdriver.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from base64 import b64decode, urlsafe_b64encode
3030
from contextlib import asynccontextmanager, contextmanager
3131
from importlib import import_module
32-
from typing import Any, Optional, Union
32+
from typing import Any, Dict, Optional, Type, Union, cast
3333

3434
from selenium.common.exceptions import (
3535
InvalidArgumentException,
@@ -122,16 +122,21 @@ def get_remote_connection(
122122
candidates = [ChromeRemoteConnection, EdgeRemoteConnection, SafariRemoteConnection, FirefoxRemoteConnection]
123123
handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection)
124124

125+
if hasattr(command_executor, "client_config") and command_executor.client_config:
126+
remote_server_addr = command_executor.client_config.remote_server_addr
127+
else:
128+
remote_server_addr = command_executor
129+
125130
return handler(
126-
remote_server_addr=command_executor,
131+
remote_server_addr=remote_server_addr,
127132
keep_alive=keep_alive,
128133
ignore_proxy=ignore_local_proxy,
129134
client_config=client_config,
130135
)
131136

132137

133138
def create_matches(options: list[BaseOptions]) -> dict:
134-
capabilities = {"capabilities": {}}
139+
capabilities: Dict[str, Any] = {"capabilities": {}}
135140
opts = []
136141
for opt in options:
137142
opts.append(opt.to_capabilities())
@@ -154,9 +159,9 @@ def create_matches(options: list[BaseOptions]) -> dict:
154159
for k, v in samesies.items():
155160
always[k] = v
156161

157-
for i in opts:
162+
for opt_dict in opts:
158163
for k in always:
159-
del i[k]
164+
del opt_dict[k]
160165

161166
capabilities["capabilities"]["alwaysMatch"] = always
162167
capabilities["capabilities"]["firstMatch"] = opts
@@ -196,7 +201,7 @@ def __init__(
196201
file_detector: Optional[FileDetector] = None,
197202
options: Optional[Union[BaseOptions, list[BaseOptions]]] = None,
198203
locator_converter: Optional[LocatorConverter] = None,
199-
web_element_cls: Optional[type] = None,
204+
web_element_cls: Optional[Type[WebElement]] = None,
200205
client_config: Optional[ClientConfig] = None,
201206
) -> None:
202207
"""Create a new driver that will issue commands using the wire
@@ -242,9 +247,9 @@ def __init__(
242247
client_config=client_config,
243248
)
244249
self._is_remote = True
245-
self.session_id = None
246-
self.caps = {}
247-
self.pinned_scripts = {}
250+
self.session_id: Optional[str] = None
251+
self.caps: Dict[str, Any] = {}
252+
self.pinned_scripts: Dict[str, Any] = {}
248253
self.error_handler = ErrorHandler()
249254
self._switch_to = SwitchTo(self)
250255
self._mobile = Mobile(self)
@@ -443,7 +448,8 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None)
443448
elif "sessionId" not in params:
444449
params["sessionId"] = self.session_id
445450

446-
response = self.command_executor.execute(driver_command, params)
451+
response = cast(RemoteConnection, self.command_executor).execute(driver_command, params)
452+
447453
if response:
448454
self.error_handler.check_response(response)
449455
response["value"] = self._unwrap_value(response.get("value", None))
@@ -606,7 +612,8 @@ def quit(self) -> None:
606612
self.execute(Command.QUIT)
607613
finally:
608614
self.stop_client()
609-
self.command_executor.close()
615+
executor = cast(RemoteConnection, self.command_executor)
616+
executor.close()
610617

611618
@property
612619
def current_window_handle(self) -> str:
@@ -661,7 +668,7 @@ def print_page(self, print_options: Optional[PrintOptions] = None) -> str:
661668
--------
662669
>>> driver.print_page()
663670
"""
664-
options = {}
671+
options: Union[Dict[str, Any], Any] = {}
665672
if print_options:
666673
options = print_options.to_dict()
667674

@@ -944,7 +951,10 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> list[WebElemen
944951

945952
if isinstance(by, RelativeBy):
946953
_pkg = ".".join(__name__.split(".")[:-1])
947-
raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8")
954+
raw_data = pkgutil.get_data(_pkg, "findElements.js")
955+
if raw_data is None:
956+
raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}")
957+
raw_function = raw_data.decode("utf8")
948958
find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);"
949959
return self.execute_script(find_element_js, by.to_dict())
950960

@@ -1416,7 +1426,7 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non
14161426
self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"]
14171427

14181428
@property
1419-
def virtual_authenticator_id(self) -> str:
1429+
def virtual_authenticator_id(self) -> Optional[str]:
14201430
"""Returns the id of the virtual authenticator.
14211431
14221432
Example:

0 commit comments

Comments
 (0)