Skip to content

Commit 5f330ba

Browse files
Add title to overlay
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 4a5ac67 commit 5f330ba

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed

src/python/model_api/visualizer/layout/hstack.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import PIL
1111

12+
from model_api.visualizer.primitive import Overlay
13+
1214
from .layout import Layout
1315

1416
if TYPE_CHECKING:
@@ -31,6 +33,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc
3133
images = []
3234
for _primitive in scene.get_primitives(primitive):
3335
image_ = _primitive.compute(image.copy())
36+
if isinstance(_primitive, Overlay):
37+
image_ = Overlay.overlay_labels(image=image_, labels=_primitive.label)
3438
images.append(image_)
3539
return self._stitch(*images)
3640
return None

src/python/model_api/visualizer/primitive/overlay.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
from __future__ import annotations
77

8+
from typing import Union
9+
810
import numpy as np
911
import PIL
12+
from PIL import ImageFont
1013

1114
from .primitive import Primitive
1215

@@ -18,11 +21,18 @@ class Overlay(Primitive):
1821
1922
Args:
2023
image (PIL.Image | np.ndarray): Image to be overlaid.
24+
label (str | None): Optional label name to overlay.
2125
opacity (float): Opacity of the overlay.
2226
"""
2327

24-
def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
28+
def __init__(
29+
self,
30+
image: PIL.Image | np.ndarray,
31+
opacity: float = 0.4,
32+
label: Union[str, None] = None,
33+
) -> None:
2534
self.image = self._to_pil(image)
35+
self.label = label
2636
self.opacity = opacity
2737

2838
def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
@@ -33,3 +43,22 @@ def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
3343
def compute(self, image: PIL.Image) -> PIL.Image:
3444
image_ = self.image.resize(image.size)
3545
return PIL.Image.blend(image, image_, self.opacity)
46+
47+
@classmethod
48+
def overlay_labels(cls, image: PIL.Image, labels: Union[list[str], str, None] = None) -> PIL.Image:
49+
"""Draw labels at the bottom center of the image.
50+
51+
This is handy when you want to add a label to the image.
52+
"""
53+
if labels is not None:
54+
labels = [labels] if isinstance(labels, str) else labels
55+
font = ImageFont.load_default(size=18)
56+
buffer_y = 5
57+
dummy_image = PIL.Image.new("RGB", (1, 1))
58+
draw = PIL.ImageDraw.Draw(dummy_image)
59+
textbox = draw.textbbox((0, 0), ", ".join(labels), font=font)
60+
image_ = PIL.Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), "white")
61+
draw = PIL.ImageDraw.Draw(image_)
62+
draw.text((0, 0), ", ".join(labels), font=font, fill="black")
63+
image.paste(image_, (image.width // 2 - image_.width // 2, image.height - image_.height - buffer_y))
64+
return image

src/python/model_api/visualizer/scene/detection.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
from itertools import starmap
76
from typing import Union
87

98
import cv2
@@ -22,27 +21,28 @@ class DetectionScene(Scene):
2221
def __init__(self, image: Image, result: DetectionResult, layout: Union[Layout, None] = None) -> None:
2322
super().__init__(
2423
base=image,
25-
label=self._get_labels(result),
2624
bounding_box=self._get_bounding_boxes(result),
2725
overlay=self._get_overlays(result),
2826
layout=layout,
2927
)
3028

31-
def _get_labels(self, result: DetectionResult) -> list[Label]:
32-
labels = []
33-
for label, score, label_name in zip(result.labels, result.scores, result.label_names):
34-
labels.append(Label(label=f"{label} {label_name}", score=score))
35-
return labels
36-
3729
def _get_overlays(self, result: DetectionResult) -> list[Overlay]:
3830
overlays = []
39-
for saliency_map in result.saliency_map[0][1:]: # Assumes only one batch. Skip background class.
40-
saliency_map = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
41-
overlays.append(Overlay(saliency_map))
31+
# Add only the overlays that are predicted
32+
label_index_mapping = dict(zip(result.labels, result.label_names))
33+
for label_index, label_name in label_index_mapping.items():
34+
# Index 0 as it assumes only one batch
35+
saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET)
36+
overlays.append(Overlay(saliency_map, label=label_name.title()))
4237
return overlays
4338

4439
def _get_bounding_boxes(self, result: DetectionResult) -> list[BoundingBox]:
45-
return list(starmap(BoundingBox, result.bboxes))
40+
bounding_boxes = []
41+
for score, label_name, bbox in zip(result.scores, result.label_names, result.bboxes):
42+
x1, y1, x2, y2 = bbox
43+
label = f"{label_name} ({score:.2f})"
44+
bounding_boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, label=label))
45+
return bounding_boxes
4646

4747
@property
4848
def default_layout(self) -> Layout:

0 commit comments

Comments
 (0)