diff --git a/model_api/python/model_api/visualizer/__init__.py b/model_api/python/model_api/visualizer/__init__.py index df33ea49..2c8a0062 100644 --- a/model_api/python/model_api/visualizer/__init__.py +++ b/model_api/python/model_api/visualizer/__init__.py @@ -1,11 +1,11 @@ """Visualizer.""" -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .layout import Flatten, HStack, Layout -from .primitive import Overlay +from .primitive import BoundingBox, Overlay from .scene import Scene from .visualizer import Visualizer -__all__ = ["Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"] +__all__ = ["BoundingBox", "Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"] diff --git a/model_api/python/model_api/visualizer/primitive.py b/model_api/python/model_api/visualizer/primitive.py index d24fc217..8a67c719 100644 --- a/model_api/python/model_api/visualizer/primitive.py +++ b/model_api/python/model_api/visualizer/primitive.py @@ -9,16 +9,70 @@ import numpy as np import PIL +from PIL import Image, ImageDraw class Primitive(ABC): """Primitive class.""" @abstractmethod - def compute(self, image: PIL.Image) -> PIL.Image: + def compute(self, image: Image) -> Image: pass +class BoundingBox(Primitive): + """Bounding box primitive. + + Args: + x1 (int): x-coordinate of the top-left corner of the bounding box. + y1 (int): y-coordinate of the top-left corner of the bounding box. + x2 (int): x-coordinate of the bottom-right corner of the bounding box. + y2 (int): y-coordinate of the bottom-right corner of the bounding box. + label (str | None): Label of the bounding box. + color (str | tuple[int, int, int]): Color of the bounding box. + + Example: + >>> bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100, label="Label Name") + >>> bounding_box.compute(image) + """ + + def __init__( + self, + x1: int, + y1: int, + x2: int, + y2: int, + label: str | None = None, + color: str | tuple[int, int, int] = "blue", + ) -> None: + self.x1 = x1 + self.y1 = y1 + self.x2 = x2 + self.y2 = y2 + self.label = label + self.color = color + self.y_buffer = 5 # Text at the bottom of the text box is clipped. This prevents that. + + def compute(self, image: Image) -> Image: + draw = ImageDraw.Draw(image) + # draw rectangle + draw.rectangle((self.x1, self.y1, self.x2, self.y2), outline=self.color, width=2) + # add label + if self.label: + # draw the background of the label + textbox = draw.textbbox((0, 0), self.label) + label_image = Image.new( + "RGB", + (textbox[2] - textbox[0], textbox[3] + self.y_buffer - textbox[1]), + self.color, + ) + draw = ImageDraw.Draw(label_image) + # write the label on the background + draw.text((0, 0), self.label, fill="white") + image.paste(label_image, (self.x1, self.y1)) + return image + + class Overlay(Primitive): """Overlay primitive. diff --git a/tests/python/unit/visualizer/test_primitive.py b/tests/python/unit/visualizer/test_primitive.py index 9f6a210b..d9a84624 100644 --- a/tests/python/unit/visualizer/test_primitive.py +++ b/tests/python/unit/visualizer/test_primitive.py @@ -5,8 +5,9 @@ import numpy as np import PIL +from PIL import ImageDraw -from model_api.visualizer import Overlay +from model_api.visualizer import BoundingBox, Overlay def test_overlay(mock_image: PIL.Image): @@ -22,3 +23,12 @@ def test_overlay(mock_image: PIL.Image): data *= 255 overlay = Overlay(data) assert overlay.compute(empty_image) == expected_image + + +def test_bounding_box(mock_image: PIL.Image): + """Test if the bounding box is created correctly.""" + expected_image = mock_image.copy() + draw = ImageDraw.Draw(expected_image) + draw.rectangle((10, 10, 100, 100), outline="blue", width=2) + bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100) + assert bounding_box.compute(mock_image) == expected_image