Skip to content

Commit c97a1c4

Browse files
committed
Add render method to visualizer
1 parent 7622768 commit c97a1c4

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

src/python/model_api/visualizer/visualizer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Visualizer for modelAPI."""
22

3-
# Copyright (C) 2024 Intel Corporation
3+
# Copyright (C) 2024-2025 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

66
from __future__ import annotations # TODO: remove when Python3.9 support is dropped
@@ -42,18 +42,32 @@ class Visualizer:
4242
def __init__(self, layout: Layout | None = None) -> None:
4343
self.layout = layout
4444

45-
def show(self, image: Image | np.ndarray, result: Result) -> None:
45+
def show(self, image: Image.Image | np.ndarray, result: Result) -> None:
4646
if isinstance(image, np.ndarray):
4747
image = Image.fromarray(image)
4848
scene = self._scene_from_result(image, result)
4949
return scene.show()
5050

51-
def save(self, image: Image | np.ndarray, result: Result, path: Path) -> None:
51+
def save(self, image: Image.Image | np.ndarray, result: Result, path: Path) -> None:
5252
if isinstance(image, np.ndarray):
5353
image = Image.fromarray(image)
5454
scene = self._scene_from_result(image, result)
5555
scene.save(path)
5656

57+
def render(self, image: Image.Image | np.ndarray, result: Result) -> Image.Image | np.ndarray:
58+
is_numpy = isinstance(image, np.ndarray)
59+
60+
if is_numpy:
61+
image = Image.fromarray(image)
62+
63+
scene = self._scene_from_result(image, result)
64+
result_img: Image = scene.render()
65+
66+
if is_numpy:
67+
return np.array(result_img)
68+
69+
return result_img
70+
5771
def _scene_from_result(self, image: Image, result: Result) -> Scene:
5872
scene: Scene
5973
if isinstance(result, AnomalyResult):
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Tests for visualizer."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
7+
from pathlib import Path
8+
9+
import numpy as np
10+
from PIL import Image
11+
import pytest
12+
13+
from model_api.models.result import (
14+
AnomalyResult,
15+
)
16+
from model_api.models.result.classification import Label
17+
from model_api.visualizer import Visualizer
18+
19+
20+
def test_anomaly_scene(mock_image: Image, tmpdir: Path):
21+
"""Test if the anomaly scene is created."""
22+
heatmap = np.ones(mock_image.size, dtype=np.uint8)
23+
heatmap *= 255
24+
25+
mask = np.zeros(mock_image.size, dtype=np.uint8)
26+
mask[32:96, 32:96] = 255
27+
mask[40:80, 0:128] = 255
28+
29+
anomaly_result = AnomalyResult(
30+
anomaly_map=heatmap,
31+
pred_boxes=np.array([[0, 0, 128, 128], [32, 32, 96, 96]]),
32+
pred_label="Anomaly",
33+
pred_mask=mask,
34+
pred_score=0.85,
35+
)
36+
37+
visualizer = Visualizer()
38+
rendered_img = visualizer.render(mock_image, anomaly_result)
39+
40+
assert isinstance(rendered_img, Image.Image)
41+
assert np.array(rendered_img).shape == np.array(mock_image).shape
42+
43+
rendered_img_np = visualizer.render(np.array(mock_image), anomaly_result)
44+
45+
assert isinstance(rendered_img_np, np.ndarray)
46+
assert rendered_img_np.shape == np.array(mock_image).shape
47+

0 commit comments

Comments
 (0)