diff --git a/appium/webdriver/extensions/images_comparison.py b/appium/webdriver/extensions/images_comparison.py index 6f730815..cae80168 100644 --- a/appium/webdriver/extensions/images_comparison.py +++ b/appium/webdriver/extensions/images_comparison.py @@ -18,9 +18,11 @@ from ..mobilecommand import MobileCommand as Command +Base64Payload = Union[str, bytes] + class ImagesComparison(CanExecuteCommands): - def match_images_features(self, base64_image1: bytes, base64_image2: bytes, **opts: Any) -> Dict[str, Any]: + def match_images_features(self, base64_image1: Base64Payload, base64_image2: Base64Payload, **opts: Any) -> Dict[str, Any]: """Performs images matching by features. Read @@ -66,11 +68,16 @@ def match_images_features(self, base64_image1: bytes, base64_image2: bytes, **op rect2 (dict): The bounding rect for the `points2` array or a zero rect if not enough matching points were found. The rect is represented by a dictionary with 'x', 'y', 'width' and 'height' keys """ - options = {'mode': 'matchFeatures', 'firstImage': base64_image1, 'secondImage': base64_image2, 'options': opts} + options = { + 'mode': 'matchFeatures', + 'firstImage': _adjust_image_payload(base64_image1), + 'secondImage': _adjust_image_payload(base64_image2), + 'options': opts, + } return self.execute(Command.COMPARE_IMAGES, options)['value'] def find_image_occurrence( - self, base64_full_image: bytes, base64_partial_image: bytes, **opts: Any + self, base64_full_image: Base64Payload, base64_partial_image: Base64Payload, **opts: Any ) -> Dict[str, Union[bytes, Dict]]: """Performs images matching by template to find possible occurrence of the partial image in the full image. @@ -97,13 +104,15 @@ def find_image_occurrence( """ options = { 'mode': 'matchTemplate', - 'firstImage': base64_full_image, - 'secondImage': base64_partial_image, + 'firstImage': _adjust_image_payload(base64_full_image), + 'secondImage': _adjust_image_payload(base64_partial_image), 'options': opts, } return self.execute(Command.COMPARE_IMAGES, options)['value'] - def get_images_similarity(self, base64_image1: bytes, base64_image2: bytes, **opts: Any) -> Dict[str, Union[bytes, Dict]]: + def get_images_similarity( + self, base64_image1: Base64Payload, base64_image2: Base64Payload, **opts: Any + ) -> Dict[str, Union[bytes, Dict]]: """Performs images matching to calculate the similarity score between them. The flow there is similar to the one used in @@ -125,8 +134,17 @@ def get_images_similarity(self, base64_image1: bytes, base64_image2: bytes, **op score (float): The similarity score as a float number in range [0.0, 1.0]. 1.0 is the highest score (means both images are totally equal). """ - options = {'mode': 'getSimilarity', 'firstImage': base64_image1, 'secondImage': base64_image2, 'options': opts} + options = { + 'mode': 'getSimilarity', + 'firstImage': _adjust_image_payload(base64_image1), + 'secondImage': _adjust_image_payload(base64_image2), + 'options': opts, + } return self.execute(Command.COMPARE_IMAGES, options)['value'] def _add_commands(self) -> None: self.command_executor.add_command(Command.COMPARE_IMAGES, 'POST', '/session/$sessionId/appium/compare_images') + + +def _adjust_image_payload(payload: Base64Payload) -> str: + return payload if isinstance(payload, str) else payload.decode('utf-8')