diff --git a/py/selenium/webdriver/common/action_chains.py b/py/selenium/webdriver/common/action_chains.py index 7dcf5cae7f143..17aa1f3fba7a4 100644 --- a/py/selenium/webdriver/common/action_chains.py +++ b/py/selenium/webdriver/common/action_chains.py @@ -273,7 +273,7 @@ def pause(self, seconds: float | int) -> ActionChains: """Pause all inputs for the specified duration in seconds.""" self.w3c_actions.pointer_action.pause(seconds) - self.w3c_actions.key_action.pause(seconds) + self.w3c_actions.key_action.pause(int(seconds)) return self diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index e843752b7a27f..1b35fc871d629 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Union +import warnings +from typing import Any, Callable, Optional, Union from selenium.webdriver.common.bidi.common import command_builder @@ -66,12 +67,23 @@ def from_json(cls, json: dict) -> "NavigationInfo": ------- NavigationInfo: A new instance of NavigationInfo. """ - return cls( - context=json.get("context"), - navigation=json.get("navigation"), - timestamp=json.get("timestamp"), - url=json.get("url"), - ) + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + navigation = json.get("navigation") + if navigation is not None and not isinstance(navigation, str): + raise ValueError("navigation must be a string") + + timestamp = json.get("timestamp") + if timestamp is None or not isinstance(timestamp, int): + raise ValueError("timestamp is required and must be an integer") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + + return cls(context, navigation, timestamp, url) class BrowsingContextInfo: @@ -108,12 +120,25 @@ def from_json(cls, json: dict) -> "BrowsingContextInfo": BrowsingContextInfo: A new instance of BrowsingContextInfo. """ children = None - if json.get("children") is not None: - children = [BrowsingContextInfo.from_json(child) for child in json.get("children")] + raw_children = json.get("children") + if raw_children is not None and isinstance(raw_children, list): + children = [] + for child in raw_children: + if isinstance(child, dict): + children.append(BrowsingContextInfo.from_json(child)) + else: + warnings.warn(f"Unexpected child type in browsing context: {type(child)}") + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") return cls( - context=json.get("context"), - url=json.get("url"), + context=context, + url=url, children=children, parent=json.get("parent"), user_context=json.get("userContext"), @@ -148,12 +173,32 @@ def from_json(cls, json: dict) -> "DownloadWillBeginParams": ------- DownloadWillBeginParams: A new instance of DownloadWillBeginParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + navigation = json.get("navigation") + if navigation is not None and not isinstance(navigation, str): + raise ValueError("navigation must be a string") + + timestamp = json.get("timestamp") + if timestamp is None or not isinstance(timestamp, int): + raise ValueError("timestamp is required and must be an integer") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + + suggested_filename = json.get("suggestedFilename") + if suggested_filename is None or not isinstance(suggested_filename, str): + raise ValueError("suggestedFilename is required and must be a string") + return cls( - context=json.get("context"), - navigation=json.get("navigation"), - timestamp=json.get("timestamp"), - url=json.get("url"), - suggested_filename=json.get("suggestedFilename"), + context=context, + navigation=navigation, + timestamp=timestamp, + url=url, + suggested_filename=suggested_filename, ) @@ -186,12 +231,32 @@ def from_json(cls, json: dict) -> "UserPromptOpenedParams": ------- UserPromptOpenedParams: A new instance of UserPromptOpenedParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + handler = json.get("handler") + if handler is None or not isinstance(handler, str): + raise ValueError("handler is required and must be a string") + + message = json.get("message") + if message is None or not isinstance(message, str): + raise ValueError("message is required and must be a string") + + type_value = json.get("type") + if type_value is None or not isinstance(type_value, str): + raise ValueError("type is required and must be a string") + + default_value = json.get("defaultValue") + if default_value is not None and not isinstance(default_value, str): + raise ValueError("defaultValue must be a string if provided") + return cls( - context=json.get("context"), - handler=json.get("handler"), - message=json.get("message"), - type=json.get("type"), - default_value=json.get("defaultValue"), + context=context, + handler=handler, + message=message, + type=type_value, + default_value=default_value, ) @@ -222,11 +287,27 @@ def from_json(cls, json: dict) -> "UserPromptClosedParams": ------- UserPromptClosedParams: A new instance of UserPromptClosedParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + accepted = json.get("accepted") + if accepted is None or not isinstance(accepted, bool): + raise ValueError("accepted is required and must be a boolean") + + type_value = json.get("type") + if type_value is None or not isinstance(type_value, str): + raise ValueError("type is required and must be a string") + + user_text = json.get("userText") + if user_text is not None and not isinstance(user_text, str): + raise ValueError("userText must be a string if provided") + return cls( - context=json.get("context"), - accepted=json.get("accepted"), - type=json.get("type"), - user_text=json.get("userText"), + context=context, + accepted=accepted, + type=type_value, + user_text=user_text, ) @@ -253,9 +334,17 @@ def from_json(cls, json: dict) -> "HistoryUpdatedParams": ------- HistoryUpdatedParams: A new instance of HistoryUpdatedParams. """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + return cls( - context=json.get("context"), - url=json.get("url"), + context=context, + url=url, ) @@ -278,7 +367,11 @@ def from_json(cls, json: dict) -> "BrowsingContextEvent": ------- BrowsingContextEvent: A new instance of BrowsingContextEvent. """ - return cls(event_class=json.get("event_class"), **json) + event_class = json.get("event_class") + if event_class is None or not isinstance(event_class, str): + raise ValueError("event_class is required and must be a string") + + return cls(event_class=event_class, **json) class BrowsingContext: @@ -339,7 +432,7 @@ def capture_screenshot( ------- str: The Base64-encoded screenshot. """ - params = {"context": context, "origin": origin} + params: dict[str, Any] = {"context": context, "origin": origin} if format is not None: params["format"] = format if clip is not None: @@ -383,7 +476,7 @@ def create( ------- str: The browsing context ID of the created navigable. """ - params = {"type": type} + params: dict[str, Any] = {"type": type} if reference_context is not None: params["referenceContext"] = reference_context if background is not None: @@ -411,7 +504,7 @@ def get_tree( ------- List[BrowsingContextInfo]: A list of browsing context information. """ - params = {} + params: dict[str, Any] = {} if max_depth is not None: params["maxDepth"] = max_depth if root is not None: @@ -434,7 +527,7 @@ def handle_user_prompt( accept: Whether to accept the prompt. user_text: The text to enter in the prompt. """ - params = {"context": context} + params: dict[str, Any] = {"context": context} if accept is not None: params["accept"] = accept if user_text is not None: @@ -464,7 +557,7 @@ def locate_nodes( ------- List[Dict]: A list of nodes. """ - params = {"context": context, "locator": locator} + params: dict[str, Any] = {"context": context, "locator": locator} if max_node_count is not None: params["maxNodeCount"] = max_node_count if serialization_options is not None: @@ -564,7 +657,7 @@ def reload( ------- Dict: A dictionary containing the navigation result. """ - params = {"context": context} + params: dict[str, Any] = {"context": context} if ignore_cache is not None: params["ignoreCache"] = ignore_cache if wait is not None: @@ -593,7 +686,7 @@ def set_viewport( ------ Exception: If the browsing context is not a top-level traversable. """ - params = {} + params: dict[str, Any] = {} if context is not None: params["context"] = context if viewport is not None: @@ -621,7 +714,7 @@ def traverse_history(self, context: str, delta: int) -> dict: result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) return result - def _on_event(self, event_name: str, callback: callable) -> int: + def _on_event(self, event_name: str, callback: Callable) -> int: """Set a callback function to subscribe to a browsing context event. Parameters: @@ -665,7 +758,7 @@ def _callback(event_data): return callback_id - def add_event_handler(self, event: str, callback: callable, contexts: Optional[list[str]] = None) -> int: + def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int: """Add an event handler to the browsing context. Parameters: @@ -710,15 +803,18 @@ def remove_event_handler(self, event: str, callback_id: int) -> None: except KeyError: raise Exception(f"Event {event} not found") - event = BrowsingContextEvent(event_name) + event_obj = BrowsingContextEvent(event_name) - self.conn.remove_callback(event, callback_id) - self.subscriptions[event_name].remove(callback_id) - if len(self.subscriptions[event_name]) == 0: - params = {"events": [event_name]} - session = Session(self.conn) - self.conn.execute(session.unsubscribe(**params)) - del self.subscriptions[event_name] + self.conn.remove_callback(event_obj, callback_id) + if event_name in self.subscriptions: + callbacks = self.subscriptions[event_name] + if callback_id in callbacks: + callbacks.remove(callback_id) + if not callbacks: + params = {"events": [event_name]} + session = Session(self.conn) + self.conn.execute(session.unsubscribe(**params)) + del self.subscriptions[event_name] def clear_event_handlers(self) -> None: """Clear all event handlers from the browsing context.""" diff --git a/py/selenium/webdriver/common/utils.py b/py/selenium/webdriver/common/utils.py index b04e2b0e40c30..7fade1fb45e3c 100644 --- a/py/selenium/webdriver/common/utils.py +++ b/py/selenium/webdriver/common/utils.py @@ -64,12 +64,12 @@ def find_connectable_ip(host: Union[str, bytes, bytearray, None], port: Optional for family, _, _, _, sockaddr in addrinfos: connectable = True if port: - connectable = is_connectable(port, sockaddr[0]) + connectable = is_connectable(port, str(sockaddr[0])) if connectable and family == socket.AF_INET: - return sockaddr[0] + return str(sockaddr[0]) if connectable and not ip and family == socket.AF_INET6: - ip = sockaddr[0] + ip = str(sockaddr[0]) return ip @@ -131,8 +131,7 @@ def keys_to_typing(value: Iterable[AnyKey]) -> list[str]: characters: list[str] = [] for val in value: if isinstance(val, Keys): - # Todo: Does this even work? - characters.append(val) + characters.append(str(val)) elif isinstance(val, (int, float)): characters.extend(str(val)) else: diff --git a/py/selenium/webdriver/remote/shadowroot.py b/py/selenium/webdriver/remote/shadowroot.py index 2d81f17e01426..e3603797e838c 100644 --- a/py/selenium/webdriver/remote/shadowroot.py +++ b/py/selenium/webdriver/remote/shadowroot.py @@ -16,8 +16,11 @@ # under the License. from hashlib import md5 as md5_hash +from typing import Union, TYPE_CHECKING +if TYPE_CHECKING: + from selenium.webdriver.support.relative_locator import RelativeBy -from ..common.by import By +from ..common.by import By, ByType from .command import Command @@ -43,7 +46,7 @@ def __repr__(self) -> str: def id(self) -> str: return self._id - def find_element(self, by: str = By.ID, value: str = None): + def find_element(self, by: "Union[ByType, RelativeBy]" = By.ID, value: str = None): """Find an element inside a shadow root given a By strategy and locator. @@ -82,7 +85,7 @@ def find_element(self, by: str = By.ID, value: str = None): return self._execute(Command.FIND_ELEMENT_FROM_SHADOW_ROOT, {"using": by, "value": value})["value"] - def find_elements(self, by: str = By.ID, value: str = None): + def find_elements(self, by: "Union[ByType, RelativeBy]" = By.ID, value: str = None): """Find elements inside a shadow root given a By strategy and locator. Parameters: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 149f12d8fe1a0..8ad8a182c8a99 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -30,7 +30,10 @@ from base64 import b64decode, urlsafe_b64encode from contextlib import asynccontextmanager, contextmanager from importlib import import_module -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from selenium.webdriver.support.relative_locator import RelativeBy from selenium.common.exceptions import ( InvalidArgumentException, @@ -46,7 +49,7 @@ from selenium.webdriver.common.bidi.session import Session from selenium.webdriver.common.bidi.storage import Storage from selenium.webdriver.common.bidi.webextension import WebExtension -from selenium.webdriver.common.by import By +from selenium.webdriver.common.by import By, ByType from selenium.webdriver.common.options import ArgOptions, BaseOptions from selenium.webdriver.common.print_page_options import PrintOptions from selenium.webdriver.common.timeouts import Timeouts @@ -55,8 +58,6 @@ VirtualAuthenticatorOptions, required_virtual_authenticator, ) -from selenium.webdriver.support.relative_locator import RelativeBy - from ..common.fedcm.dialog import Dialog from .bidi_connection import BidiConnection from .client_config import ClientConfig @@ -877,7 +878,7 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element(self, by=By.ID, value: Optional[str] = None) -> WebElement: + def find_element(self, by: "Union[ByType, RelativeBy]" = By.ID, value: Optional[str] = None) -> WebElement: """Find an element given a By strategy and locator. Parameters: @@ -913,7 +914,7 @@ def find_element(self, by=By.ID, value: Optional[str] = None) -> WebElement: return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] - def find_elements(self, by=By.ID, value: Optional[str] = None) -> list[WebElement]: + def find_elements(self, by: "Union[ByType, RelativeBy]" = By.ID, value: Optional[str] = None) -> list[WebElement]: """Find elements given a By strategy and locator. Parameters: diff --git a/py/selenium/webdriver/remote/webelement.py b/py/selenium/webdriver/remote/webelement.py index 0e5754d82b314..37043653c7a98 100644 --- a/py/selenium/webdriver/remote/webelement.py +++ b/py/selenium/webdriver/remote/webelement.py @@ -24,9 +24,14 @@ from base64 import b64decode, encodebytes from hashlib import md5 as md5_hash from io import BytesIO +from typing import Union, TYPE_CHECKING + +if TYPE_CHECKING: + from selenium.webdriver.support.relative_locator import RelativeBy + from selenium.common.exceptions import JavascriptException, WebDriverException -from selenium.webdriver.common.by import By +from selenium.webdriver.common.by import By, ByType from selenium.webdriver.common.utils import keys_to_typing from .command import Command @@ -572,7 +577,7 @@ def _execute(self, command, params=None): params["id"] = self._id return self._parent.execute(command, params) - def find_element(self, by=By.ID, value=None) -> WebElement: + def find_element(self, by: "Union[ByType, RelativeBy]" = By.ID, value=None) -> WebElement: """Find an element given a By strategy and locator. Parameters: @@ -601,7 +606,7 @@ def find_element(self, by=By.ID, value=None) -> WebElement: by, value = self._parent.locator_converter.convert(by, value) return self._execute(Command.FIND_CHILD_ELEMENT, {"using": by, "value": value})["value"] - def find_elements(self, by=By.ID, value=None) -> list[WebElement]: + def find_elements(self, by: "Union[ByType, RelativeBy]" = By.ID, value=None) -> list[WebElement]: """Find elements given a By strategy and locator. Parameters: diff --git a/py/selenium/webdriver/support/expected_conditions.py b/py/selenium/webdriver/support/expected_conditions.py index 06a5f36d9e7c6..bd4fdf4f7d6c4 100644 --- a/py/selenium/webdriver/support/expected_conditions.py +++ b/py/selenium/webdriver/support/expected_conditions.py @@ -17,7 +17,10 @@ import re from collections.abc import Iterable -from typing import Any, Callable, Literal, TypeVar, Union +from typing import Any, Callable, Literal, TypeVar, Union, Tuple + +from selenium.webdriver.common.by import ByType +from selenium.webdriver.support.relative_locator import RelativeBy from selenium.common.exceptions import ( NoAlertPresentException, @@ -38,6 +41,7 @@ T = TypeVar("T") WebDriverOrWebElement = Union[WebDriver, WebElement] +LocatorType = Union[Tuple[ByType, str], Tuple[RelativeBy, None]] def title_is(title: str) -> Callable[[WebDriver], bool]: @@ -79,7 +83,7 @@ def _predicate(driver: WebDriver): return _predicate -def presence_of_element_located(locator: tuple[str, str]) -> Callable[[WebDriverOrWebElement], WebElement]: +def presence_of_element_located(locator: LocatorType) -> Callable[[WebDriverOrWebElement], WebElement]: """An expectation for checking that an element is present on the DOM of a page. This does not necessarily mean that the element is visible. @@ -189,7 +193,7 @@ def _predicate(driver: WebDriver): def visibility_of_element_located( - locator: tuple[str, str], + locator: LocatorType, ) -> Callable[[WebDriverOrWebElement], Union[Literal[False], WebElement]]: """An expectation for checking that an element is present on the DOM of a page and visible. Visibility means that the element is not only displayed @@ -272,7 +276,7 @@ def _element_if_visible(element: WebElement, visibility: bool = True) -> Union[L return element if element.is_displayed() == visibility else False -def presence_of_all_elements_located(locator: tuple[str, str]) -> Callable[[WebDriverOrWebElement], list[WebElement]]: +def presence_of_all_elements_located(locator: LocatorType) -> Callable[[WebDriverOrWebElement], list[WebElement]]: """An expectation for checking that there is at least one element present on a web page. @@ -299,7 +303,7 @@ def _predicate(driver: WebDriverOrWebElement): return _predicate -def visibility_of_any_elements_located(locator: tuple[str, str]) -> Callable[[WebDriverOrWebElement], list[WebElement]]: +def visibility_of_any_elements_located(locator: LocatorType) -> Callable[[WebDriverOrWebElement], list[WebElement]]: """An expectation for checking that there is at least one element visible on a web page. @@ -327,7 +331,7 @@ def _predicate(driver: WebDriverOrWebElement): def visibility_of_all_elements_located( - locator: tuple[str, str], + locator: LocatorType, ) -> Callable[[WebDriverOrWebElement], Union[list[WebElement], Literal[False]]]: """An expectation for checking that all elements are present on the DOM of a page and visible. Visibility means that the elements are not only @@ -363,7 +367,7 @@ def _predicate(driver: WebDriverOrWebElement): return _predicate -def text_to_be_present_in_element(locator: tuple[str, str], text_: str) -> Callable[[WebDriverOrWebElement], bool]: +def text_to_be_present_in_element(locator: LocatorType, text_: str) -> Callable[[WebDriverOrWebElement], bool]: """An expectation for checking if the given text is present in the specified element. @@ -399,7 +403,7 @@ def _predicate(driver: WebDriverOrWebElement): def text_to_be_present_in_element_value( - locator: tuple[str, str], text_: str + locator: LocatorType, text_: str ) -> Callable[[WebDriverOrWebElement], bool]: """An expectation for checking if the given text is present in the element's value. @@ -436,7 +440,7 @@ def _predicate(driver: WebDriverOrWebElement): def text_to_be_present_in_element_attribute( - locator: tuple[str, str], attribute_: str, text_: str + locator: LocatorType, attribute_: str, text_: str ) -> Callable[[WebDriverOrWebElement], bool]: """An expectation for checking if the given text is present in the element's attribute. @@ -687,7 +691,7 @@ def _predicate(_): return _predicate -def element_located_to_be_selected(locator: tuple[str, str]) -> Callable[[WebDriverOrWebElement], bool]: +def element_located_to_be_selected(locator: LocatorType) -> Callable[[WebDriverOrWebElement], bool]: """An expectation for the element to be located is selected. Parameters: @@ -743,7 +747,7 @@ def _predicate(_): def element_located_selection_state_to_be( - locator: tuple[str, str], is_selected: bool + locator: LocatorType, is_selected: bool ) -> Callable[[WebDriverOrWebElement], bool]: """An expectation to locate an element and check if the selection state specified is in that state. @@ -858,7 +862,7 @@ def _predicate(driver: WebDriver): return _predicate -def element_attribute_to_include(locator: tuple[str, str], attribute_: str) -> Callable[[WebDriverOrWebElement], bool]: +def element_attribute_to_include(locator: LocatorType, attribute_: str) -> Callable[[WebDriverOrWebElement], bool]: """An expectation for checking if the given attribute is included in the specified element. diff --git a/py/test/unit/selenium/webdriver/support/test_expected_conditions_relative_by.py b/py/test/unit/selenium/webdriver/support/test_expected_conditions_relative_by.py new file mode 100644 index 0000000000000..9a6dd2eabdfde --- /dev/null +++ b/py/test/unit/selenium/webdriver/support/test_expected_conditions_relative_by.py @@ -0,0 +1,57 @@ +import pytest +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.relative_locator import locate_with +from selenium.webdriver.common.by import By +from unittest.mock import Mock + + +class TestExpectedConditionsRelativeBy: + """Test that expected conditions accept RelativeBy in type annotations""" + + def test_presence_of_element_located_accepts_relative_by(self): + """Test presence_of_element_located accepts RelativeBy""" + relative_by = locate_with(By.TAG_NAME, "div").above({By.ID: "footer"}) + condition = EC.presence_of_element_located(relative_by) + assert condition is not None + + def test_visibility_of_element_located_accepts_relative_by(self): + """Test visibility_of_element_located accepts RelativeBy""" + relative_by = locate_with(By.TAG_NAME, "button").near({By.CLASS_NAME: "submit"}) + condition = EC.visibility_of_element_located(relative_by) + assert condition is not None + + def test_presence_of_all_elements_located_accepts_relative_by(self): + """Test presence_of_all_elements_located accepts RelativeBy""" + relative_by = locate_with(By.CSS_SELECTOR, ".item").below({By.ID: "header"}) + condition = EC.presence_of_all_elements_located(relative_by) + assert condition is not None + + def test_visibility_of_any_elements_located_accepts_relative_by(self): + """Test visibility_of_any_elements_located accepts RelativeBy""" + relative_by = locate_with(By.TAG_NAME, "span").to_left_of({By.ID: "sidebar"}) + condition = EC.visibility_of_any_elements_located(relative_by) + assert condition is not None + + def test_text_to_be_present_in_element_accepts_relative_by(self): + """Test text_to_be_present_in_element accepts RelativeBy""" + relative_by = locate_with(By.TAG_NAME, "p").above({By.CLASS_NAME: "footer"}) + condition = EC.text_to_be_present_in_element(relative_by, "Hello") + assert condition is not None + + def test_element_to_be_clickable_accepts_relative_by(self): + """Test element_to_be_clickable accepts RelativeBy""" + relative_by = locate_with(By.TAG_NAME, "button").near({By.ID: "form"}) + condition = EC.element_to_be_clickable(relative_by) + assert condition is not None + + def test_invisibility_of_element_located_accepts_relative_by(self): + """Test invisibility_of_element_located accepts RelativeBy""" + relative_by = locate_with(By.CSS_SELECTOR, ".loading").above({By.ID: "content"}) + condition = EC.invisibility_of_element_located(relative_by) + assert condition is not None + + def test_element_located_to_be_selected_accepts_relative_by(self): + """Test element_located_to_be_selected accepts RelativeBy""" + relative_by = locate_with(By.TAG_NAME, "input").near({By.ID: "terms-label"}) + condition = EC.element_located_to_be_selected(relative_by) + assert condition is not None \ No newline at end of file diff --git a/py/test/unit/selenium/webdriver/test_relative_by_annotations.py b/py/test/unit/selenium/webdriver/test_relative_by_annotations.py new file mode 100644 index 0000000000000..0e149e7b86d18 --- /dev/null +++ b/py/test/unit/selenium/webdriver/test_relative_by_annotations.py @@ -0,0 +1,65 @@ +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.support.relative_locator import RelativeBy, locate_with +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.remote.webelement import WebElement +from selenium.webdriver.remote.shadowroot import ShadowRoot +from unittest.mock import Mock, MagicMock + + +class TestRelativeByAnnotations: + """Test that RelativeBy is properly accepted in type annotations""" + + def test_webdriver_find_element_accepts_relative_by(self): + """Test WebDriver.find_element accepts RelativeBy""" + driver = Mock(spec=WebDriver) + relative_by = locate_with(By.TAG_NAME, "div").above({By.ID: "footer"}) + + # This should not raise type checking errors + driver.find_element(by=relative_by) + driver.find_element(relative_by) + + def test_webdriver_find_elements_accepts_relative_by(self): + """Test WebDriver.find_elements accepts RelativeBy""" + driver = Mock(spec=WebDriver) + relative_by = locate_with(By.TAG_NAME, "div").below({By.ID: "header"}) + + # This should not raise type checking errors + driver.find_elements(by=relative_by) + driver.find_elements(relative_by) + + def test_webelement_find_element_accepts_relative_by(self): + """Test WebElement.find_element accepts RelativeBy""" + element = Mock(spec=WebElement) + relative_by = locate_with(By.TAG_NAME, "span").near({By.CLASS_NAME: "button"}) + + # This should not raise type checking errors + element.find_element(by=relative_by) + element.find_element(relative_by) + + def test_webelement_find_elements_accepts_relative_by(self): + """Test WebElement.find_elements accepts RelativeBy""" + element = Mock(spec=WebElement) + relative_by = locate_with(By.TAG_NAME, "input").to_left_of({By.ID: "submit"}) + + # This should not raise type checking errors + element.find_elements(by=relative_by) + element.find_elements(relative_by) + + def test_shadowroot_find_element_accepts_relative_by(self): + """Test ShadowRoot.find_element accepts RelativeBy""" + shadow_root = Mock(spec=ShadowRoot) + relative_by = locate_with(By.TAG_NAME, "button").to_right_of({By.ID: "cancel"}) + + # This should not raise type checking errors + shadow_root.find_element(by=relative_by) + shadow_root.find_element(relative_by) + + def test_shadowroot_find_elements_accepts_relative_by(self): + """Test ShadowRoot.find_elements accepts RelativeBy""" + shadow_root = Mock(spec=ShadowRoot) + relative_by = locate_with(By.CSS_SELECTOR, ".item").above({By.ID: "footer"}) + + # This should not raise type checking errors + shadow_root.find_elements(by=relative_by) + shadow_root.find_elements(relative_by) \ No newline at end of file