Skip to content

Commit 14224d1

Browse files
eaidovaAnna Grebneva
andauthored
AC: enable auto resize for NHWC layout (#3007)
* AC: enable auto resize for NHWC layout * Update tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/resize.py Co-authored-by: Anna Grebneva <[email protected]> Co-authored-by: Anna Grebneva <[email protected]>
1 parent 258d401 commit 14224d1

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/openvino_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def _fill_lstm_inputs(self, infer_outputs=None):
814814
for lstm_var, output_layer in self._lstm_inputs.items():
815815
layer_shape = parse_partial_shape(self.inputs[lstm_var].partial_shape)
816816
if infer_outputs and output_layer not in infer_outputs:
817-
raise 'Output node with name {} not found'.format(output_layer)
817+
raise ConfigError('Output node with name {} not found'.format(output_layer))
818818
input_data = infer_outputs[output_layer].reshape(layer_shape) if infer_outputs else np.zeros(
819819
layer_shape, dtype=format_map[self.inputs[lstm_var].element_type.get_type_name()]
820820
)

tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/resize.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,17 +479,22 @@ def __init__(self, config, name=None):
479479

480480
def set_input_shape(self, input_shape):
481481
def is_image_input(shape):
482-
return len(shape) == 4 and shape[1] in [1, 3, 4]
482+
return len(shape) == 4 and (shape[1] in [1, 3, 4] or shape[-1] in [1, 3, 4])
483+
484+
def is_nhwc(shape):
485+
return shape[-1] in [1, 3, 4]
483486
if input_shape is None:
484487
raise ConfigError('resize to input size impossible')
485488
image_inputs = [value for value in input_shape.values() if is_image_input(value)]
486489
if not image_inputs:
487490
raise ConfigError('image input is not detected')
488491
if len(image_inputs) == 1:
489-
self.dst_height, self.dst_width = image_inputs[0][2:]
492+
self.dst_height, self.dst_width = (
493+
image_inputs[0][2:] if not is_nhwc(image_inputs[0]) else image_inputs[0][1:3]
494+
)
490495
else:
491-
self.dst_height = [im_input[2] for im_input in image_inputs]
492-
self.dst_width = [im_input[3] for im_input in image_inputs]
496+
self.dst_height = [im_input[2] if not is_nhwc(im_input) else im_input[1] for im_input in image_inputs]
497+
self.dst_width = [im_input[3] if not is_nhwc(im_input) else im_input[2] for im_input in image_inputs]
493498

494499
def process(self, image, annotation_meta=None):
495500
is_simple_case = not isinstance(image.data, list) # otherwise -- pyramid, tiling, etc

0 commit comments

Comments
 (0)