Skip to content

Commit f50cf8f

Browse files
Add classification scene
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent ddda5e6 commit f50cf8f

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

src/python/model_api/visualizer/scene/classification.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
from typing import Union
77

8+
import cv2
89
from PIL import Image
910

1011
from model_api.models.result import ClassificationResult
1112
from model_api.visualizer.layout import Flatten, Layout
12-
from model_api.visualizer.primitive import Overlay
13+
from model_api.visualizer.primitive import Label, Overlay
1314

1415
from .scene import Scene
1516

@@ -18,9 +19,26 @@ class ClassificationScene(Scene):
1819
"""Classification Scene."""
1920

2021
def __init__(self, image: Image, result: ClassificationResult, layout: Union[Layout, None] = None) -> None:
21-
self.image = image
22-
self.result = result
22+
super().__init__(
23+
base=image,
24+
label=self._get_labels(result),
25+
overlay=self._get_overlays(result),
26+
layout=layout,
27+
)
28+
29+
def _get_labels(self, result: ClassificationResult) -> list[Label]:
30+
labels = []
31+
if result.top_labels is not None and len(result.top_labels) > 0:
32+
labels.extend([Label(label=str(label)) for label in result.top_labels])
33+
return labels
34+
35+
def _get_overlays(self, result: ClassificationResult) -> list[Overlay]:
36+
overlays = []
37+
if result.saliency_map is not None and result.saliency_map.size > 0:
38+
saliency_map = cv2.cvtColor(result.saliency_map, cv2.COLOR_BGR2RGB)
39+
overlays.append(Overlay(saliency_map))
40+
return overlays
2341

2442
@property
2543
def default_layout(self) -> Layout:
26-
return Flatten(Overlay)
44+
return Flatten(Overlay, Label)

tests/python/unit/visualizer/test_scene.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import numpy as np
99
from PIL import Image
1010

11-
from model_api.models.result import AnomalyResult
11+
from model_api.models.result import AnomalyResult, ClassificationResult
12+
from model_api.models.result.classification import Label
1213
from model_api.visualizer import Visualizer
1314

1415

@@ -32,3 +33,19 @@ def test_anomaly_scene(mock_image: Image, tmpdir: Path):
3233
visualizer = Visualizer()
3334
visualizer.save(mock_image, anomaly_result, tmpdir / "anomaly_scene.jpg")
3435
assert Path(tmpdir / "anomaly_scene.jpg").exists()
36+
37+
38+
def test_classification_scene(mock_image: Image, tmpdir: Path):
39+
"""Test if the classification scene is created."""
40+
classification_result = ClassificationResult(
41+
top_labels=[
42+
Label(name="cat", confidence=0.95),
43+
Label(name="dog", confidence=0.90),
44+
],
45+
saliency_map=np.ones(mock_image.size, dtype=np.uint8),
46+
)
47+
visualizer = Visualizer()
48+
visualizer.save(
49+
mock_image, classification_result, tmpdir / "classification_scene.jpg"
50+
)
51+
assert Path(tmpdir / "classification_scene.jpg").exists()

0 commit comments

Comments
 (0)