Skip to content

Commit f5b5986

Browse files
author
Anna Grebneva
authored
Fixed background_matting_evaluator in case quantization (#3296)
1 parent b0482a1 commit f5b5986

File tree

2 files changed

+46
-25
lines changed

2 files changed

+46
-25
lines changed

models/public/robust-video-matting-mobilenetv3/accuracy-check.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ evaluations:
22
- name: robust-video-matting-mobilenetv3
33
module: custom_evaluators.sequential_background_matting_evaluator.SequentialBackgroundMatting
44
module_config:
5-
5+
network_info:
6+
background_matting_model: {}
67
launchers:
78
- framework: openvino
89
adapter:

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/sequential_background_matting_evaluator.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,10 @@
2525

2626

2727
class SequentialBackgroundMatting(BaseCustomEvaluator):
28-
def __init__(self, dataset_config, launcher, model, input_feeder, adapter, orig_config):
28+
def __init__(self, dataset_config, launcher, model, adapter, orig_config):
2929
super().__init__(dataset_config, launcher, orig_config)
3030
self.model = model
31-
self.input_feeder = input_feeder
3231
self.adapter = adapter
33-
self.inputs = input_feeder.network_inputs
3432

3533
def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file):
3634
previous_video_id = None
@@ -42,10 +40,10 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
4240
filled_inputs = self.input_feeder.fill_inputs(batch_inputs)
4341
for i, filled_input in enumerate(filled_inputs):
4442
filled_input.update(rnn_inputs[i])
45-
batch_raw_results = self.model.predict(batch_identifiers, filled_inputs)
46-
batch_predictions = self.adapter.process(batch_raw_results, batch_identifiers, batch_meta)
43+
batch_raw_results, batch_results = self.model.predict(batch_identifiers, filled_inputs)
44+
batch_predictions = self.adapter.process(batch_results, batch_identifiers, batch_meta)
4745
previous_video_id = batch_annotation[0].video_id
48-
rnn_inputs = self.model.set_rnn_inputs(batch_raw_results)
46+
rnn_inputs = self.model.set_rnn_inputs(batch_results)
4947
annotation, prediction = self.postprocessor.process_batch(batch_annotation, batch_predictions)
5048
metrics_result = self._get_metrics_result(batch_input_ids, annotation, prediction, calculate_metrics)
5149
if output_callback:
@@ -60,21 +58,16 @@ def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
6058
config.get('network_info', {}), launcher, config.get('_models', []), config.get('_model_is_blob'),
6159
delayed_model_loading
6260
)
63-
launcher.network = model.model.network
64-
launcher_config = config['launchers'][0]
65-
postpone_model_loading = False
66-
input_precision = launcher_config.get('_input_precision', [])
67-
input_layouts = launcher_config.get('_input_layout', '')
68-
input_feeder = InputFeeder(
69-
launcher.config.get('inputs', []), model.model.launcher.inputs, cls.input_shape, launcher.fit_to_input,
70-
launcher.default_layout, launcher_config['framework'] == 'dummy' or postpone_model_loading, input_precision,
71-
input_layouts
72-
)
73-
adapter = create_adapter(launcher_config.get('adapter', 'background_matting_with_pha_and_fgr'))
74-
return cls(dataset_config, launcher, model, input_feeder, adapter, orig_config)
61+
adapter = create_adapter(config['launchers'][0].get('adapter', 'background_matting_with_pha_and_fgr'))
62+
return cls(dataset_config, launcher, model, adapter, orig_config)
7563

76-
def input_shape(self, input_name):
77-
return self.inputs[input_name]
64+
@property
65+
def input_feeder(self):
66+
return self.model.model.input_feeder
67+
68+
@property
69+
def inputs(self):
70+
return self.input_feeder.network_inputs
7871

7972

8073
class SequentialBackgroundMattingModel(BaseCascadeModel):
@@ -99,10 +92,12 @@ def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_l
9992

10093
def predict(self, identifiers, input_data):
10194
batch_raw_results = []
95+
batch_results = []
10296
for identifier, data in zip(identifiers, input_data):
103-
raw_results = self.model.predict(identifier, data)
97+
results, raw_results = self.model.predict(identifier, data)
10498
batch_raw_results.append(raw_results)
105-
return batch_raw_results
99+
batch_results.append(results)
100+
return batch_raw_results, batch_results
106101

107102
def reset_rnn_inputs(self, batch_size):
108103
output = []
@@ -129,9 +124,34 @@ def set_rnn_inputs(self, outputs):
129124

130125
class DLSDKSequentialBackgroundMattingModel(BaseDLSDKModel):
131126
def predict(self, identifiers, input_data):
132-
return self.exec_network.infer(input_data)
127+
outputs = self.exec_network.infer(input_data)
128+
if isinstance(outputs, tuple):
129+
outputs, raw_outputs = outputs
130+
else:
131+
raw_outputs = outputs
132+
return outputs, raw_outputs
133+
134+
def input_shape(self, input_name):
135+
return self.launcher.inputs[input_name]
136+
137+
def load_network(self, network, launcher):
138+
super().load_network(network, launcher)
139+
self.launcher = launcher
140+
self.launcher.network = self.network
141+
self.input_feeder = InputFeeder(self.launcher.config.get('inputs', []), self.launcher.inputs, self.input_shape,
142+
self.launcher.fit_to_input, self.launcher.default_layout)
133143

134144

135145
class OpenVINOModelSequentialBackgroundMattingModel(BaseOpenVINOModel):
136146
def predict(self, identifiers, input_data):
137-
return self.infer(input_data)
147+
return self.infer(input_data, raw_results=True)
148+
149+
def input_shape(self, input_name):
150+
return self.launcher.inputs[input_name]
151+
152+
def load_network(self, network, launcher):
153+
super().load_network(network, launcher)
154+
self.launcher = launcher
155+
self.launcher.network = self.network
156+
self.input_feeder = InputFeeder(self.launcher.config.get('inputs', []), self.launcher.inputs, self.input_shape,
157+
self.launcher.fit_to_input, self.launcher.default_layout)

0 commit comments

Comments
 (0)