Skip to content

Commit 46b74db

Browse files
Add label primitive (#256)
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent ae3241a commit 46b74db

File tree

7 files changed

+191
-22
lines changed

7 files changed

+191
-22
lines changed

src/python/model_api/visualizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from .layout import Flatten, HStack, Layout
7-
from .primitive import BoundingBox, Overlay, Polygon
7+
from .primitive import BoundingBox, Label, Overlay, Polygon
88
from .scene import Scene
99
from .visualizer import Visualizer
1010

11-
__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
11+
__all__ = ["BoundingBox", "Label", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]

src/python/model_api/visualizer/layout/flatten.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from __future__ import annotations
77

8-
from typing import TYPE_CHECKING, Type, Union
8+
from typing import TYPE_CHECKING, Type, Union, cast
9+
10+
from model_api.visualizer.primitive import Label
911

1012
from .layout import Layout
1113

@@ -29,13 +31,19 @@ def __init__(self, *args: Union[Type[Primitive], Layout]) -> None:
2931
def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, scene: Scene) -> PIL.Image | None:
3032
if scene.has_primitives(primitive):
3133
primitives = scene.get_primitives(primitive)
32-
for _primitive in primitives:
33-
image = _primitive.compute(image)
34+
if primitive == Label: # Labels need to be rendered next to each other
35+
# cast is needed as mypy does not know that the primitives are of type Label.
36+
primitives_ = cast("list[Label]", primitives)
37+
image = Label.overlay_labels(image, primitives_)
38+
else:
39+
# Other primitives are rendered on top of each other
40+
for _primitive in primitives:
41+
image = _primitive.compute(image)
3442
return image
3543
return None
3644

3745
def __call__(self, scene: Scene) -> PIL.Image:
38-
_image: PIL.Image = scene.base.copy()
46+
image_: PIL.Image = scene.base.copy()
3947
for child in self.children:
40-
_image = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, _image, scene)
41-
return _image
48+
image_ = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, image_, scene)
49+
return image_

src/python/model_api/visualizer/layout/hstack.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc
3030
if scene.has_primitives(primitive):
3131
images = []
3232
for _primitive in scene.get_primitives(primitive):
33-
_image = _primitive.compute(image.copy())
34-
images.append(_image)
33+
image_ = _primitive.compute(image.copy())
34+
images.append(image_)
3535
return self._stitch(*images)
3636
return None
3737

@@ -70,9 +70,9 @@ def __call__(self, scene: Scene) -> PIL.Image:
7070
images: list[PIL.Image] = []
7171
for child in self.children:
7272
if isinstance(child, Layout):
73-
_image = child(scene)
73+
image_ = child(scene)
7474
else:
75-
_image = self._compute_on_primitive(child, scene.base.copy(), scene)
76-
if _image is not None:
77-
images.append(_image)
75+
image_ = self._compute_on_primitive(child, scene.base.copy(), scene)
76+
if image_ is not None:
77+
images.append(image_)
7878
return self._stitch(*images)

src/python/model_api/visualizer/primitive/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from .bounding_box import BoundingBox
7+
from .label import Label
78
from .overlay import Overlay
89
from .polygon import Polygon
910
from .primitive import Primitive
1011

