Skip to content

Commit d674b9f

Browse files
authored
AC: allow resize if impossible extend mask (#2918)
1 parent 44afbd0 commit d674b9f

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/extend_segmentation_mask.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from .postprocessor import Postprocessor
2121
from .resize_segmentation_mask import ResizeSegmentationMask
2222
from ..representation import SegmentationAnnotation, SegmentationPrediction
23-
from ..config import NumberField, ConfigError, StringField
23+
from ..config import NumberField, StringField
2424
from ..preprocessor.geometric_transformations import padding_func
25+
from ..logging import warning
2526

2627

2728
class ExtendSegmentationMask(Postprocessor):
@@ -58,7 +59,10 @@ def process_image(self, annotation, prediction):
5859
dst_height, dst_width = prediction_.mask.shape[-2:]
5960
height, width = annotation_mask.shape[-2:]
6061
if dst_width < width or dst_height < height:
61-
raise ConfigError('size for extending should be not less current mask size')
62+
warning('size for extending should be not less current mask size. resize operation will be applied')
63+
annotation_.mask = self._resize(annotation_.mask, dst_height, dst_width, False)
64+
continue
65+
6266

6367
pad = self.pad_func(dst_width, dst_height, width, height)
6468
extended_mask = cv2.copyMakeBorder(
@@ -75,18 +79,6 @@ def process_image_with_metadata(self, annotation, prediction, image_metadata=Non
7579

7680
@staticmethod
7781
def _deprocess_prediction(prediction, meta):
78-
def _resize(entry, height, width):
79-
if len(entry.shape) == 2:
80-
entry = ResizeSegmentationMask.segm_resize(entry, width, height)
81-
return entry
82-
83-
entry_mask = []
84-
for class_mask in entry:
85-
resized_mask = ResizeSegmentationMask.segm_resize(class_mask, width, height)
86-
entry_mask.append(resized_mask)
87-
entry = np.array(entry_mask)
88-
89-
return entry
9082
geom_ops = meta.get('geometric_operations', [])
9183
pad = geom_ops[-1].parameters['pad'] if geom_ops and geom_ops[-1].type == 'padding' else [0, 0, 0, 0]
9284
image_h, image_w = meta['image_size'][:2]
@@ -96,7 +88,21 @@ def _resize(entry, height, width):
9688
pred_mask = prediction_.mask[pad[0]:pred_w-pad[2], pad[1]:pred_h-pad[3]]
9789
pred_h, pred_w = pred_mask.shape[-2:]
9890
if (pred_h, pred_w) != (image_h, image_w):
99-
pred_mask = _resize(pred_mask, image_h, image_w)
91+
pred_mask = ExtendSegmentationMask._resize(pred_mask, image_h, image_w)
10092
prediction_.mask = pred_mask
10193

10294
return prediction
95+
96+
@staticmethod
97+
def _resize(entry, height, width, per_class=True):
98+
if len(entry.shape) == 2 or not per_class:
99+
entry = ResizeSegmentationMask.segm_resize(entry, width, height)
100+
return entry
101+
102+
entry_mask = []
103+
for class_mask in entry:
104+
resized_mask = ResizeSegmentationMask.segm_resize(class_mask, width, height)
105+
entry_mask.append(resized_mask)
106+
entry = np.array(entry_mask)
107+
108+
return entry

tools/accuracy_checker/tests/test_postprocessor.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -996,13 +996,6 @@ def test_extend_segmentation_mask_asymmetrical(self):
996996
assert np.array_equal(prediction[0].mask, expected_prediction_mask)
997997
assert np.array_equal(annotation[0].mask, expected_annotation_mask)
998998

999-
def test_extend_segmentation_mask_raise_config_error_if_prediction_less_annotation(self):
1000-
config = [{'type': 'extend_segmentation_mask'}]
1001-
annotation = make_segmentation_representation(np.zeros((5, 5)), ground_truth=True)
1002-
prediction = make_segmentation_representation(np.zeros((4, 4)), ground_truth=False)
1003-
with pytest.raises(ConfigError):
1004-
postprocess_data(PostprocessingExecutor(config), annotation, prediction)
1005-
1006999
def test_extend_segmentation_mask_with_filling_label(self):
10071000
config = [{'type': 'extend_segmentation_mask', 'filling_label': 1}]
10081001
annotation = make_segmentation_representation(np.zeros((5, 5)), ground_truth=True)

0 commit comments

Comments
 (0)