25
25
26
26
27
27
class SequentialBackgroundMatting (BaseCustomEvaluator ):
28
- def __init__ (self , dataset_config , launcher , model , input_feeder , adapter , orig_config ):
28
+ def __init__ (self , dataset_config , launcher , model , adapter , orig_config ):
29
29
super ().__init__ (dataset_config , launcher , orig_config )
30
30
self .model = model
31
- self .input_feeder = input_feeder
32
31
self .adapter = adapter
33
- self .inputs = input_feeder .network_inputs
34
32
35
33
def _process (self , output_callback , calculate_metrics , progress_reporter , metric_config , csv_file ):
36
34
previous_video_id = None
@@ -42,10 +40,10 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
42
40
filled_inputs = self .input_feeder .fill_inputs (batch_inputs )
43
41
for i , filled_input in enumerate (filled_inputs ):
44
42
filled_input .update (rnn_inputs [i ])
45
- batch_raw_results = self .model .predict (batch_identifiers , filled_inputs )
46
- batch_predictions = self .adapter .process (batch_raw_results , batch_identifiers , batch_meta )
43
+ batch_raw_results , batch_results = self .model .predict (batch_identifiers , filled_inputs )
44
+ batch_predictions = self .adapter .process (batch_results , batch_identifiers , batch_meta )
47
45
previous_video_id = batch_annotation [0 ].video_id
48
- rnn_inputs = self .model .set_rnn_inputs (batch_raw_results )
46
+ rnn_inputs = self .model .set_rnn_inputs (batch_results )
49
47
annotation , prediction = self .postprocessor .process_batch (batch_annotation , batch_predictions )
50
48
metrics_result = self ._get_metrics_result (batch_input_ids , annotation , prediction , calculate_metrics )
51
49
if output_callback :
@@ -60,21 +58,16 @@ def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
60
58
config .get ('network_info' , {}), launcher , config .get ('_models' , []), config .get ('_model_is_blob' ),
61
59
delayed_model_loading
62
60
)
63
- launcher .network = model .model .network
64
- launcher_config = config ['launchers' ][0 ]
65
- postpone_model_loading = False
66
- input_precision = launcher_config .get ('_input_precision' , [])
67
- input_layouts = launcher_config .get ('_input_layout' , '' )
68
- input_feeder = InputFeeder (
69
- launcher .config .get ('inputs' , []), model .model .launcher .inputs , cls .input_shape , launcher .fit_to_input ,
70
- launcher .default_layout , launcher_config ['framework' ] == 'dummy' or postpone_model_loading , input_precision ,
71
- input_layouts
72
- )
73
- adapter = create_adapter (launcher_config .get ('adapter' , 'background_matting_with_pha_and_fgr' ))
74
- return cls (dataset_config , launcher , model , input_feeder , adapter , orig_config )
61
+ adapter = create_adapter (config ['launchers' ][0 ].get ('adapter' , 'background_matting_with_pha_and_fgr' ))
62
+ return cls (dataset_config , launcher , model , adapter , orig_config )
75
63
76
- def input_shape (self , input_name ):
77
- return self .inputs [input_name ]
64
+ @property
65
+ def input_feeder (self ):
66
+ return self .model .model .input_feeder
67
+
68
+ @property
69
+ def inputs (self ):
70
+ return self .input_feeder .network_inputs
78
71
79
72
80
73
class SequentialBackgroundMattingModel (BaseCascadeModel ):
@@ -99,10 +92,12 @@ def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_l
99
92
100
93
def predict (self , identifiers , input_data ):
101
94
batch_raw_results = []
95
+ batch_results = []
102
96
for identifier , data in zip (identifiers , input_data ):
103
- raw_results = self .model .predict (identifier , data )
97
+ results , raw_results = self .model .predict (identifier , data )
104
98
batch_raw_results .append (raw_results )
105
- return batch_raw_results
99
+ batch_results .append (results )
100
+ return batch_raw_results , batch_results
106
101
107
102
def reset_rnn_inputs (self , batch_size ):
108
103
output = []
@@ -129,9 +124,34 @@ def set_rnn_inputs(self, outputs):
129
124
130
125
class DLSDKSequentialBackgroundMattingModel (BaseDLSDKModel ):
131
126
def predict (self , identifiers , input_data ):
132
- return self .exec_network .infer (input_data )
127
+ outputs = self .exec_network .infer (input_data )
128
+ if isinstance (outputs , tuple ):
129
+ outputs , raw_outputs = outputs
130
+ else :
131
+ raw_outputs = outputs
132
+ return outputs , raw_outputs
133
+
134
+ def input_shape (self , input_name ):
135
+ return self .launcher .inputs [input_name ]
136
+
137
+ def load_network (self , network , launcher ):
138
+ super ().load_network (network , launcher )
139
+ self .launcher = launcher
140
+ self .launcher .network = self .network
141
+ self .input_feeder = InputFeeder (self .launcher .config .get ('inputs' , []), self .launcher .inputs , self .input_shape ,
142
+ self .launcher .fit_to_input , self .launcher .default_layout )
133
143
134
144
135
145
class OpenVINOModelSequentialBackgroundMattingModel (BaseOpenVINOModel ):
136
146
def predict (self , identifiers , input_data ):
137
- return self .infer (input_data )
147
+ return self .infer (input_data , raw_results = True )
148
+
149
+ def input_shape (self , input_name ):
150
+ return self .launcher .inputs [input_name ]
151
+
152
+ def load_network (self , network , launcher ):
153
+ super ().load_network (network , launcher )
154
+ self .launcher = launcher
155
+ self .launcher .network = self .network
156
+ self .input_feeder = InputFeeder (self .launcher .config .get ('inputs' , []), self .launcher .inputs , self .input_shape ,
157
+ self .launcher .fit_to_input , self .launcher .default_layout )
0 commit comments