11-
__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"]
12+
__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon"]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Label primitive."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from io import BytesIO
7+
from typing import Union
8+
9+
from PIL import Image, ImageDraw, ImageFont
10+
11+
from .primitive import Primitive
12+
13+
14+
class Label(Primitive):
15+
"""Label primitive.
16+
17+
Labels require a different processing than other primitives as the class also handles the instance when the layout
18+
requests all the labels to be drawn on a single image.
19+
20+
Args:
21+
label (str): Text of the label.
22+
fg_color (str | tuple[int, int, int]): Foreground color of the label.
23+
bg_color (str | tuple[int, int, int]): Background color of the label.
24+
font_path (str | None | BytesIO): Path to the font file.
25+
size (int): Size of the font.
26+
27+
Examples:
28+
>>> label = Label(label="Label 1")
29+
>>> label.compute(image).save("label.jpg")
30+
31+
>>> label = Label(text="Label 1", fg_color="red", bg_color="blue", font_path="arial.ttf", size=20)
32+
>>> label.compute(image).save("label.jpg")
33+
34+
or multiple labels on a single image:
35+
>>> label1 = Label(text="Label 1")
36+
>>> label2 = Label(text="Label 2")
37+
>>> label3 = Label(text="Label 3")
38+
>>> Label.overlay_labels(image, [label1, label2, label3]).save("labels.jpg")
39+
"""
40+
41+
def __init__(
42+
self,
43+
label: str,
44+
fg_color: Union[str, tuple[int, int, int]] = "black",
45+
bg_color: Union[str, tuple[int, int, int]] = "yellow",
46+
font_path: Union[str, BytesIO, None] = None,
47+
size: int = 16,
48+
) -> None:
49+
self.label = label
50+
self.fg_color = fg_color
51+
self.bg_color = bg_color
52+
self.font = ImageFont.load_default(size=size) if font_path is None else ImageFont.truetype(font_path, size)
53+
54+
def compute(self, image: Image, buffer_y: int = 5) -> Image:
55+
"""Generate label on top of the image.
56+
57+
Args:
58+
image (PIL.Image): Image to paste the label on.
59+
buffer_y (int): Buffer to add to the y-axis of the label.
60+
"""
61+
label_image = self.generate_label_image(buffer_y)
62+
image.paste(label_image, (0, 0))
63+
return image
64+
65+
def generate_label_image(self, buffer_y: int = 5) -> Image:
66+
"""Generate label image.
67+
68+
Args:
69+
buffer_y (int): Buffer to add to the y-axis of the label. This is needed as the text is clipped from the
70+
bottom otherwise.
71+
72+
Returns:
73+
PIL.Image: Image that consists only of the label.
74+
"""
75+
dummy_image = Image.new("RGB", (1, 1))
76+
draw = ImageDraw.Draw(dummy_image)
77+
textbox = draw.textbbox((0, 0), self.label, font=self.font)
78+
label_image = Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), self.bg_color)
79+
draw = ImageDraw.Draw(label_image)
80+
draw.text((0, 0), self.label, font=self.font, fill=self.fg_color)
81+
return label_image
82+
83+
@classmethod
84+
def overlay_labels(cls, image: Image, labels: list["Label"], buffer_y: int = 5, buffer_x: int = 5) -> Image:
85+
"""Overlay multiple label images on top of the image.
86+
Paste the labels in a row but wrap the labels if they exceed the image width.
87+
88+
Args:
89+
image (PIL.Image): Image to paste the labels on.
90+
labels (list[Label]): Labels to be pasted on the image.
91+
buffer_y (int): Buffer to add to the y-axis of the labels.
92+
buffer_x (int): Space between the labels.
93+
94+
Returns:
95+
PIL.Image: Image with the labels pasted on it.
96+
"""
97+
offset_x = 0
98+
offset_y = 0
99+
for label in labels:
100+
label_image = label.generate_label_image(buffer_y)
101+
image.paste(label_image, (offset_x, offset_y))
102+
offset_x += label_image.width + buffer_x
103+
if offset_x + label_image.width > image.width:
104+
offset_x = 0
105+
offset_y += label_image.height
106+
return image

src/python/model_api/visualizer/scene/scene.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from __future__ import annotations
77

8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, cast
99

1010
import numpy as np
1111
from PIL import Image
1212

13-
from model_api.visualizer.primitive import Overlay, Primitive
13+
from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon, Primitive
1414

1515
if TYPE_CHECKING:
1616
from pathlib import Path
@@ -27,16 +27,23 @@ class Scene:
2727
def __init__(
2828
self,
2929
base: Image,
30+
bounding_box: BoundingBox | list[BoundingBox] | None = None,
31+
label: Label | list[Label] | None = None,
3032
overlay: Overlay | list[Overlay] | np.ndarray | None = None,
33+
polygon: Polygon | list[Polygon] | None = None,
3134
layout: Layout | None = None,
3235
) -> None:
3336
self.base = base
3437
self.overlay = self._to_overlay(overlay)
38+
self.bounding_box = self._to_bounding_box(bounding_box)
39+
self.label = self._to_label(label)
40+
self.polygon = self._to_polygon(polygon)
3541
self.layout = layout
3642

