Skip to content

Commit 1679cac

Browse files
committed
Refactor segmentation handling to use RotatedSegmentationResult and update related tests
1 parent e528585 commit 1679cac

File tree

9 files changed

+244
-201
lines changed

9 files changed

+244
-201
lines changed

model_api/python/model_api/models/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
ImageResultWithSoftPrediction,
2121
InstanceSegmentationResult,
2222
PredictedMask,
23-
SegmentedObject,
24-
SegmentedObjectWithRects,
23+
RotatedSegmentationResult,
2524
VisualPromptingResult,
2625
ZSLVisualPromptingResult,
2726
)
@@ -94,8 +93,7 @@
9493
"classification_models",
9594
"detection_models",
9695
"segmentation_models",
97-
"SegmentedObject",
98-
"SegmentedObjectWithRects",
96+
"RotatedSegmentationResult",
9997
"add_rotated_rects",
10098
"get_contours",
10199
]

model_api/python/model_api/models/instance_segmentation.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from model_api.adapters.inference_adapter import InferenceAdapter
1010

1111
from .image_model import ImageModel
12-
from .result_types import InstanceSegmentationResult, SegmentedObject
12+
from .result_types import InstanceSegmentationResult
1313
from .types import BooleanValue, ListValue, NumericalValue, StringValue
1414
from .utils import load_labels
1515

@@ -176,27 +176,31 @@ def postprocess(self, outputs: dict, meta: dict) -> InstanceSegmentationResult:
176176
out=boxes,
177177
)
178178

179-
objects = []
180179
has_feature_vector_name = _feature_vector_name in self.outputs
181180
if has_feature_vector_name:
182181
if not self.labels:
183182
self.raise_error("Can't get number of classes because labels are empty")
184183
saliency_maps: list = [[] for _ in range(len(self.labels))]
185184
else:
186185
saliency_maps = []
187-
for box, confidence, cls, raw_mask in zip(boxes, scores, labels, masks):
188-
x1, y1, x2, y2 = box
189-
if (x2 - x1) * (y2 - y1) < 1 or (confidence <= self.confidence_threshold and not has_feature_vector_name):
190-
continue
191186

192-
# Skip if label index is out of bounds
193-
if self.labels and cls >= len(self.labels):
194-
continue
187+
# Apply confidence threshold, bounding box area filter and label index filter.
188+
keep = (scores > self.confidence_threshold) & ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) > 1)
189+
190+
if self.labels:
191+
keep &= labels < len(self.labels)
192+
193+
boxes = boxes[keep].astype(np.int32)
194+
scores = scores[keep]
195+
labels = labels[keep]
196+
masks = masks[keep]
195197

196-
# Get label string
197-
str_label = self.labels[cls] if self.labels else f"#{cls}"
198+
resized_masks, label_names = [], []
199+
for box, label_idx, raw_mask in zip(boxes, labels, masks, strict=True):
200+
if self.labels:
201+
label_names.append(self.labels[label_idx])
198202

199-
raw_cls_mask = raw_mask[cls, ...] if self.is_segmentoly else raw_mask
203+
raw_cls_mask = raw_mask[label_idx, ...] if self.is_segmentoly else raw_mask
200204
if self.postprocess_semantic_masks or has_feature_vector_name:
201205
resized_mask = _segm_postprocess(
202206
box,
@@ -205,27 +209,21 @@ def postprocess(self, outputs: dict, meta: dict) -> InstanceSegmentationResult:
205209
)
206210
else:
207211
resized_mask = raw_cls_mask
208-
if confidence > self.confidence_threshold:
209-
output_mask = resized_mask if self.postprocess_semantic_masks else raw_cls_mask
210-
xmin, ymin, xmax, ymax = box.astype(int)
211-
objects.append(
212-
SegmentedObject(
213-
xmin,
214-
ymin,
215-
xmax,
216-
ymax,
217-
score=confidence,
218-
id=cls,
219-
str_label=str_label,
220-
mask=output_mask,
221-
),
222-
)
223-
if has_feature_vector_name and confidence > self.confidence_threshold:
224-
saliency_maps[cls - 1].append(resized_mask)
212+
213+
output_mask = resized_mask if self.postprocess_semantic_masks else raw_cls_mask
214+
resized_masks.append(output_mask)
215+
if has_feature_vector_name:
216+
saliency_maps[label_idx - 1].append(resized_mask)
217+
218+
_masks = np.stack(resized_masks) if len(resized_masks) > 0 else np.empty((0, 16, 16), dtype=np.uint8)
225219
return InstanceSegmentationResult(
226-
objects,
227-
_average_and_normalize(saliency_maps),
228-
outputs.get(_feature_vector_name, np.ndarray(0)),
220+
bboxes=boxes,
221+
labels=labels,
222+
scores=scores,
223+
masks=_masks,
224+
label_names=label_names if label_names else None,
225+
saliency_map=_average_and_normalize(saliency_maps),
226+
feature_vector=outputs.get(_feature_vector_name, np.ndarray(0)),
229227
)
230228

