Skip to content

Commit e2c7806

Browse files
authored
AC: fix mtcnn quantization (#2995)
1 parent adc84bd commit e2c7806

File tree

2 files changed

+574
-331
lines changed

2 files changed

+574
-331
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/mtcnn_evaluator.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,26 @@
1818
from functools import partial
1919
import numpy as np
2020

21-
from .mtcnn_models import MTCNNCascadeModel
21+
from .mtcnn_models import build_stages
2222
from .mtcnn_evaluator_utils import transform_for_callback
2323
from .base_custom_evaluator import BaseCustomEvaluator
24+
from ..quantization_model_evaluator import create_dataset_attributes
2425

2526

2627
class MTCNNEvaluator(BaseCustomEvaluator):
27-
def __init__(self, dataset_config, launcher, model, orig_config):
28+
def __init__(self, dataset_config, launcher, stages, orig_config):
2829
super().__init__(dataset_config, launcher, orig_config)
29-
self.model = model
30-
if hasattr(self.model, 'adapter') and self.model.adapter is not None:
31-
self.adapter_type = self.model.adapter.__provider__
30+
self.stages = stages
31+
stage = next(iter(self.stages.values()))
32+
if hasattr(stage, 'adapter') and stage.adapter is not None:
33+
self.adapter_type = stage.adapter.__provider__
3234

3335
@classmethod
3436
def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
3537
dataset_config, launcher, _ = cls.get_dataset_and_launcher_info(config)
36-
model = MTCNNCascadeModel(
37-
config.get('network_info', {}), launcher, config.get('_models', []), delayed_model_loading
38-
)
39-
return cls(dataset_config, launcher, model, orig_config)
38+
models_info = config['network_info']
39+
stages = build_stages(models_info, [], launcher, config.get('_models'), delayed_model_loading)
40+
return cls(dataset_config, launcher, stages, orig_config)
4041

4142
def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file):
4243
def no_detections(batch_pred):
@@ -50,7 +51,7 @@ def no_detections(batch_pred):
5051
intermediate_callback = partial(output_callback, metrics_result=None,
5152
element_identifiers=batch_identifiers, dataset_indices=batch_input_ids)
5253
batch_size = 1
53-
for stage in self.model.stages.values():
54+
for stage in self.stages.values():
5455
previous_stage_predictions = batch_prediction
5556
filled_inputs, batch_meta = stage.preprocess_data(
5657
copy.deepcopy(batch_inputs), batch_annotation, previous_stage_predictions
@@ -71,10 +72,37 @@ def no_detections(batch_pred):
7172
dataset_indices=batch_input_ids)
7273
self._update_progress(progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file)
7374

75+
def _release_model(self):
76+
for _, stage in self.stages.items():
77+
stage.release()
78+
7479
def reset(self):
75-
self.model.reset()
80+
super().reset()
81+
for _, stage in self.stages.items():
82+
stage.reset()
83+
84+
def load_network(self, network=None):
85+
if network is None:
86+
for stage_name, stage in self.stages.items():
87+
stage.load_network(network, self.launcher, stage_name + '_')
88+
else:
89+
for net_dict in network:
90+
stage_name = net_dict['name']
91+
network_ = net_dict['model']
92+
self.stages[stage_name].load_network(network_, self.launcher, stage_name + '_')
93+
94+
def load_network_from_ir(self, models_list):
95+
for models_dict in models_list:
96+
stage_name = models_dict['name']
97+
self.stages[stage_name].load_model(models_dict, self.launcher, stage_name + '_')
98+
99+
def get_network(self):
100+
return [{'name': stage_name, 'model': stage.network} for stage_name, stage in self.stages.items()]
76101

77102
def select_dataset(self, dataset_tag):
78-
super().select_dataset(dataset_tag)
79-
for _, stage in self.model.stages.items():
80-
stage.update_preprocessing(self.preprocessor)
103+
if self.dataset is not None and isinstance(self.dataset_config, list):
104+
return
105+
dataset_attributes = create_dataset_attributes(self.dataset_config, dataset_tag)
106+
self.dataset, self.metric_executor, preprocessor, self.postprocessor = dataset_attributes
107+
for _, stage in self.stages.items():
108+
stage.update_preprocessing(preprocessor)

0 commit comments

Comments
 (0)