3743
def show(self) -> Image: ...
3844

39-
def save(self, path: Path) -> None: ...
45+
def save(self, path: Path) -> None:
46+
self.render().save(path)
4047

4148
def render(self) -> Image:
4249
if self.layout is None:
@@ -46,16 +53,42 @@ def render(self) -> Image:
4653
def has_primitives(self, primitive: type[Primitive]) -> bool:
4754
if primitive == Overlay:
4855
return bool(self.overlay)
56+
if primitive == BoundingBox:
57+
return bool(self.bounding_box)
58+
if primitive == Label:
59+
return bool(self.label)
60+
if primitive == Polygon:
61+
return bool(self.polygon)
4962
return False
5063

5164
def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]:
65+
"""Get primitives of the given type.
66+
67+
Args:
68+
primitive (type[Primitive]): The type of primitive to get.
69+
70+
Example:
71+
>>> scene = Scene(base=Image.new("RGB", (100, 100)), overlay=[Overlay(Image.new("RGB", (100, 100)))])
72+
>>> scene.get_primitives(Overlay)
73+
[Overlay(image=Image.new("RGB", (100, 100)))]
74+
75+
Returns:
76+
list[Primitive]: The primitives of the given type or an empty list if no primitives are found.
77+
"""
5278
primitives: list[Primitive] | None = None
79+
# cast is needed as mypy does not know that the primitives are a subclass of Primitive.
5380
if primitive == Overlay:
54-
primitives = self.overlay # type: ignore[assignment] # TODO(ashwinvaidya17): Address this in the next PR
55-
if primitives is None:
81+
primitives = cast("list[Primitive]", self.overlay)
82+
elif primitive == BoundingBox:
83+
primitives = cast("list[Primitive]", self.bounding_box)
84+
elif primitive == Label:
85+
primitives = cast("list[Primitive]", self.label)
86+
elif primitive == Polygon:
87+
primitives = cast("list[Primitive]", self.polygon)
88+
else:
5689
msg = f"Primitive {primitive} not found"
5790
raise ValueError(msg)
58-
return primitives
91+
return primitives or []
5992

6093
@property
6194
def default_layout(self) -> Layout:
@@ -70,3 +103,18 @@ def _to_overlay(self, overlay: Overlay | list[Overlay] | np.ndarray | None) -> l
70103
if isinstance(overlay, Overlay):
71104
return [overlay]
72105
return overlay
106+
107+
def _to_bounding_box(self, bounding_box: BoundingBox | list[BoundingBox] | None) -> list[BoundingBox] | None:
108+
if isinstance(bounding_box, BoundingBox):
109+
return [bounding_box]
110+
return bounding_box
111+
112+
def _to_label(self, label: Label | list[Label] | None) -> list[Label] | None:
113+
if isinstance(label, Label):
114+
return [label]
115+
return label
116+
117+
def _to_polygon(self, polygon: Polygon | list[Polygon] | None) -> list[Polygon] | None:
118+
if isinstance(polygon, Polygon):
119+
return [polygon]
120+
return polygon

tests/python/unit/visualizer/test_primitive.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import PIL
88
from PIL import ImageDraw
99

10-
from model_api.visualizer import BoundingBox, Overlay, Polygon
10+
from model_api.visualizer import BoundingBox, Label, Overlay, Polygon
1111

1212

1313
def test_overlay(mock_image: PIL.Image):
@@ -51,3 +51,9 @@ def test_polygon(mock_image: PIL.Image):
5151
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
5252
polygon = Polygon(mask=mask, color="red")
5353
assert polygon.compute(mock_image) == expected_image
54+
55+
56+
def test_label(mock_image: PIL.Image):
57+
label = Label(label="Label")
58+
# When using a single label, compute and overlay_labels should return the same image
59+
assert label.compute(mock_image) == Label.overlay_labels(mock_image, [label])

0 commit comments

Comments
 (0)