Skip to content

Commit 4af29a0

Browse files
committed
Fix mypy type errors in remote/webdriver.py
1 parent f52bb20 commit 4af29a0

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

py/selenium/webdriver/remote/webdriver.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from base64 import b64decode, urlsafe_b64encode
3131
from contextlib import asynccontextmanager, contextmanager
3232
from importlib import import_module
33-
from typing import Any, Optional, Union
33+
from typing import Any, Dict, Optional, Type, Union, cast
3434

3535
from selenium.common.exceptions import (
3636
InvalidArgumentException,
@@ -123,16 +123,21 @@ def get_remote_connection(
123123
candidates = [ChromeRemoteConnection, EdgeRemoteConnection, SafariRemoteConnection, FirefoxRemoteConnection]
124124
handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection)
125125

126+
if hasattr(command_executor, "client_config") and command_executor.client_config:
127+
remote_server_addr = command_executor.client_config.remote_server_addr
128+
else:
129+
None
130+
126131
return handler(
127-
remote_server_addr=command_executor,
132+
remote_server_addr=remote_server_addr,
128133
keep_alive=keep_alive,
129134
ignore_proxy=ignore_local_proxy,
130135
client_config=client_config,
131136
)
132137

133138

134139
def create_matches(options: list[BaseOptions]) -> dict:
135-
capabilities = {"capabilities": {}}
140+
capabilities: Dict[str, Any] = {"capabilities": {}}
136141
opts = []
137142
for opt in options:
138143
opts.append(opt.to_capabilities())
@@ -155,9 +160,9 @@ def create_matches(options: list[BaseOptions]) -> dict:
155160
for k, v in samesies.items():
156161
always[k] = v
157162

158-
for i in opts:
163+
for opt_dict in opts:
159164
for k in always:
160-
del i[k]
165+
del opt_dict[k]
161166

162167
capabilities["capabilities"]["alwaysMatch"] = always
163168
capabilities["capabilities"]["firstMatch"] = opts
@@ -197,7 +202,7 @@ def __init__(
197202
file_detector: Optional[FileDetector] = None,
198203
options: Optional[Union[BaseOptions, list[BaseOptions]]] = None,
199204
locator_converter: Optional[LocatorConverter] = None,
200-
web_element_cls: Optional[type] = None,
205+
web_element_cls: Optional[Type[WebElement]] = None,
201206
client_config: Optional[ClientConfig] = None,
202207
) -> None:
203208
"""Create a new driver that will issue commands using the wire
@@ -243,9 +248,9 @@ def __init__(
243248
client_config=client_config,
244249
)
245250
self._is_remote = True
246-
self.session_id = None
247-
self.caps = {}
248-
self.pinned_scripts = {}
251+
self.session_id: Optional[str] = None
252+
self.caps: Dict[str, Any] = {}
253+
self.pinned_scripts: Dict[str, Any] = {}
249254
self.error_handler = ErrorHandler()
250255
self._switch_to = SwitchTo(self)
251256
self._mobile = Mobile(self)
@@ -442,7 +447,9 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None)
442447
elif "sessionId" not in params:
443448
params["sessionId"] = self.session_id
444449

445-
response = self.command_executor.execute(driver_command, params)
450+
executor = cast(RemoteConnection, self.command_executor)
451+
response = executor.execute(driver_command, params)
452+
446453
if response:
447454
self.error_handler.check_response(response)
448455
response["value"] = self._unwrap_value(response.get("value", None))
@@ -605,7 +612,8 @@ def quit(self) -> None:
605612
self.execute(Command.QUIT)
606613
finally:
607614
self.stop_client()
608-
self.command_executor.close()
615+
executor = cast(RemoteConnection, self.command_executor)
616+
executor.close()
609617

610618
@property
611619
def current_window_handle(self) -> str:
@@ -660,7 +668,7 @@ def print_page(self, print_options: Optional[PrintOptions] = None) -> str:
660668
--------
661669
>>> driver.print_page()
662670
"""
663-
options = {}
671+
options: Union[Dict[str, Any], Any] = {}
664672
if print_options:
665673
options = print_options.to_dict()
666674

@@ -943,7 +951,10 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> list[WebElemen
943951

944952
if isinstance(by, RelativeBy):
945953
_pkg = ".".join(__name__.split(".")[:-1])
946-
raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8")
954+
raw_data = pkgutil.get_data(_pkg, "findElements.js")
955+
if raw_data is None:
956+
raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}")
957+
raw_function = raw_data.decode("utf8")
947958
find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);"
948959
return self.execute_script(find_element_js, by.to_dict())
949960

@@ -1397,7 +1408,7 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non
13971408
self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"]
13981409

13991410
@property
1400-
def virtual_authenticator_id(self) -> str:
1411+
def virtual_authenticator_id(self) -> Optional[str]:
14011412
"""Returns the id of the virtual authenticator.
14021413
14031414
Example:

0 commit comments

Comments
 (0)