Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/python/model_api/visualizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# SPDX-License-Identifier: Apache-2.0

from .layout import Flatten, HStack, Layout
from .primitive import BoundingBox, Overlay
from .primitive import BoundingBox, Overlay, Polygon
from .scene import Scene
from .visualizer import Visualizer

__all__ = ["BoundingBox", "Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
11 changes: 11 additions & 0 deletions src/python/model_api/visualizer/primitive/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Primitive classes."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .bounding_box import BoundingBox
from .overlay import Overlay
from .polygon import Polygon
from .primitive import Primitive

__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"]
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
"""Base class for primitives."""
"""Bounding box primitive."""

# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod

import numpy as np
import PIL
from PIL import Image, ImageDraw


class Primitive(ABC):
"""Primitive class."""

@abstractmethod
def compute(self, image: Image) -> Image:
pass
from .primitive import Primitive


class BoundingBox(Primitive):
Expand Down Expand Up @@ -71,27 +61,3 @@ def compute(self, image: Image) -> Image:
draw.text((0, 0), self.label, fill="white")
image.paste(label_image, (self.x1, self.y1))
return image


class Overlay(Primitive):
"""Overlay primitive.

Useful for XAI and Anomaly Maps.

Args:
image (PIL.Image | np.ndarray): Image to be overlaid.
opacity (float): Opacity of the overlay.
"""

def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
self.image = self._to_pil(image)
self.opacity = opacity

def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
if isinstance(image, np.ndarray):
return PIL.Image.fromarray(image)
return image

def compute(self, image: PIL.Image) -> PIL.Image:
_image = self.image.resize(image.size)
return PIL.Image.blend(image, _image, self.opacity)
35 changes: 35 additions & 0 deletions src/python/model_api/visualizer/primitive/overlay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Overlay primitive."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import numpy as np
import PIL

from .primitive import Primitive


class Overlay(Primitive):
"""Overlay primitive.

Useful for XAI and Anomaly Maps.

Args:
image (PIL.Image | np.ndarray): Image to be overlaid.
opacity (float): Opacity of the overlay.
"""

def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
self.image = self._to_pil(image)
self.opacity = opacity

def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
if isinstance(image, np.ndarray):
return PIL.Image.fromarray(image)
return image

def compute(self, image: PIL.Image) -> PIL.Image:
image_ = self.image.resize(image.size)
return PIL.Image.blend(image, image_, self.opacity)
93 changes: 93 additions & 0 deletions src/python/model_api/visualizer/primitive/polygon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Polygon primitive."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import TYPE_CHECKING

import cv2
from PIL import Image, ImageDraw

from .primitive import Primitive

if TYPE_CHECKING:
import numpy as np


class Polygon(Primitive):
"""Polygon primitive.

Args:
points: List of points.
mask: Mask to draw the polygon.
color: Color of the polygon.

Examples:
>>> polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red")
>>> polygon = Polygon(mask=mask, color="red")
>>> polygon.compute(image).save("polygon.jpg")

>>> polygon = Polygon(mask=mask, color="red")
>>> polygon.compute(image).save("polygon.jpg")
"""

def __init__(
self,
points: list[tuple[int, int]] | None = None,
mask: np.ndarray | None = None,
color: str | tuple[int, int, int] = "blue",
) -> None:
self.points = self._get_points(points, mask)
self.color = color

def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]:
"""Get points from either points or mask.
Note:
Either points or mask should be provided.

Args:
points: List of points.
mask: Mask to draw the polygon.

Returns:
List of points.
"""
if points is not None and mask is not None:
msg = "Either points or mask should be provided, not both."
raise ValueError(msg)
if points is not None:
points_ = points
elif mask is not None:
points_ = self._get_points_from_mask(mask)
else:
msg = "Either points or mask should be provided."
raise ValueError(msg)
return points_

def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]:
"""Get points from mask.

Args:
mask: Mask to draw the polygon.

Returns:
List of points.
"""
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
points_ = contours[0].squeeze().tolist()
return [tuple(point) for point in points_]

def compute(self, image: Image) -> Image:
"""Compute the polygon.

Args:
image: Image to draw the polygon on.

Returns:
Image with the polygon drawn on it.
"""
draw = ImageDraw.Draw(image)
draw.polygon(self.points, fill=self.color)
return image
20 changes: 20 additions & 0 deletions src/python/model_api/visualizer/primitive/primitive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Base class for primitives."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import PIL


class Primitive(ABC):
"""Base class for primitives."""

@abstractmethod
def compute(self, image: PIL.Image) -> PIL.Image:
"""Compute the primitive."""
21 changes: 20 additions & 1 deletion tests/python/unit/visualizer/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import PIL
from PIL import ImageDraw

from model_api.visualizer import BoundingBox, Overlay
from model_api.visualizer import BoundingBox, Overlay, Polygon


def test_overlay(mock_image: PIL.Image):
Expand All @@ -32,3 +32,22 @@ def test_bounding_box(mock_image: PIL.Image):
draw.rectangle((10, 10, 100, 100), outline="blue", width=2)
bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100)
assert bounding_box.compute(mock_image) == expected_image


def test_polygon(mock_image: PIL.Image):
"""Test if the polygon is created correctly."""
# Test from points
expected_image = mock_image.copy()
draw = ImageDraw.Draw(expected_image)
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red")
assert polygon.compute(mock_image) == expected_image

# Test from mask
mask = np.zeros((100, 100), dtype=np.uint8)
mask[10:100, 10:100] = 255
expected_image = mock_image.copy()
draw = ImageDraw.Draw(expected_image)
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
polygon = Polygon(mask=mask, color="red")
assert polygon.compute(mock_image) == expected_image
Loading