Skip to content

Commit ae3241a

Browse files
Add Polygon Primitive (#254)
* Refactor primitives Signed-off-by: Ashwin Vaidya <[email protected]> * Add polygon primitive Signed-off-by: Ashwin Vaidya <[email protected]> * Add docstrings Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 6e7d108 commit ae3241a

File tree

7 files changed

+184
-40
lines changed

7 files changed

+184
-40
lines changed

src/python/model_api/visualizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from .layout import Flatten, HStack, Layout
7-
from .primitive import BoundingBox, Overlay
7+
from .primitive import BoundingBox, Overlay, Polygon
88
from .scene import Scene
99
from .visualizer import Visualizer
1010

11-
__all__ = ["BoundingBox", "Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
11+
__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Primitive classes."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .bounding_box import BoundingBox
7+
from .overlay import Overlay
8+
from .polygon import Polygon
9+
from .primitive import Primitive
10+
11+
__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"]

src/python/model_api/visualizer/primitive.py renamed to src/python/model_api/visualizer/primitive/bounding_box.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,13 @@
1-
"""Base class for primitives."""
1+
"""Bounding box primitive."""
22

3-
# Copyright (C) 2024 Intel Corporation
3+
# Copyright (C) 2025 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

66
from __future__ import annotations
77

8-
from abc import ABC, abstractmethod
9-
10-
import numpy as np
11-
import PIL
128
from PIL import Image, ImageDraw
139

14-
15-
class Primitive(ABC):
16-
"""Primitive class."""
17-
18-
@abstractmethod
19-
def compute(self, image: Image) -> Image:
20-
pass
10+
from .primitive import Primitive
2111

2212

2313
class BoundingBox(Primitive):
@@ -71,27 +61,3 @@ def compute(self, image: Image) -> Image:
7161
draw.text((0, 0), self.label, fill="white")
7262
image.paste(label_image, (self.x1, self.y1))
7363
return image
74-
75-
76-
class Overlay(Primitive):
77-
"""Overlay primitive.
78-
79-
Useful for XAI and Anomaly Maps.
80-
81-
Args:
82-
image (PIL.Image | np.ndarray): Image to be overlaid.
83-
opacity (float): Opacity of the overlay.
84-
"""
85-
86-
def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
87-
self.image = self._to_pil(image)
88-
self.opacity = opacity
89-
90-
def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
91-
if isinstance(image, np.ndarray):
92-
return PIL.Image.fromarray(image)
93-
return image
94-
95-
def compute(self, image: PIL.Image) -> PIL.Image:
96-
_image = self.image.resize(image.size)
97-
return PIL.Image.blend(image, _image, self.opacity)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Overlay primitive."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from __future__ import annotations
7+
8+
import numpy as np
9+
import PIL
10+
11+
from .primitive import Primitive
12+
13+
14+
class Overlay(Primitive):
15+
"""Overlay primitive.
16+
17+
Useful for XAI and Anomaly Maps.
18+
19+
Args:
20+
image (PIL.Image | np.ndarray): Image to be overlaid.
21+
opacity (float): Opacity of the overlay.
22+
"""
23+
24+
def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
25+
self.image = self._to_pil(image)
26+
self.opacity = opacity
27+
28+
def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
29+
if isinstance(image, np.ndarray):
30+
return PIL.Image.fromarray(image)
31+
return image
32+
33+
def compute(self, image: PIL.Image) -> PIL.Image:
34+
image_ = self.image.resize(image.size)
35+
return PIL.Image.blend(image, image_, self.opacity)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Polygon primitive."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from __future__ import annotations
7+
8+
from typing import TYPE_CHECKING
9+
10+
import cv2
11+
from PIL import Image, ImageDraw
12+
13+
from .primitive import Primitive
14+
15+
if TYPE_CHECKING:
16+
import numpy as np
17+
18+
19+
class Polygon(Primitive):
20+
"""Polygon primitive.
21+
22+
Args:
23+
points: List of points.
24+
mask: Mask to draw the polygon.
25+
color: Color of the polygon.
26+
27+
Examples:
28+
>>> polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red")
29+
>>> polygon = Polygon(mask=mask, color="red")
30+
>>> polygon.compute(image).save("polygon.jpg")
31+
32+
>>> polygon = Polygon(mask=mask, color="red")
33+
>>> polygon.compute(image).save("polygon.jpg")
34+
"""
35+
36+
def __init__(
37+
self,
38+
points: list[tuple[int, int]] | None = None,
39+
mask: np.ndarray | None = None,
40+
color: str | tuple[int, int, int] = "blue",
41+
) -> None:
42+
self.points = self._get_points(points, mask)
43+
self.color = color
44+
45+
def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]:
46+
"""Get points from either points or mask.
47+
Note:
48+
Either points or mask should be provided.
49+
50+
Args:
51+
points: List of points.
52+
mask: Mask to draw the polygon.
53+
54+
Returns:
55+
List of points.
56+
"""
57+
if points is not None and mask is not None:
58+
msg = "Either points or mask should be provided, not both."
59+
raise ValueError(msg)
60+
if points is not None:
61+
points_ = points
62+
elif mask is not None:
63+
points_ = self._get_points_from_mask(mask)
64+
else:
65+
msg = "Either points or mask should be provided."
66+
raise ValueError(msg)
67+
return points_
68+
69+
def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]:
70+
"""Get points from mask.
71+
72+
Args:
73+
mask: Mask to draw the polygon.
74+
75+
Returns:
76+
List of points.
77+
"""
78+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79+
points_ = contours[0].squeeze().tolist()
80+
return [tuple(point) for point in points_]
81+
82+
def compute(self, image: Image) -> Image:
83+
"""Compute the polygon.
84+
85+
Args:
86+
image: Image to draw the polygon on.
87+
88+
Returns:
89+
Image with the polygon drawn on it.
90+
"""
91+
draw = ImageDraw.Draw(image)
92+
draw.polygon(self.points, fill=self.color)
93+
return image
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Base class for primitives."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from __future__ import annotations
7+
8+
from abc import ABC, abstractmethod
9+
from typing import TYPE_CHECKING
10+
11+
if TYPE_CHECKING:
12+
import PIL
13+
14+
15+
class Primitive(ABC):
16+
"""Base class for primitives."""
17+
18+
@abstractmethod
19+
def compute(self, image: PIL.Image) -> PIL.Image:
20+
"""Compute the primitive."""

