Skip to content

Commit 06c0822

Browse files
authored
AC: handle port name in sequence inference (#3018)
* AC: handle port name in sequence inference * reshape using indexes * input output info better logging
1 parent 5c84931 commit 06c0822

File tree

4 files changed

+35
-20
lines changed

4 files changed

+35
-20
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/adapter.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ..config import BaseField, ConfigValidator, StringField, ConfigError
1818
from ..dependency import ClassProvider, UnregisteredProviderException
19-
from ..utils import get_parameter_value_from_config
19+
from ..utils import get_parameter_value_from_config, postprocess_output_name
2020

2121

2222
class Adapter(ClassProvider):
@@ -53,18 +53,7 @@ def configure(self):
5353

5454
@staticmethod
5555
def check_output_name(output_name, outputs, suffix=('/sink_port_0', ':0')):
56-
suffixes = [suffix] if isinstance(suffix, str) else suffix
57-
outputs = outputs[0] if isinstance(outputs, list) else outputs
58-
if output_name in outputs:
59-
return output_name
60-
for suffix_ in suffixes:
61-
if suffix_ in output_name:
62-
preprocessed_output_name = output_name.replace(suffix_, '')
63-
else:
64-
preprocessed_output_name = '{}{}'.format(output_name, suffix_)
65-
if preprocessed_output_name in outputs:
66-
return preprocessed_output_name
67-
return output_name
56+
return postprocess_output_name(output_name, outputs, suffix, raise_error=False)
6857

6958
@classmethod
7059
def validate_config(cls, config, fetch_only=False, uri_prefix='', **kwargs):
@@ -121,6 +110,7 @@ def reset(self):
121110
def release(self):
122111
pass
123112

113+
124114
class AdapterField(BaseField):
125115
def validate(self, entry, field_uri=None, fetch_only=False, validation_scheme=None):
126116
errors_stack = super().validate(entry, field_uri, fetch_only, validation_scheme)

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/base_models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,18 @@ def input_tensors_mapping(self):
262262

263263
return node_to_tensor
264264

265+
def input_index_mapping(self):
266+
inputs = self.network.inputs if self.network is not None else self.exec_network.inputs
267+
return {inp.get_node().friendly_name: idx for idx, inp in enumerate(inputs)}
268+
265269
def _reshape_input(self, input_shapes):
266270
if self.is_dynamic:
267271
return
268272
if hasattr(self, 'exec_network') and self.exec_network is not None:
269273
del self.infer_request
270274
del self.exec_network
271-
tensor_mapping = self.input_tensors_mapping()
272-
input_shapes_for_tensors = {tensor_mapping[name]: shape for name, shape in input_shapes.items()}
275+
index_mapping = self.input_index_mapping()
276+
input_shapes_for_tensors = {index_mapping[name]: shape for name, shape in input_shapes.items()}
273277
self.launcher.reshape_network(self.network, input_shapes_for_tensors)
274278
self.dynamic_inputs, self.partial_shapes = self.launcher.get_dynamic_inputs(self.network)
275279
if not self.is_dynamic and self.dynamic_inputs:

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/openvino_launcher.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
string_to_tuple,
3939
get_or_parse_value,
4040
parse_partial_shape,
41+
postprocess_output_name
4142
)
4243
from .launcher import Launcher
4344
from ..logging import print_info
@@ -330,7 +331,7 @@ def _reshape_input(self, shapes, make_dynamic=False):
330331
p_shape = PartialShape(
331332
[Dimension(d) if not isinstance(d, tuple) else Dimension(d[0], d[1]) for d in shape]
332333
)
333-
partial_shapes[self.input_to_tensor_name[name]] = p_shape
334+
partial_shapes[self.input_to_index[name]] = p_shape
334335

335336
self.network.reshape(partial_shapes)
336337
self.dyn_input_layers, self._partial_shapes = self.get_dynamic_inputs(self.network)
@@ -548,6 +549,7 @@ def load_network(self, network=None, log=False, preprocessing=None):
548549
if self.network is not None:
549550
self.dyn_input_layers, self._partial_shapes = self.get_dynamic_inputs(self.network)
550551
self.input_to_tensor_name = self.get_input_tensor_name_mapping(self.network)
552+
self.input_to_index = {inp.get_node().friendly_name: idx for idx, inp in enumerate(self.network.inputs)}
551553

552554
if not self._postpone_input_configuration:
553555
self._set_precision()
@@ -813,8 +815,8 @@ def _fill_lstm_inputs(self, infer_outputs=None):
813815
feed_dict = {}
814816
for lstm_var, output_layer in self._lstm_inputs.items():
815817
layer_shape = parse_partial_shape(self.inputs[lstm_var].partial_shape)
816-
if infer_outputs and output_layer not in infer_outputs:
817-
raise ConfigError('Output node with name {} not found'.format(output_layer))
818+
if infer_outputs:
819+
output_layer = postprocess_output_name(output_layer, infer_outputs)
818820
input_data = infer_outputs[output_layer].reshape(layer_shape) if infer_outputs else np.zeros(
819821
layer_shape, dtype=format_map[self.inputs[lstm_var].element_type.get_type_name()]
820822
)
@@ -831,13 +833,15 @@ def print_input_output_info(network, prefix=None):
831833
network_outputs = network.outputs
832834
for input_info in network_inputs:
833835
input_node = input_info.get_node()
834-
print_info('\tLayer name: {}'.format(input_node.friendly_name))
836+
print_info('\tNode name: {}'.format(input_node.friendly_name))
837+
print_info('\tTensor names: {}'.format(', '.join(input_info.get_names())))
835838
print_info('\tprecision: {}'.format(input_node.element_type.get_type_name()))
836839
print_info('\tshape: {}\n'.format(parse_partial_shape(input_node.get_partial_shape())))
837840
print_info('Output info')
838841
for output_info in network_outputs:
839842
out_node = output_info.get_node()
840-
print_info('\tLayer name: {}'.format(out_node.friendly_name))
843+
print_info('\tNode name: {}'.format(out_node.friendly_name))
844+
print_info('\tTensor names: {}'.format(', '.join(output_info.get_names())))
841845
precision = out_node.get_output_element_type(0).get_type_name()
842846
print_info('\tprecision: {}'.format(precision))
843847
shape = parse_partial_shape(out_node.get_output_partial_shape(0))

tools/accuracy_checker/openvino/tools/accuracy_checker/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,3 +975,20 @@ def parse_partial_shape(partial_shape):
975975
shape_list.append(string_to_tuple(shape_range, casting_type=int))
976976
s_pos = min(close_brace + 2, e_pos)
977977
return shape_list
978+
979+
980+
def postprocess_output_name(output_name, outputs, suffix=('/sink_port_0', ':0'), raise_error=True):
981+
suffixes = [suffix] if isinstance(suffix, str) else suffix
982+
outputs = outputs[0] if isinstance(outputs, list) else outputs
983+
if output_name in outputs:
984+
return output_name
985+
for suffix_ in suffixes:
986+
if suffix_ in output_name:
987+
preprocessed_output_name = output_name.replace(suffix_, '')
988+
else:
989+
preprocessed_output_name = '{}{}'.format(output_name, suffix_)
990+
if preprocessed_output_name in outputs:
991+
return preprocessed_output_name
992+
if raise_error:
993+
raise ValueError('Output name: {} not found'.format(output_name))
994+
return output_name

0 commit comments

Comments
 (0)