diff --git a/py/selenium/webdriver/common/actions/interaction.py b/py/selenium/webdriver/common/actions/interaction.py index 6225efcf143e7..413ee773f741c 100644 --- a/py/selenium/webdriver/common/actions/interaction.py +++ b/py/selenium/webdriver/common/actions/interaction.py @@ -16,11 +16,13 @@ # under the License. from typing import Union +from .input_device import InputDevice + KEY = "key" POINTER = "pointer" NONE = "none" WHEEL = "wheel" -SOURCE_TYPES = {KEY, POINTER, NONE} +SOURCE_TYPES = {KEY, POINTER, WHEEL, NONE} POINTER_MOUSE = "mouse" POINTER_TOUCH = "touch" @@ -32,7 +34,7 @@ class Interaction: PAUSE = "pause" - def __init__(self, source: str) -> None: + def __init__(self, source: InputDevice) -> None: self.source = source diff --git a/py/selenium/webdriver/common/actions/key_actions.py b/py/selenium/webdriver/common/actions/key_actions.py index b5052fe5b188e..bc23244478902 100644 --- a/py/selenium/webdriver/common/actions/key_actions.py +++ b/py/selenium/webdriver/common/actions/key_actions.py @@ -18,7 +18,7 @@ from __future__ import annotations from ..utils import keys_to_typing -from .interaction import KEY, POINTER, WHEEL, Interaction +from .interaction import KEY, Interaction from .key_input import KeyInput from .pointer_input import PointerInput from .wheel_input import WheelInput @@ -29,18 +29,7 @@ def __init__(self, source: KeyInput | PointerInput | WheelInput | None = None) - if source is None: source = KeyInput(KEY) self.input_source = source - - # Determine the correct source type string based on the input object - if isinstance(source, KeyInput): - source_type = KEY - elif isinstance(source, PointerInput): - source_type = POINTER - elif isinstance(source, WheelInput): - source_type = WHEEL - else: - source_type = KEY - - super().__init__(source_type) + super().__init__(source) def key_down(self, letter: str) -> KeyActions: return self._key_action("create_key_down", letter) @@ -60,6 +49,6 @@ def send_keys(self, text: str | list) -> KeyActions: return self def _key_action(self, action: str, letter) -> KeyActions: - meth = getattr(self.input_source, action) + meth = getattr(self.source, action) meth(letter) return self diff --git a/py/selenium/webdriver/common/actions/wheel_actions.py b/py/selenium/webdriver/common/actions/wheel_actions.py index f258f293a6d76..16da6bc56b527 100644 --- a/py/selenium/webdriver/common/actions/wheel_actions.py +++ b/py/selenium/webdriver/common/actions/wheel_actions.py @@ -17,14 +17,14 @@ from typing import Optional -from .interaction import Interaction +from .interaction import WHEEL, Interaction from .wheel_input import WheelInput class WheelActions(Interaction): def __init__(self, source: Optional[WheelInput] = None): if source is None: - source = WheelInput("wheel") + source = WheelInput(WHEEL) super().__init__(source) def pause(self, duration: float = 0):