tests/python/unit/visualizer/test_primitive.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import PIL
88
from PIL import ImageDraw
99

10-
from model_api.visualizer import BoundingBox, Overlay
10+
from model_api.visualizer import BoundingBox, Overlay, Polygon
1111

1212

1313
def test_overlay(mock_image: PIL.Image):
@@ -32,3 +32,22 @@ def test_bounding_box(mock_image: PIL.Image):
3232
draw.rectangle((10, 10, 100, 100), outline="blue", width=2)
3333
bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100)
3434
assert bounding_box.compute(mock_image) == expected_image
35+
36+
37+
def test_polygon(mock_image: PIL.Image):
38+
"""Test if the polygon is created correctly."""
39+
# Test from points
40+
expected_image = mock_image.copy()
41+
draw = ImageDraw.Draw(expected_image)
42+
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
43+
polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red")
44+
assert polygon.compute(mock_image) == expected_image
45+
46+
# Test from mask
47+
mask = np.zeros((100, 100), dtype=np.uint8)
48+
mask[10:100, 10:100] = 255
49+
expected_image = mock_image.copy()
50+
draw = ImageDraw.Draw(expected_image)
51+
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
52+
polygon = Polygon(mask=mask, color="red")
53+
assert polygon.compute(mock_image) == expected_image

0 commit comments

Comments
 (0)