Skip to content

Commit 56feddf

Browse files
Add detection scene
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent f50cf8f commit 56feddf

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/python/model_api/models/result/detection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def label_names(self, value):
111111

112112
@property
113113
def saliency_map(self):
114+
"""Saliency map for XAI.
115+
116+
Returns:
117+
np.ndarray: Saliency map in dim of (B, N_CLASSES, H, W).
118+
"""
114119
return self._saliency_map
115120

116121
@saliency_map.setter

src/python/model_api/visualizer/scene/detection.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6+
from itertools import starmap
67
from typing import Union
78

9+
import cv2
810
from PIL import Image
911

1012
from model_api.models.result import DetectionResult
11-
from model_api.visualizer.layout import Layout
13+
from model_api.visualizer.layout import Flatten, HStack, Layout
14+
from model_api.visualizer.primitive import BoundingBox, Label, Overlay
1215

1316
from .scene import Scene
1417

@@ -17,5 +20,30 @@ class DetectionScene(Scene):
1720
"""Detection Scene."""
1821

1922
def __init__(self, image: Image, result: DetectionResult, layout: Union[Layout, None] = None) -> None:
20-
self.image = image
21-
self.result = result
23+
super().__init__(
24+
base=image,
25+
label=self._get_labels(result),
26+
bounding_box=self._get_bounding_boxes(result),
27+
overlay=self._get_overlays(result),
28+
layout=layout,
29+
)
30+
31+
def _get_labels(self, result: DetectionResult) -> list[Label]:
32+
labels = []
33+
for label, score, label_name in zip(result.labels, result.scores, result.label_names):
34+
labels.append(Label(label=f"{label} {label_name}", score=score))
35+
return labels
36+
37+
def _get_overlays(self, result: DetectionResult) -> list[Overlay]:
38+
overlays = []
39+
for saliency_map in result.saliency_map[0][1:]: # Assumes only one batch. Skip background class.
40+
saliency_map = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
41+
overlays.append(Overlay(saliency_map))
42+
return overlays
43+
44+
def _get_bounding_boxes(self, result: DetectionResult) -> list[BoundingBox]:
45+
return list(starmap(BoundingBox, result.bboxes))
46+
47+
@property
48+
def default_layout(self) -> Layout:
49+
return HStack(Flatten(BoundingBox, Label), Overlay)

0 commit comments

Comments
 (0)