Skip to content

Commit adc84bd

Browse files
authored
AC: fix output blob selection for adapter (#2991)
* AC: fix output blob selection for adapter * fix cascade model for single model evaluators
1 parent d67718f commit adc84bd

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _extract_predictions(outputs_list, meta):
102102

103103
def select_output_blob(self, outputs):
104104
if self.output_blob is None:
105-
self.output_blob = next(iter(outputs))
105+
self.output_blob = next(iter(outputs)) if isinstance(outputs, dict) else next(iter(outputs[0]))
106106

107107
@classmethod
108108
def validation_scheme(cls, provider=None):

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/base_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,16 @@ def release(self):
5555
model.release()
5656

5757
def load_network(self, network_list, launcher):
58+
if len(self._part_by_name) == 1 and 'name' not in network_list[0]:
59+
next(iter(self._part_by_name.values())).load_model(network_list[0]['model'], launcher)
60+
return
5861
for network_dict in network_list:
5962
self._part_by_name[network_dict['name']].load_network(network_dict['model'], launcher)
6063

6164
def load_model(self, network_list, launcher):
65+
if len(self._part_by_name) == 1 and 'name' not in network_list[0]:
66+
next(iter(self._part_by_name.values())).load_model(network_list[0], launcher)
67+
return
6268
for network_dict in network_list:
6369
self._part_by_name[network_dict['name']].load_model(network_dict, launcher)
6470

@@ -204,6 +210,8 @@ def set_input_and_output(self):
204210
self.input_blob = input_blob
205211
self.output_blob = output_blob
206212
self.with_prefix = with_prefix
213+
if hasattr(self, 'adapter') and self.adapter is not None:
214+
self.adapter.output_blob = output_blob
207215

208216
def load_model(self, network_info, launcher, log=False):
209217
if 'onnx_model' in network_info:
@@ -326,6 +334,8 @@ def set_input_and_output(self):
326334
self.input_blob = input_blob
327335
self.output_blob = output_blob
328336
self.with_prefix = with_prefix
337+
if hasattr(self, 'adapter') and self.adapter is not None:
338+
self.adapter.output_blob = output_blob
329339

330340
@property
331341
def inputs(self):

0 commit comments

Comments
 (0)