Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from base64 import b64decode, urlsafe_b64encode
from contextlib import asynccontextmanager, contextmanager
from importlib import import_module
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Type, Union, cast

from selenium.common.exceptions import (
InvalidArgumentException,
Expand Down Expand Up @@ -122,16 +122,21 @@ def get_remote_connection(
candidates = [ChromeRemoteConnection, EdgeRemoteConnection, SafariRemoteConnection, FirefoxRemoteConnection]
handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection)

if hasattr(command_executor, "client_config") and command_executor.client_config:
remote_server_addr = command_executor.client_config.remote_server_addr
else:
remote_server_addr = command_executor

return handler(
remote_server_addr=command_executor,
remote_server_addr=remote_server_addr,
keep_alive=keep_alive,
ignore_proxy=ignore_local_proxy,
client_config=client_config,
)


def create_matches(options: list[BaseOptions]) -> dict:
capabilities = {"capabilities": {}}
capabilities: Dict[str, Any] = {"capabilities": {}}
opts = []
for opt in options:
opts.append(opt.to_capabilities())
Expand All @@ -154,9 +159,9 @@ def create_matches(options: list[BaseOptions]) -> dict:
for k, v in samesies.items():
always[k] = v

for i in opts:
for opt_dict in opts:
for k in always:
del i[k]
del opt_dict[k]

capabilities["capabilities"]["alwaysMatch"] = always
capabilities["capabilities"]["firstMatch"] = opts
Expand Down Expand Up @@ -196,7 +201,7 @@ def __init__(
file_detector: Optional[FileDetector] = None,
options: Optional[Union[BaseOptions, list[BaseOptions]]] = None,
locator_converter: Optional[LocatorConverter] = None,
web_element_cls: Optional[type] = None,
web_element_cls: Optional[Type[WebElement]] = None,
client_config: Optional[ClientConfig] = None,
) -> None:
"""Create a new driver that will issue commands using the wire
Expand Down Expand Up @@ -242,9 +247,9 @@ def __init__(
client_config=client_config,
)
self._is_remote = True
self.session_id = None
self.caps = {}
self.pinned_scripts = {}
self.session_id: Optional[str] = None
self.caps: Dict[str, Any] = {}
self.pinned_scripts: Dict[str, Any] = {}
self.error_handler = ErrorHandler()
self._switch_to = SwitchTo(self)
self._mobile = Mobile(self)
Expand Down Expand Up @@ -443,7 +448,8 @@ def execute(self, driver_command: str, params: Optional[dict[str, Any]] = None)
elif "sessionId" not in params:
params["sessionId"] = self.session_id

response = self.command_executor.execute(driver_command, params)
response = cast(RemoteConnection, self.command_executor).execute(driver_command, params)

if response:
self.error_handler.check_response(response)
response["value"] = self._unwrap_value(response.get("value", None))
Expand Down Expand Up @@ -606,7 +612,8 @@ def quit(self) -> None:
self.execute(Command.QUIT)
finally:
self.stop_client()
self.command_executor.close()
executor = cast(RemoteConnection, self.command_executor)
executor.close()

@property
def current_window_handle(self) -> str:
Expand Down Expand Up @@ -661,7 +668,7 @@ def print_page(self, print_options: Optional[PrintOptions] = None) -> str:
--------
>>> driver.print_page()
"""
options = {}
options: Union[Dict[str, Any], Any] = {}
if print_options:
options = print_options.to_dict()

Expand Down Expand Up @@ -944,7 +951,10 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> list[WebElemen

if isinstance(by, RelativeBy):
_pkg = ".".join(__name__.split(".")[:-1])
raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8")
raw_data = pkgutil.get_data(_pkg, "findElements.js")
if raw_data is None:
raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}")
raw_function = raw_data.decode("utf8")
find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);"
return self.execute_script(find_element_js, by.to_dict())

Expand Down Expand Up @@ -1416,7 +1426,7 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non
self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"]

@property
def virtual_authenticator_id(self) -> str:
def virtual_authenticator_id(self) -> Optional[str]:
"""Returns the id of the virtual authenticator.

Example:
Expand Down
Loading