18
18
from functools import partial
19
19
import numpy as np
20
20
21
- from .mtcnn_models import MTCNNCascadeModel
21
+ from .mtcnn_models import build_stages
22
22
from .mtcnn_evaluator_utils import transform_for_callback
23
23
from .base_custom_evaluator import BaseCustomEvaluator
24
+ from ..quantization_model_evaluator import create_dataset_attributes
24
25
25
26
26
27
class MTCNNEvaluator (BaseCustomEvaluator ):
27
- def __init__ (self , dataset_config , launcher , model , orig_config ):
28
+ def __init__ (self , dataset_config , launcher , stages , orig_config ):
28
29
super ().__init__ (dataset_config , launcher , orig_config )
29
- self .model = model
30
- if hasattr (self .model , 'adapter' ) and self .model .adapter is not None :
31
- self .adapter_type = self .model .adapter .__provider__
30
+ self .stages = stages
31
+ stage = next (iter (self .stages .values ()))
32
+ if hasattr (stage , 'adapter' ) and stage .adapter is not None :
33
+ self .adapter_type = stage .adapter .__provider__
32
34
33
35
@classmethod
34
36
def from_configs (cls , config , delayed_model_loading = False , orig_config = None ):
35
37
dataset_config , launcher , _ = cls .get_dataset_and_launcher_info (config )
36
- model = MTCNNCascadeModel (
37
- config .get ('network_info' , {}), launcher , config .get ('_models' , []), delayed_model_loading
38
- )
39
- return cls (dataset_config , launcher , model , orig_config )
38
+ models_info = config ['network_info' ]
39
+ stages = build_stages (models_info , [], launcher , config .get ('_models' ), delayed_model_loading )
40
+ return cls (dataset_config , launcher , stages , orig_config )
40
41
41
42
def _process (self , output_callback , calculate_metrics , progress_reporter , metric_config , csv_file ):
42
43
def no_detections (batch_pred ):
@@ -50,7 +51,7 @@ def no_detections(batch_pred):
50
51
intermediate_callback = partial (output_callback , metrics_result = None ,
51
52
element_identifiers = batch_identifiers , dataset_indices = batch_input_ids )
52
53
batch_size = 1
53
- for stage in self .model . stages .values ():
54
+ for stage in self .stages .values ():
54
55
previous_stage_predictions = batch_prediction
55
56
filled_inputs , batch_meta = stage .preprocess_data (
56
57
copy .deepcopy (batch_inputs ), batch_annotation , previous_stage_predictions
@@ -71,10 +72,37 @@ def no_detections(batch_pred):
71
72
dataset_indices = batch_input_ids )
72
73
self ._update_progress (progress_reporter , metric_config , batch_id , len (batch_prediction ), csv_file )
73
74
75
+ def _release_model (self ):
76
+ for _ , stage in self .stages .items ():
77
+ stage .release ()
78
+
74
79
def reset (self ):
75
- self .model .reset ()
80
+ super ().reset ()
81
+ for _ , stage in self .stages .items ():
82
+ stage .reset ()
83
+
84
+ def load_network (self , network = None ):
85
+ if network is None :
86
+ for stage_name , stage in self .stages .items ():
87
+ stage .load_network (network , self .launcher , stage_name + '_' )
88
+ else :
89
+ for net_dict in network :
90
+ stage_name = net_dict ['name' ]
91
+ network_ = net_dict ['model' ]
92
+ self .stages [stage_name ].load_network (network_ , self .launcher , stage_name + '_' )
93
+
94
+ def load_network_from_ir (self , models_list ):
95
+ for models_dict in models_list :
96
+ stage_name = models_dict ['name' ]
97
+ self .stages [stage_name ].load_model (models_dict , self .launcher , stage_name + '_' )
98
+
99
+ def get_network (self ):
100
+ return [{'name' : stage_name , 'model' : stage .network } for stage_name , stage in self .stages .items ()]
76
101
77
102
def select_dataset (self , dataset_tag ):
78
- super ().select_dataset (dataset_tag )
79
- for _ , stage in self .model .stages .items ():
80
- stage .update_preprocessing (self .preprocessor )
103
+ if self .dataset is not None and isinstance (self .dataset_config , list ):
104
+ return
105
+ dataset_attributes = create_dataset_attributes (self .dataset_config , dataset_tag )
106
+ self .dataset , self .metric_executor , preprocessor , self .postprocessor = dataset_attributes
107
+ for _ , stage in self .stages .items ():
108
+ stage .update_preprocessing (preprocessor )
0 commit comments