Skip to content

Commit 67c3a60

Browse files
Add PoC 2
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 2d0849e commit 67c3a60

File tree

4 files changed

+136
-144
lines changed

4 files changed

+136
-144
lines changed

model_api/python/model_api/models/result_types/anomaly.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import cv2
99
import numpy as np
1010

11+
from model_api.visualizer.layout import Flatten, Layout
1112
from model_api.visualizer.primitives import BoundingBoxes, Label, Overlay, Polygon
1213

1314
from .base import Result
@@ -57,3 +58,11 @@ def _register_primitives(self) -> None:
5758
self._add_primitive(Label(self.pred_label, bg_color="red" if self.pred_label == "Anomaly" else "green"))
5859
self._add_primitive(Label(f"Score: {self.pred_score}"))
5960
self._add_primitive(Polygon(mask=self.pred_mask))
61+
62+
@property
63+
def default_layout(self) -> Layout:
64+
return Flatten(
65+
Overlay,
66+
Polygon,
67+
Label,
68+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Visualization Layout"""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
from __future__ import annotations
6+
7+
from abc import ABC
8+
from typing import TYPE_CHECKING, Type
9+
10+
from PIL import Image
11+
12+
if TYPE_CHECKING:
13+
from model_api.visualizer.primitives import Primitive
14+
15+
from .visualize_mixin import VisualizeMixin
16+
17+
18+
class Layout(ABC):
19+
"""Base class for layouts."""
20+
21+
def _compute_on_primitive(self, primitive: Primitive, image: Image, result: VisualizeMixin) -> Image | None:
22+
if result.has_primitive(primitive):
23+
primitives = result.get_primitive(primitive)
24+
for primitive in primitives:
25+
image = primitive.compute(image)
26+
return image
27+
return None
28+
29+
30+
class HStack(Layout):
31+
"""Horizontal stack layout."""
32+
33+
def __init__(self, *args: Layout | Type[Primitive]) -> None:
34+
self.children = args
35+
36+
def __call__(self, image: Image, result: VisualizeMixin) -> Image:
37+
images: list[Image] = []
38+
for child in self.children:
39+
if isinstance(child, Layout):
40+
images.append(child(image, result))
41+
else:
42+
_image = image.copy()
43+
_image = self._compute_on_primitive(child, _image, result)
44+
if _image is not None:
45+
images.append(_image)
46+
return self._stitch(*images)
47+
48+
def _stitch(self, *images: Image) -> Image:
49+
"""Stitch images together.
50+
51+
Args:
52+
images (Image): Images to stitch.
53+
54+
Returns:
55+
Image: Stitched image.
56+
"""
57+
new_image = Image.new(
58+
"RGB",
59+
(
60+
sum(image.width for image in images),
61+
max(image.height for image in images),
62+
),
63+
)
64+
x_offset = 0
65+
for image in images:
66+
new_image.paste(image, (x_offset, 0))
67+
x_offset += image.width
68+
return new_image
69+
70+
71+
class VStack(Layout):
72+
"""Vertical stack layout."""
73+
74+
75+
class Flatten(Layout):
76+
"""Put all primitives on top of each other"""
77+
78+
def __init__(self, *args: Type[Primitive]) -> None:
79+
self.children = args
80+
81+
def __call__(self, image: Image, result: VisualizeMixin) -> Image:
82+
_image: Image = image.copy()
83+
for child in self.children:
84+
_image = self._compute_on_primitive(child, _image, result)
85+
return _image

model_api/python/model_api/visualizer/visualize_mixin.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from abc import ABC, abstractmethod
7+
from typing import Type
78

9+
from .layout import Layout
810
from .primitives import BoundingBoxes, Label, Overlay, Polygon, Primitive
911

1012

@@ -22,6 +24,11 @@ def __init__(self) -> None:
2224
def _register_primitives(self) -> None:
2325
"""Convert result entities to primitives."""
2426

27+
@property
28+
@abstractmethod
29+
def default_layout(self) -> Layout:
30+
"""Default layout."""
31+
2532
def _add_primitive(self, primitive: Primitive) -> None:
2633
"""Add primitive."""
2734
if isinstance(primitive, Label):
@@ -33,49 +40,32 @@ def _add_primitive(self, primitive: Primitive) -> None:
3340
elif isinstance(primitive, BoundingBoxes):
3441
self._bounding_boxes.append(primitive)
3542

36-
@property
37-
def has_labels(self) -> bool:
38-
"""Check if there are labels."""
39-
self._register_primitives_if_needed()
40-
return bool(self._labels)
41-
42-
@property
43-
def has_bounding_boxes(self) -> bool:
44-
"""Check if there are bounding boxes."""
45-
self._register_primitives_if_needed()
46-
return bool(self._bounding_boxes)
47-
48-
@property
49-
def has_polygons(self) -> bool:
50-
"""Check if there are polygons."""
51-
self._register_primitives_if_needed()
52-
return bool(self._polygons)
53-
54-
@property
55-
def has_overlays(self) -> bool:
56-
"""Check if there are overlays."""
57-
self._register_primitives_if_needed()
58-
return bool(self._overlays)
59-
60-
def get_labels(self) -> list[Label]:
61-
"""Get labels."""
43+
def has_primitive(self, primitive: Type[Primitive]) -> bool:
44+
"""Check if the primitive type is registered."""
6245
self._register_primitives_if_needed()
63-
return self._labels
64-
65-
def get_polygons(self) -> list[Polygon]:
66-
"""Get polygons."""
67-
self._register_primitives_if_needed()
68-
return self._polygons
69-
70-
def get_overlays(self) -> list[Overlay]:
71-
"""Get overlays."""
72-
self._register_primitives_if_needed()
73-
return self._overlays
74-
75-
def get_bounding_boxes(self) -> list[BoundingBoxes]:
76-
"""Get bounding boxes."""
46+
if primitive == Label:
47+
return bool(self._labels)
48+
if primitive == Polygon:
49+
return bool(self._polygons)
50+
if primitive == Overlay:
51+
return bool(self._overlays)
52+
if primitive == BoundingBoxes:
53+
return bool(self._bounding_boxes)
54+
return False
55+
56+
def get_primitive(self, primitive: Type[Primitive]) -> Primitive:
57+
"""Get primitive."""
7758
self._register_primitives_if_needed()
78-
return self._bounding_boxes
59+
if primitive == Label:
60+
return self._labels
61+
if primitive == Polygon:
62+
return self._polygons
63+
if primitive == Overlay:
64+
return self._overlays
65+
if primitive == BoundingBoxes:
66+
return self._bounding_boxes
67+
msg = f"Primitive {primitive} not found"
68+
raise ValueError(msg)
7969

8070
def _register_primitives_if_needed(self):
8171
if not self._registered_primitives:

model_api/python/model_api/visualizer/visualizer.py

Lines changed: 11 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -8,130 +8,38 @@
88
from enum import Enum
99
from typing import TYPE_CHECKING
1010

11-
from PIL import Image
12-
1311
from model_api.visualizer.primitives import Label
1412

1513
if TYPE_CHECKING:
16-
from model_api.visualizer.visualize_mixin import VisualizeMixin
14+
from PIL import Image
1715

16+
from model_api.visualizer.visualize_mixin import VisualizeMixin
1817

19-
class VisualizationType(Enum):
20-
"""Visualization type."""
21-
22-
FULL = "full"
23-
SIMPLE = "simple"
18+
from .layout import Layout
2419

2520

2621
class Visualizer:
27-
def __init__(self) -> None:
28-
# TODO: add transforms for the source image so that it has the same crop, and size as the model.
29-
pass
22+
def __init__(self, layout: Layout | None = None) -> None:
23+
self.layout = layout
3024

3125
def show(
3226
self,
3327
image: Image,
3428
result: VisualizeMixin,
35-
visualization_type: VisualizationType | str = VisualizationType.SIMPLE,
3629
) -> None:
37-
visualization_type = VisualizationType(visualization_type)
38-
result: Image = self._generate(image, result, visualization_type)
30+
result: Image = self._generate(image, result)
3931
result.show()
4032

4133
def save(
4234
self,
4335
image: Image,
4436
result: VisualizeMixin,
4537
path: str,
46-
visualization_type: VisualizationType | str = VisualizationType.SIMPLE,
4738
) -> None:
48-
visualization_type = VisualizationType(visualization_type)
49-
result: Image = self._generate(image, result, visualization_type)
39+
result: Image = self._generate(image, result)
5040
result.save(path)
5141

52-
def _generate(self, image: Image, result: VisualizeMixin, visualization_type: VisualizationType) -> Image:
53-
_result: Image
54-
if visualization_type == VisualizationType.SIMPLE:
55-
_result = self._generate_simple(image, result)
56-
else:
57-
_result = self._generate_full(image, result)
58-
return _result
59-
60-
def _generate_simple(self, image: Image, result: VisualizeMixin) -> Image:
61-
"""Return a single image with stacked visualizations."""
62-
# 1. Use Overlay
63-
_image = image.copy()
64-
if result.has_overlays:
65-
overlays = result.get_overlays()
66-
for overlay in overlays:
67-
image = overlay.compute(_image)
68-
69-
elif result.has_polygons: # 2. else use polygons
70-
polygons = result.get_polygons()
71-
for polygon in polygons:
72-
image = polygon.compute(_image)
73-
74-
elif result.has_bounding_boxes: # 3. else use bounding boxes
75-
bounding_boxes = result.get_bounding_boxes()
76-
for bounding_box in bounding_boxes:
77-
image = bounding_box.compute(_image)
78-
79-
# Finally add labels
80-
if result.has_labels:
81-
labels = result.get_labels()
82-
label_images = []
83-
for label in labels:
84-
label_images.append(label.compute(_image, overlay_on_image=False))
85-
_image = Label.overlay_labels(_image, label_images)
86-
87-
return _image
88-
89-
def _generate_full(self, image: Image, result: VisualizeMixin) -> Image:
90-
"""Return a single image with visualizations side by side."""
91-
images: list[Image] = [image]
92-
93-
if result.has_overlays:
94-
overlays = result.get_overlays()
95-
_image = image.copy()
96-
for overlay in overlays:
97-
_image = overlay.compute(_image)
98-
images.append(_image)
99-
if result.has_polygons:
100-
polygons = result.get_polygons()
101-
_image = image.copy()
102-
for polygon in polygons:
103-
_image = polygon.compute(_image)
104-
images.append(_image)
105-
if result.has_bounding_boxes:
106-
bounding_boxes = result.get_bounding_boxes()
107-
_image = image.copy()
108-
for bounding_box in bounding_boxes:
109-
_image = bounding_box.compute(_image)
110-
images.append(_image)
111-
if result.has_labels:
112-
labels = result.get_labels()
113-
for label in labels:
114-
images.append(label.compute(image.copy(), overlay_on_image=True))
115-
return self._stitch(*images)
116-
117-
def _stitch(self, *images: Image) -> Image:
118-
"""Stitch images together.
119-
120-
Args:
121-
images (Image): Images to stitch.
122-
123-
Returns:
124-
Image: Stitched image.
125-
"""
126-
new_image = Image.new(
127-
"RGB",
128-
(
129-
sum(image.width for image in images),
130-
max(image.height for image in images),
131-
),
132-
)
133-
x_offset = 0
134-
for image in images:
135-
new_image.paste(image, (x_offset, 0))
136-
x_offset += image.width
137-
return new_image
42+
def _generate(self, image: Image, result: VisualizeMixin) -> Image:
43+
if self.layout is not None:
44+
return self.layout(image, result)
45+
return result.default_layout(image, result)

0 commit comments

Comments
 (0)