Skip to content

Commit f5386a7

Browse files
authored
AC: move tensor mapping to read network (#3401)
* AC: move tensor mapping to read network * update work with batch
1 parent 99f99e7 commit f5386a7

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/openvino_launcher.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -547,27 +547,29 @@ def _set_batch_size(self, batch_size):
547547
if layer_name in self.const_inputs:
548548
input_shapes[layer_name] = parse_partial_shape(input_node.get_node().partial_shape)
549549
else:
550-
layer_shape = parse_partial_shape(input_node.get_node().partial_shape)
551-
layout = self.inputs[layer_name].layout
552-
if '...' in str(layout):
553-
layout = self.get_layout_from_config(layer_name)
554-
else:
555-
layout = str(layout).replace('[', '').replace(']', '').replace(',', '')
550+
layer_shape = list(parse_partial_shape(input_node.get_node().partial_shape))
551+
layout = self._process_layout(self.inputs[layer_name].layout, layer_name)
556552
batch_pos = layout.find('N')
557553
if batch_pos != -1:
558554
layer_shape[batch_pos] = batch_size
559555
input_shapes[layer_name] = layer_shape
560556
self._reshape_input(input_shapes, batch_size == -1)
561557
self._batch = batch_size
562558

559+
def _process_layout(self, ov_layout, layer_name):
560+
if '...' in str(ov_layout) or ov_layout is None:
561+
ov_layout = self.get_layout_from_config(layer_name)
562+
else:
563+
ov_layout = str(ov_layout).replace('[', '').replace(']', '').replace(',', '')
564+
return ov_layout
565+
563566
def _get_model_batch_size(self):
564567
input_nodes = self.network.inputs if self.network else self.exec_network.inputs
565568
input_info = input_nodes[0]
566-
layout = input_info.get_node().layout
567-
if '...' in str(layout) or layout is None:
568-
layout = self.get_layout_from_config(input_info.get_node().friendly_name)
569-
else:
570-
layout = str(layout).replace('[', '').replace(']', '').replace(',', '')
569+
layout = (
570+
self._process_layout(input_info.get_node().layout, input_info.get_node().friendly_name)
571+
or self.default_layout
572+
)
571573
batch_pos = layout.find('N')
572574
if batch_pos != -1:
573575
return parse_partial_shape(input_info.partial_shape)[batch_pos]
@@ -702,6 +704,8 @@ def read_network(self, model, weights):
702704
network = self.ie_core.read_model(model=str(model), weights=str(weights))
703705
else:
704706
network = self.ie_core.read_model(model=str(model))
707+
self.input_to_tensor_name = self.get_input_tensor_name_mapping(network)
708+
self.input_to_index = {inp.get_node().friendly_name: idx for idx, inp in enumerate(network.inputs)}
705709
return network
706710

707711
def inputs_info_for_meta(self, inputs=None):

0 commit comments

Comments
 (0)