Skip to content

Commit 1ee228f

Browse files
authored
AC: update custom sr evaluator (#3180)
1 parent 7d71879 commit 1ee228f

File tree

1 file changed

+10
-0
lines changed
  • tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators

1 file changed

+10
-0
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/sr_evaluator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def fit_to_input(self, input_data):
196196
for name, info in self.inputs.items():
197197
data = input_data[self._name_to_idx[name]]
198198
data = np.expand_dims(data, axis=0)
199+
if parse_partial_shape(info.get_partial_shape())[1] == 3:
200+
data = np.transpose(data, (0, 3, 1, 2))
199201
if not info.get_partial_shape().is_dynamic:
200202
assert tuple(parse_partial_shape(info.get_partial_shape())) == np.shape(data)
201203
fitted[name] = data
@@ -209,6 +211,14 @@ def load_network(self, network, launcher):
209211
def set_input_and_output(self):
210212
input_info = self.inputs
211213
input_blob = next(iter(input_info))
214+
out_mapping = {}
215+
outputs = self.network.outputs if self.network is not None else self.exec_network.outputs
216+
for out in outputs:
217+
if not out.names:
218+
continue
219+
for name in out.names:
220+
out_mapping[name] = out.get_node().friendly_name
221+
self.adapter.additional_output_mapping = out_mapping
212222
with_prefix = input_blob.startswith(self.default_model_suffix + '_')
213223
if (with_prefix != self.with_prefix) and with_prefix:
214224
self.network_info['feedback_input'] = '_'.join([self.default_model_suffix,

0 commit comments

Comments
 (0)