231229

model_api/python/model_api/models/result_types/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
Contour,
1717
ImageResultWithSoftPrediction,
1818
InstanceSegmentationResult,
19-
SegmentedObject,
20-
SegmentedObjectWithRects,
19+
RotatedSegmentationResult,
2120
)
2221
from .visual_prompting import PredictedMask, VisualPromptingResult, ZSLVisualPromptingResult
2322

@@ -29,12 +28,11 @@
2928
"DetectionResult",
3029
"DetectedKeypoints",
3130
"MultipleOutputParser",
32-
"SegmentedObject",
33-
"SegmentedObjectWithRects",
3431
"SingleOutputParser",
3532
"ImageResultWithSoftPrediction",
3633
"InstanceSegmentationResult",
3734
"PredictedMask",
3835
"VisualPromptingResult",
3936
"ZSLVisualPromptingResult",
37+
"RotatedSegmentationResult",
4038
]

model_api/python/model_api/models/result_types/detection.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,22 @@ def __init__(
5353
self._feature_vector = feature_vector
5454

5555
def __len__(self) -> int:
56-
return len(self._bboxes)
56+
return len(self.bboxes)
5757

5858
def __str__(self) -> str:
59-
return (
60-
f"Num of boxes: {self._bboxes.shape}, "
61-
f"Num of labels: {len(self._labels)}, "
62-
f"Num of scores: {len(self._scores)}, "
63-
f"Saliency Map: {array_shape_to_str(self._saliency_map)}, "
64-
f"Feature Vec: {array_shape_to_str(self._feature_vector)}"
65-
)
59+
repr_str = ""
60+
for box, score, label, name in zip(
61+
self.bboxes,
62+
self.scores,
63+
self.labels,
64+
self.label_names,
65+
strict=True,
66+
):
67+
x1, y1, x2, y2 = box
68+
repr_str += f"{x1}, {y1}, {x2}, {y2}, {label} ({name}): {score:.3f}; "
69+
70+
repr_str += f"{array_shape_to_str(self.saliency_map)}; {array_shape_to_str(self.feature_vector)}"
71+
return repr_str
6672

6773
def get_obj_sizes(self) -> np.ndarray:
6874
"""Get object sizes.
@@ -117,11 +123,11 @@ def label_names(self, value):
117123
self._label_names = value
118124

119125
@property
120-
def saliency_map(self) -> np.ndarray:
126+
def saliency_map(self):
121127
return self._saliency_map
122128

123129
@saliency_map.setter
124-
def saliency_map(self, value):
130+
def saliency_map(self, value: np.ndarray):
125131
if not isinstance(value, np.ndarray):
126132
msg = "Saliency map must be numpy array."
127133
raise ValueError(msg)
@@ -168,8 +174,12 @@ def __call__(self, outputs) -> DetectionResult:
168174
labels.append(label)
169175
bboxes = np.array(bboxes)
170176
scores = np.array(scores)
171-
labels = np.array(labels)
172-
return DetectionResult(bboxes, scores, labels)
177+
labels = np.array(labels).astype(np.int32)
178+
return DetectionResult(
179+
bboxes=bboxes,
180+
labels=labels,
181+
scores=scores,
182+
)
173183

174184

175185
class MultipleOutputParser:

model_api/python/model_api/models/result_types/segmentation.py

Lines changed: 119 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,81 +10,143 @@
1010
import cv2
1111
import numpy as np
1212

13+
from .detection import DetectionResult
1314
from .utils import array_shape_to_str
1415

1516
if TYPE_CHECKING:
1617
from cv2.typing import RotatedRect
1718

1819

19-
class SegmentedObject:
20-
def __init__(
21-
self,
22-
xmin: int,
23-
ymin: int,
24-
xmax: int,
25-
ymax: int,
26-
score: float,
27-
id: int,
28-
str_label: str,
29-
mask: np.ndarray,
30-
) -> None:
31-
self.xmin = xmin
32-
self.ymin = ymin
33-
self.xmax = xmax
34-
self.ymax = ymax
35-
self.score = score
36-
self.id = id
37-
self.str_label = str_label
38-
self.mask = mask
39-
40-
def __str__(self):
41-
return (
42-
f"{self.xmin}, {self.ymin}, {self.xmax}, {self.ymax}, {self.id} ({self.str_label}): {self.score:.3f}"
43-
f", {(self.mask > 0.5).sum()}"
44-
)
45-
20+
class InstanceSegmentationResult(DetectionResult):
21+
"""Instance segmentation result type.
4622
47-
class SegmentedObjectWithRects(SegmentedObject):
48-
def __init__(self, segmented_object: SegmentedObject, rotated_rect: RotatedRect) -> None:
49-
super().__init__(
50-
segmented_object.xmin,
51-
segmented_object.ymin,
52-
segmented_object.xmax,
53-
segmented_object.ymax,
54-
segmented_object.score,
55-
segmented_object.id,
56-
segmented_object.str_label,
57-
segmented_object.mask,
58-
)
59-
self.rotated_rect = rotated_rect
23+
Args:
24+
bboxes (np.ndarray): bounding boxes in dim (N, 4) where N is the number of boxes.
25+
labels (np.ndarray): labels for each bounding box in dim (N,).
26+
masks (np.ndarray): masks for each bounding box in dim (N, H, W).
27+
scores (np.ndarray | None, optional): confidence scores for each bounding box in dim (N,). Defaults to None.
28+
label_names (list[str] | None, optional): class names for each label. Defaults to None.
29+
saliency_map (list[np.ndarray] | None, optional): saliency maps for XAI. Defaults to None.
30+
feature_vector (np.ndarray | None, optional): feature vector for XAI. Defaults to None.
31+
"""
6032

61-
def __str__(self):
62-
res = super().__str__()
63-
rect = self.rotated_rect
64-
res += f", RotatedRect: {rect[0][0]:.3f} {rect[0][1]:.3f} {rect[1][0]:.3f} {rect[1][1]:.3f} {rect[2]:.3f}"
65-
return res
33+
def __init__(
34+
self,
35+
bboxes: np.ndarray,
36+
labels: np.ndarray,
37+
masks: np.ndarray,
38+
scores: np.ndarray | None = None,
39+
label_names: list[str] | None = None,
40+
saliency_map: list[np.ndarray] | None = None,
41+
feature_vector: np.ndarray | None = None,
42+
):
43+
super().__init__(bboxes, labels, scores, label_names, saliency_map, feature_vector)
44+
self._masks = masks
45+
46+
def __str__(self) -> str:
47+
repr_str = ""
48+
for box, score, label, name, mask in zip(
49+
self.bboxes,
50+
self.scores,
51+
self.labels,
52+
self.label_names,
53+
self.masks,
54+
strict=True,
55+
):
56+
x1, y1, x2, y2 = box
57+
repr_str += f"{x1}, {y1}, {x2}, {y2}, {label} ({name}): {score:.3f}, {(mask > 0.5).sum()}; "
6658

59+
filled = 0
60+
for cls_map in self.saliency_map:
61+
if cls_map.size:
62+
filled += 1
63+
prefix = f"{repr_str}" if len(repr_str) else ""
64+
return prefix + f"{filled}; {array_shape_to_str(self.feature_vector)}"
65+
66+
@property
67+
def masks(self) -> np.ndarray:
68+
return self._masks
69+
70+
@masks.setter
71+
def masks(self, value):
72+
if not isinstance(value, np.ndarray):
73+
msg = "Masks must be numpy array."
74+
raise ValueError(msg)
75+
self._masks = value
76+
77+
@property
78+
def saliency_map(self):
79+
return self._saliency_map
80+
81+
@saliency_map.setter
82+
def saliency_map(self, value: list[np.ndarray]):
83+
if not isinstance(value, list):
84+
msg = "Saliency maps must be list."
85+
raise ValueError(msg)
86+
self._saliency_map = value
87+
88+
89+
class RotatedSegmentationResult(InstanceSegmentationResult):
90+
"""Rotated instance segmentation result type.
91+
92+
Args:
93+
bboxes (np.ndarray): bounding boxes in dim (N, 4) where N is the number of boxes.
94+
labels (np.ndarray): labels for each bounding box in dim (N,).
95+
masks (np.ndarray): masks for each bounding box in dim (N, H, W).
96+
rotated_rects (list[RotatedRect]): rotated rectangles for each bounding box.
97+
scores (np.ndarray | None, optional): confidence scores for each bounding box in dim (N,). Defaults to None.
98+
label_names (list[str] | None, optional): class names for each label. Defaults to None.
99+
saliency_map (list[np.ndarray] | None, optional): saliency maps for XAI. Defaults to None.
100+
feature_vector (np.ndarray | None, optional): feature vector for XAI. Defaults to None.
101+
"""
67102

68-
class InstanceSegmentationResult:
69103
def __init__(
70104
self,
71-
segmentedObjects: list[SegmentedObject | SegmentedObjectWithRects],
72-
saliency_map: list[np.ndarray],
73-
feature_vector: np.ndarray,
105+
bboxes: np.ndarray,
106+
labels: np.ndarray,
107+
masks: np.ndarray,
108+
rotated_rects: list[RotatedRect],
109+
scores: np.ndarray | None = None,
110+
label_names: list[str] | None = None,
111+
saliency_map: list[np.ndarray] | None = None,
112+
feature_vector: np.ndarray | None = None,
74113
):
75-
self.segmentedObjects = segmentedObjects
76-
# Contain per class saliency_maps and "feature_vector" model output if feature_vector exists
77-
self.saliency_map = saliency_map
78-
self.feature_vector = feature_vector
114+
super().__init__(bboxes, labels, masks, scores, label_names, saliency_map, feature_vector)
115+
self.rotated_rects = rotated_rects
116+
117+
def __str__(self) -> str:
118+
repr_str = ""
119+
for box, score, label, name, mask, rotated_rect in zip(
120+
self.bboxes,
121+
self.scores,
122+
self.labels,
123+
self.label_names,
124+
self.masks,
125+
self.rotated_rects,
126+
strict=True,
127+
):
128+
x1, y1, x2, y2 = box
129+
(cx, cy), (w, h), angle = rotated_rect
130+
repr_str += f"{x1}, {y1}, {x2}, {y2}, {label} ({name}): {score:.3f}, {(mask > 0.5).sum()},"
131+
repr_str += f" RotatedRect: {cx:.3f} {cy:.3f} {w:.3f} {h:.3f} {angle:.3f}; "
79132

80-
def __str__(self):
81-
obj_str = "; ".join(str(obj) for obj in self.segmentedObjects)
82133
filled = 0
83134
for cls_map in self.saliency_map:
84135
if cls_map.size:
85136
filled += 1
86-
prefix = f"{obj_str}; " if len(obj_str) else ""
87-
return prefix + f"{filled}; [{','.join(str(i) for i in self.feature_vector.shape)}]"
137+
prefix = f"{repr_str}" if len(repr_str) else ""
138+
return prefix + f"{filled}; {array_shape_to_str(self.feature_vector)}"
139+
140+
@property
141+
def rotated_rects(self) -> list[RotatedRect]:
142+
return self._rotated_rects
143+
144+
@rotated_rects.setter
145+
def rotated_rects(self, value):
146+
if not isinstance(value, list):
147+
msg = "RotatedRects must be list."
148+
raise ValueError(msg)
149+
self._rotated_rects = value
88150

89151

90152
class Contour:

0 commit comments

Comments
 (0)