diff --git a/examples/python/visualization/README.md b/examples/python/visualization/README.md new file mode 100644 index 00000000..6be5c9df --- /dev/null +++ b/examples/python/visualization/README.md @@ -0,0 +1,15 @@ +# Visualization Example + +This example demonstrates how to use the Visualizer in VisionAPI. + +## Prerequisites + +Install Model API from source. Please refer to the main [README](../../../README.md) for details. + +## Run example + +To run the example, please execute the following command: + +```bash +python run.py --image --model .xml --output +``` diff --git a/examples/python/visualization/run.py b/examples/python/visualization/run.py new file mode 100644 index 00000000..c31c329f --- /dev/null +++ b/examples/python/visualization/run.py @@ -0,0 +1,38 @@ +"""Visualization Example.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +from argparse import Namespace + +import cv2 +import numpy as np +from PIL import Image + +from model_api.models import Model +from model_api.visualizer import Visualizer + + +def main(args: Namespace): + image = Image.open(args.image) + + model = Model.create_model(args.model) + + image_array = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + predictions = model(image_array) + visualizer = Visualizer() + + if args.output: + visualizer.save(image=image, result=predictions, path=args.output) + else: + visualizer.show(image=image, result=predictions) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image", type=str, required=True) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--output", type=str, required=False) + args = parser.parse_args() + main(args) diff --git a/src/python/model_api/visualizer/__init__.py b/src/python/model_api/visualizer/__init__.py index 924d4d0e..d4833637 100644 --- a/src/python/model_api/visualizer/__init__.py +++ b/src/python/model_api/visualizer/__init__.py @@ -4,8 +4,19 @@ # SPDX-License-Identifier: Apache-2.0 from .layout import Flatten, HStack, Layout -from .primitive import BoundingBox, Label, Overlay, Polygon +from .primitive import BoundingBox, Keypoint, Label, Overlay, Polygon from .scene import Scene from .visualizer import Visualizer -__all__ = ["BoundingBox", "Label", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"] +__all__ = [ + "BoundingBox", + "Keypoint", + "Label", + "Overlay", + "Polygon", + "Scene", + "Visualizer", + "Layout", + "Flatten", + "HStack", +] diff --git a/src/python/model_api/visualizer/primitive/__init__.py b/src/python/model_api/visualizer/primitive/__init__.py index ba6c135c..42ad6d9f 100644 --- a/src/python/model_api/visualizer/primitive/__init__.py +++ b/src/python/model_api/visualizer/primitive/__init__.py @@ -4,9 +4,10 @@ # SPDX-License-Identifier: Apache-2.0 from .bounding_box import BoundingBox +from .keypoints import Keypoint from .label import Label from .overlay import Overlay from .polygon import Polygon from .primitive import Primitive -__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon"] +__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon", "Keypoint"] diff --git a/src/python/model_api/visualizer/primitive/keypoints.py b/src/python/model_api/visualizer/primitive/keypoints.py new file mode 100644 index 00000000..66a2a0f4 --- /dev/null +++ b/src/python/model_api/visualizer/primitive/keypoints.py @@ -0,0 +1,65 @@ +"""Keypoints primitive.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from .primitive import Primitive + + +class Keypoint(Primitive): + """Keypoint primitive. + + Args: + keypoints (np.ndarray): Keypoints. Shape: (N, 2) + scores (np.ndarray | None): Scores. Shape: (N,). Defaults to None. + color (str | tuple[int, int, int]): Color of the keypoints. Defaults to "purple". + """ + + def __init__( + self, + keypoints: np.ndarray, + scores: Union[np.ndarray, None] = None, + color: Union[str, tuple[int, int, int]] = "purple", + keypoint_size: int = 3, + ) -> None: + self.keypoints = self._validate_keypoints(keypoints) + self.scores = scores + self.color = color + self.keypoint_size = keypoint_size + + def compute(self, image: Image) -> Image: + """Draw keypoints on the image.""" + draw = ImageDraw.Draw(image) + for keypoint in self.keypoints: + draw.ellipse( + ( + keypoint[0] - self.keypoint_size, + keypoint[1] - self.keypoint_size, + keypoint[0] + self.keypoint_size, + keypoint[1] + self.keypoint_size, + ), + fill=self.color, + ) + + if self.scores is not None: + font = ImageFont.load_default(size=18) + for score, keypoint in zip(self.scores, self.keypoints): + textbox = draw.textbbox((0, 0), f"{score:.2f}", font=font) + draw.text( + (keypoint[0] - textbox[2] // 2, keypoint[1] + self.keypoint_size), + f"{score:.2f}", + font=font, + fill=self.color, + ) + return image + + def _validate_keypoints(self, keypoints: np.ndarray) -> np.ndarray: + if keypoints.shape[1] != 2: + msg = "Keypoints must have shape (N, 2)" + raise ValueError(msg) + return keypoints diff --git a/src/python/model_api/visualizer/primitive/polygon.py b/src/python/model_api/visualizer/primitive/polygon.py index a6ccc99c..74357264 100644 --- a/src/python/model_api/visualizer/primitive/polygon.py +++ b/src/python/model_api/visualizer/primitive/polygon.py @@ -5,16 +5,19 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING import cv2 -from PIL import Image, ImageDraw +from PIL import Image, ImageColor, ImageDraw from .primitive import Primitive if TYPE_CHECKING: import numpy as np +logger = logging.getLogger(__name__) + class Polygon(Primitive): """Polygon primitive. @@ -38,9 +41,13 @@ def __init__( points: list[tuple[int, int]] | None = None, mask: np.ndarray | None = None, color: str | tuple[int, int, int] = "blue", + opacity: float = 0.4, + outline_width: int = 2, ) -> None: self.points = self._get_points(points, mask) self.color = color + self.opacity = opacity + self.outline_width = outline_width def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]: """Get points from either points or mask. @@ -76,6 +83,13 @@ def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]: List of points. """ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # incase of multiple contours, use the one with the largest area + if len(contours) > 1: + logger.warning("Multiple contours found in the mask. Using the largest one.") + contours = sorted(contours, key=cv2.contourArea, reverse=True) + if len(contours) == 0: + msg = "No contours found in the mask." + raise ValueError(msg) points_ = contours[0].squeeze().tolist() return [tuple(point) for point in points_] @@ -88,6 +102,8 @@ def compute(self, image: Image) -> Image: Returns: Image with the polygon drawn on it. """ - draw = ImageDraw.Draw(image) - draw.polygon(self.points, fill=self.color) + draw = ImageDraw.Draw(image, "RGBA") + # Draw polygon with darker edge and a semi-transparent fill. + ink = ImageColor.getrgb(self.color) + draw.polygon(self.points, fill=(*ink, int(255 * self.opacity)), outline=self.color, width=self.outline_width) return image diff --git a/src/python/model_api/visualizer/scene/__init__.py b/src/python/model_api/visualizer/scene/__init__.py index 84469928..77d6d309 100644 --- a/src/python/model_api/visualizer/scene/__init__.py +++ b/src/python/model_api/visualizer/scene/__init__.py @@ -8,13 +8,14 @@ from .detection import DetectionScene from .keypoint import KeypointScene from .scene import Scene -from .segmentation import SegmentationScene +from .segmentation import InstanceSegmentationScene, SegmentationScene from .visual_prompting import VisualPromptingScene __all__ = [ "AnomalyScene", "ClassificationScene", "DetectionScene", + "InstanceSegmentationScene", "KeypointScene", "Scene", "SegmentationScene", diff --git a/src/python/model_api/visualizer/scene/detection.py b/src/python/model_api/visualizer/scene/detection.py index aee838c3..7ffa3e97 100644 --- a/src/python/model_api/visualizer/scene/detection.py +++ b/src/python/model_api/visualizer/scene/detection.py @@ -32,8 +32,10 @@ def _get_overlays(self, result: DetectionResult) -> list[Overlay]: label_index_mapping = dict(zip(result.labels, result.label_names)) for label_index, label_name in label_index_mapping.items(): # Index 0 as it assumes only one batch - saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET) - overlays.append(Overlay(saliency_map, label=label_name.title())) + if result.saliency_map is not None and result.saliency_map.size > 0: + saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET) + saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB) + overlays.append(Overlay(saliency_map, label=label_name.title())) return overlays def _get_bounding_boxes(self, result: DetectionResult) -> list[BoundingBox]: diff --git a/src/python/model_api/visualizer/scene/keypoint.py b/src/python/model_api/visualizer/scene/keypoint.py index 3e34711c..0dd7ac99 100644 --- a/src/python/model_api/visualizer/scene/keypoint.py +++ b/src/python/model_api/visualizer/scene/keypoint.py @@ -3,9 +3,13 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from typing import Union + +from PIL import Image + from model_api.models.result import DetectedKeypoints from model_api.visualizer.layout import Flatten, Layout -from model_api.visualizer.primitive import Overlay +from model_api.visualizer.primitive import Keypoint from .scene import Scene @@ -13,9 +17,16 @@ class KeypointScene(Scene): """Keypoint Scene.""" - def __init__(self, result: DetectedKeypoints) -> None: - self.result = result + def __init__(self, image: Image, result: DetectedKeypoints, layout: Union[Layout, None] = None) -> None: + super().__init__( + base=image, + keypoints=self._get_keypoints(result), + layout=layout, + ) + + def _get_keypoints(self, result: DetectedKeypoints) -> list[Keypoint]: + return [Keypoint(result.keypoints, result.scores)] @property def default_layout(self) -> Layout: - return Flatten(Overlay) + return Flatten(Keypoint) diff --git a/src/python/model_api/visualizer/scene/scene.py b/src/python/model_api/visualizer/scene/scene.py index d798a0da..b8c22318 100644 --- a/src/python/model_api/visualizer/scene/scene.py +++ b/src/python/model_api/visualizer/scene/scene.py @@ -10,7 +10,7 @@ import numpy as np from PIL import Image -from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon, Primitive +from model_api.visualizer.primitive import BoundingBox, Keypoint, Label, Overlay, Polygon, Primitive if TYPE_CHECKING: from pathlib import Path @@ -31,6 +31,7 @@ def __init__( label: Label | list[Label] | None = None, overlay: Overlay | list[Overlay] | np.ndarray | None = None, polygon: Polygon | list[Polygon] | None = None, + keypoints: Keypoint | list[Keypoint] | np.ndarray | None = None, layout: Layout | None = None, ) -> None: self.base = base @@ -38,6 +39,7 @@ def __init__( self.bounding_box = self._to_bounding_box(bounding_box) self.label = self._to_label(label) self.polygon = self._to_polygon(polygon) + self.keypoints = self._to_keypoints(keypoints) self.layout = layout def show(self) -> None: @@ -60,6 +62,8 @@ def has_primitives(self, primitive: type[Primitive]) -> bool: return bool(self.label) if primitive == Polygon: return bool(self.polygon) + if primitive == Keypoint: + return bool(self.keypoints) return False def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]: @@ -86,6 +90,8 @@ def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]: primitives = cast("list[Primitive]", self.label) elif primitive == Polygon: primitives = cast("list[Primitive]", self.polygon) + elif primitive == Keypoint: + primitives = cast("list[Primitive]", self.keypoints) else: msg = f"Primitive {primitive} not found" raise ValueError(msg) @@ -119,3 +125,10 @@ def _to_polygon(self, polygon: Polygon | list[Polygon] | None) -> list[Polygon] if isinstance(polygon, Polygon): return [polygon] return polygon + + def _to_keypoints(self, keypoints: Keypoint | list[Keypoint] | np.ndarray | None) -> list[Keypoint] | None: + if isinstance(keypoints, Keypoint): + return [keypoints] + if isinstance(keypoints, np.ndarray): + return [Keypoint(keypoints)] + return keypoints diff --git a/src/python/model_api/visualizer/scene/segmentation.py b/src/python/model_api/visualizer/scene/segmentation.py deleted file mode 100644 index e666804e..00000000 --- a/src/python/model_api/visualizer/scene/segmentation.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Segmentation Scene.""" - -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -from model_api.models.result import InstanceSegmentationResult - -from .scene import Scene - - -class SegmentationScene(Scene): - """Segmentation Scene.""" - - def __init__(self, result: InstanceSegmentationResult) -> None: - self.result = result diff --git a/src/python/model_api/visualizer/scene/segmentation/__init__.py b/src/python/model_api/visualizer/scene/segmentation/__init__.py new file mode 100644 index 00000000..e55807b2 --- /dev/null +++ b/src/python/model_api/visualizer/scene/segmentation/__init__.py @@ -0,0 +1,12 @@ +"""Segmentation Scene.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .instance_segmentation import InstanceSegmentationScene +from .segmentation import SegmentationScene + +__all__ = [ + "InstanceSegmentationScene", + "SegmentationScene", +] diff --git a/src/python/model_api/visualizer/scene/segmentation/instance_segmentation.py b/src/python/model_api/visualizer/scene/segmentation/instance_segmentation.py new file mode 100644 index 00000000..f3dbcc8a --- /dev/null +++ b/src/python/model_api/visualizer/scene/segmentation/instance_segmentation.py @@ -0,0 +1,69 @@ +"""Instance Segmentation Scene.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import random +from typing import Union + +import cv2 +from PIL import Image + +from model_api.models.result import InstanceSegmentationResult +from model_api.visualizer.layout import Flatten, HStack, Layout +from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon +from model_api.visualizer.scene import Scene + + +class InstanceSegmentationScene(Scene): + """Instance Segmentation Scene.""" + + def __init__(self, image: Image, result: InstanceSegmentationResult, layout: Union[Layout, None] = None) -> None: + # nosec as random is used for color generation + self.color_per_label = {label: f"#{random.randint(0, 0xFFFFFF):06x}" for label in set(result.label_names)} # noqa: S311 + super().__init__( + base=image, + label=self._get_labels(result), + overlay=self._get_overlays(result), + polygon=self._get_polygons(result), + layout=layout, + ) + + def _get_labels(self, result: InstanceSegmentationResult) -> list[Label]: + # add only unique labels + labels = [] + for label_name in set(result.label_names): + labels.append(Label(label=label_name, bg_color=self.color_per_label[label_name])) + return labels + + def _get_polygons(self, result: InstanceSegmentationResult) -> list[Polygon]: + polygons = [] + for mask, label_name in zip(result.masks, result.label_names): + polygons.append(Polygon(mask=mask, color=self.color_per_label[label_name])) + return polygons + + def _get_bounding_boxes(self, result: InstanceSegmentationResult) -> list[BoundingBox]: + bounding_boxes = [] + for bbox, label_name, score in zip(result.bboxes, result.label_names, result.scores): + x1, y1, x2, y2 = bbox + label = f"{label_name} ({score:.2f})" + bounding_boxes.append( + BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, label=label, color=self.color_per_label[label_name]), + ) + return bounding_boxes + + def _get_overlays(self, result: InstanceSegmentationResult) -> list[Overlay]: + overlays = [] + if len(result.saliency_map) > 0: + labels_label_names_mapping = dict(zip(result.labels, result.label_names)) + for label, label_name in labels_label_names_mapping.items(): + saliency_map = result.saliency_map[label - 1] + saliency_map = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET) + saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB) + overlays.append(Overlay(saliency_map, label=f"{label_name.title()} Saliency Map")) + return overlays + + @property + def default_layout(self) -> Layout: + # by default bounding box is not shown. + return HStack(Flatten(Label, Polygon), Overlay) diff --git a/src/python/model_api/visualizer/scene/segmentation/segmentation.py b/src/python/model_api/visualizer/scene/segmentation/segmentation.py new file mode 100644 index 00000000..5129b3f9 --- /dev/null +++ b/src/python/model_api/visualizer/scene/segmentation/segmentation.py @@ -0,0 +1,48 @@ +"""Segmentation Scene.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import cv2 +import numpy as np +from PIL import Image + +from model_api.models.result import ImageResultWithSoftPrediction +from model_api.visualizer.layout import HStack, Layout +from model_api.visualizer.primitive import Overlay +from model_api.visualizer.scene import Scene + + +class SegmentationScene(Scene): + """Segmentation Scene.""" + + def __init__(self, image: Image, result: ImageResultWithSoftPrediction, layout: Union[Layout, None] = None) -> None: + super().__init__( + base=image, + overlay=self._get_overlays(result), + layout=layout, + ) + + def _get_overlays(self, result: ImageResultWithSoftPrediction) -> list[Overlay]: + overlays = [] + # Use the hard prediction to get the overlays + hard_prediction = result.resultImage # shape H,W + num_classes = hard_prediction.max() + for i in range(1, num_classes + 1): # ignore background + class_map = (hard_prediction == i).astype(np.uint8) * 255 + class_map = cv2.applyColorMap(class_map, cv2.COLORMAP_JET) + class_map = cv2.cvtColor(class_map, cv2.COLOR_BGR2RGB) + overlays.append(Overlay(class_map, label=f"Class {i}")) + + # Add saliency map + if result.saliency_map.size > 0: + saliency_map = cv2.cvtColor(result.saliency_map, cv2.COLOR_BGR2RGB) + overlays.append(Overlay(saliency_map, label="Saliency Map")) + + return overlays + + @property + def default_layout(self) -> Layout: + return HStack(Overlay) diff --git a/src/python/model_api/visualizer/visualizer.py b/src/python/model_api/visualizer/visualizer.py index e665e59a..5489442e 100644 --- a/src/python/model_api/visualizer/visualizer.py +++ b/src/python/model_api/visualizer/visualizer.py @@ -11,12 +11,23 @@ from model_api.models.result import ( AnomalyResult, ClassificationResult, + DetectedKeypoints, DetectionResult, + ImageResultWithSoftPrediction, + InstanceSegmentationResult, Result, ) from .layout import Layout -from .scene import AnomalyScene, ClassificationScene, DetectionScene, Scene +from .scene import ( + AnomalyScene, + ClassificationScene, + DetectionScene, + InstanceSegmentationScene, + KeypointScene, + Scene, + SegmentationScene, +) class Visualizer: @@ -39,8 +50,16 @@ def _scene_from_result(self, image: Image, result: Result) -> Scene: scene = AnomalyScene(image, result, self.layout) elif isinstance(result, ClassificationResult): scene = ClassificationScene(image, result, self.layout) + elif isinstance(result, InstanceSegmentationResult): + # Note: This has to be before DetectionScene because InstanceSegmentationResult is a subclass + # of DetectionResult + scene = InstanceSegmentationScene(image, result, self.layout) + elif isinstance(result, ImageResultWithSoftPrediction): + scene = SegmentationScene(image, result, self.layout) elif isinstance(result, DetectionResult): scene = DetectionScene(image, result, self.layout) + elif isinstance(result, DetectedKeypoints): + scene = KeypointScene(image, result, self.layout) else: msg = f"Unsupported result type: {type(result)}" raise ValueError(msg) diff --git a/tests/python/unit/visualizer/test_primitive.py b/tests/python/unit/visualizer/test_primitive.py index 85ef4447..0c2e028c 100644 --- a/tests/python/unit/visualizer/test_primitive.py +++ b/tests/python/unit/visualizer/test_primitive.py @@ -5,9 +5,10 @@ import numpy as np import PIL +import pytest from PIL import ImageDraw -from model_api.visualizer import BoundingBox, Label, Overlay, Polygon +from model_api.visualizer import BoundingBox, Keypoint, Label, Overlay, Polygon def test_overlay(mock_image: PIL.Image): @@ -39,8 +40,13 @@ def test_polygon(mock_image: PIL.Image): # Test from points expected_image = mock_image.copy() draw = ImageDraw.Draw(expected_image) - draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red") - polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red") + draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red", width=1) + polygon = Polygon( + points=[(10, 10), (100, 10), (100, 100), (10, 100)], + color="red", + opacity=1, + outline_width=1, + ) assert polygon.compute(mock_image) == expected_image # Test from mask @@ -48,12 +54,23 @@ def test_polygon(mock_image: PIL.Image): mask[10:100, 10:100] = 255 expected_image = mock_image.copy() draw = ImageDraw.Draw(expected_image) - draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red") - polygon = Polygon(mask=mask, color="red") + draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red", width=1) + polygon = Polygon(mask=mask, color="red", opacity=1, outline_width=1) assert polygon.compute(mock_image) == expected_image + with pytest.raises(ValueError, match="No contours found in the mask."): + polygon = Polygon(mask=np.zeros((100, 100), dtype=np.uint8)) + polygon.compute(mock_image) + def test_label(mock_image: PIL.Image): label = Label(label="Label") # When using a single label, compute and overlay_labels should return the same image assert label.compute(mock_image) == Label.overlay_labels(mock_image, [label]) + + +def test_keypoint(mock_image: PIL.Image): + keypoint = Keypoint(keypoints=np.array([[100, 100]]), color="red", keypoint_size=3) + draw = ImageDraw.Draw(mock_image) + draw.ellipse((97, 97, 103, 103), fill="red") + assert keypoint.compute(mock_image) == mock_image diff --git a/tests/python/unit/visualizer/test_scene.py b/tests/python/unit/visualizer/test_scene.py index 7b560443..dd9351a5 100644 --- a/tests/python/unit/visualizer/test_scene.py +++ b/tests/python/unit/visualizer/test_scene.py @@ -8,7 +8,13 @@ import numpy as np from PIL import Image -from model_api.models.result import AnomalyResult, ClassificationResult, DetectionResult +from model_api.models.result import ( + AnomalyResult, + ClassificationResult, + DetectionResult, + ImageResultWithSoftPrediction, + InstanceSegmentationResult, +) from model_api.models.result.classification import Label from model_api.visualizer import Visualizer @@ -63,3 +69,46 @@ def test_detection_scene(mock_image: Image, tmpdir: Path): visualizer = Visualizer() visualizer.save(mock_image, detection_result, tmpdir / "detection_scene.jpg") assert Path(tmpdir / "detection_scene.jpg").exists() + + +def test_segmentation_scene(mock_image: Image, tmpdir: Path): + """Test if the segmentation scene is created.""" + visualizer = Visualizer() + + instance_segmentation_result = InstanceSegmentationResult( + bboxes=np.array([[0, 0, 128, 128], [32, 32, 96, 96]]), + labels=np.array([0, 1]), + masks=np.array( + [ + np.ones((128, 128), dtype=np.uint8), + ] + ), + scores=np.array([0.85, 0.75]), + label_names=["person", "car"], + saliency_map=[np.ones((128, 128), dtype=np.uint8) * 255], + feature_vector=np.array([1, 2, 3, 4]), + ) + + visualizer.save( + mock_image, + instance_segmentation_result, + tmpdir / "instance_segmentation_scene.jpg", + ) + assert Path(tmpdir / "instance_segmentation_scene.jpg").exists() + + # Test ImageResultWithSoftPrediction + soft_prediction_result = ImageResultWithSoftPrediction( + resultImage=np.array( + [[0, 1, 2], [1, 2, 0], [2, 0, 1]], dtype=np.uint8 + ), # 3x3 test image with 3 classes + soft_prediction=np.ones( + (3, 3, 3), dtype=np.float32 + ), # 3 classes, 3x3 prediction + saliency_map=np.ones((3, 3), dtype=np.uint8) * 255, + feature_vector=np.array([1, 2, 3, 4]), + ) + + visualizer.save( + mock_image, soft_prediction_result, tmpdir / "soft_prediction_scene.jpg" + ) + assert Path(tmpdir / "soft_prediction_scene.jpg").exists()