Skip to content

Commit 2d0849e

Browse files
restore visualization changes
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 0ce25fe commit 2d0849e

File tree

4 files changed

+47
-7
lines changed

4 files changed

+47
-7
lines changed

model_api/python/model_api/models/result_types/anomaly.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55

66
from __future__ import annotations
77

8+
import cv2
89
import numpy as np
910

11+
from model_api.visualizer.primitives import BoundingBoxes, Label, Overlay, Polygon
1012

11-
class AnomalyResult:
13+
from .base import Result
14+
15+
16+
class AnomalyResult(Result):
1217
"""Results for anomaly models."""
1318

1419
def __init__(
@@ -19,6 +24,7 @@ def __init__(
1924
pred_mask: np.ndarray | None = None,
2025
pred_score: float | None = None,
2126
) -> None:
27+
super().__init__()
2228
self.anomaly_map = anomaly_map
2329
self.pred_boxes = pred_boxes
2430
self.pred_label = pred_label
@@ -40,3 +46,14 @@ def __str__(self) -> str:
4046
f"pred_label:{self.pred_label};"
4147
f"pred_mask min:{pred_mask_min} max:{pred_mask_max};"
4248
)
49+
50+
def _register_primitives(self) -> None:
51+
"""Converts the result to primitives."""
52+
anomaly_map = cv2.applyColorMap(self.anomaly_map, cv2.COLORMAP_JET)
53+
self._add_primitive(Overlay(anomaly_map))
54+
for box in self.pred_boxes:
55+
self._add_primitive(BoundingBoxes(*box))
56+
if self.pred_label is not None:
57+
self._add_primitive(Label(self.pred_label, bg_color="red" if self.pred_label == "Anomaly" else "green"))
58+
self._add_primitive(Label(f"Score: {self.pred_score}"))
59+
self._add_primitive(Polygon(mask=self.pred_mask))
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Base result type"""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from abc import ABC
7+
8+
from model_api.visualizer.visualize_mixin import VisualizeMixin
9+
10+
11+
class Result(VisualizeMixin, ABC):
12+
"""Base result type."""

model_api/python/model_api/models/result_types/classification.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
from typing import TYPE_CHECKING
99

10+
from model_api.visualizer.primitives import Label
11+
12+
from .base import Result
1013
from .utils import array_shape_to_str
1114

1215
if TYPE_CHECKING:
1316
import numpy as np
1417

1518

16-
class ClassificationResult:
19+
class ClassificationResult(Result):
1720
"""Results for classification models."""
1821

1922
def __init__(
@@ -35,3 +38,8 @@ def __str__(self) -> str:
3538
f"{labels}, {array_shape_to_str(self.saliency_map)}, {array_shape_to_str(self.feature_vector)}, "
3639
f"{array_shape_to_str(self.raw_scores)}"
3740
)
41+
42+
def _register_primitives(self) -> None:
43+
# TODO add saliency map
44+
for idx, label, confidence in self.top_labels:
45+
self._add_primitive(Label(f"Rank: {idx}, {label}: {confidence:.3f}"))

model_api/python/model_api/visualizer/visualizer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
from __future__ import annotations
77

88
from enum import Enum
9+
from typing import TYPE_CHECKING
910

1011
from PIL import Image
1112

1213
from model_api.visualizer.primitives import Label
13-
from model_api.visualizer.visualize_mixin import VisualizeMixin
14+
15+
if TYPE_CHECKING:
16+
from model_api.visualizer.visualize_mixin import VisualizeMixin
1417

1518

1619
class VisualizationType(Enum):
@@ -47,12 +50,12 @@ def save(
4750
result.save(path)
4851

4952
def _generate(self, image: Image, result: VisualizeMixin, visualization_type: VisualizationType) -> Image:
50-
result: Image
53+
_result: Image
5154
if visualization_type == VisualizationType.SIMPLE:
52-
result = self._generate_simple(image, result)
55+
_result = self._generate_simple(image, result)
5356
else:
54-
result = self._generate_full(image, result)
55-
return result
57+
_result = self._generate_full(image, result)
58+
return _result
5659

5760
def _generate_simple(self, image: Image, result: VisualizeMixin) -> Image:
5861
"""Return a single image with stacked visualizations."""

0 commit comments

Comments
 (0)