@@ -145,14 +145,17 @@ def process(self, raw, identifiers, frame_meta):
145
145
raw_outputs = self ._extract_predictions (raw , frame_meta )
146
146
if not self .outputs_verified :
147
147
self ._get_output_names (raw_outputs )
148
+ input_shape = list (frame_meta [0 ].get ('input_shape' , {'data' : (1 , 3 , 416 , 416 )}).values ())[0 ]
149
+ nchw_layout = input_shape [1 ] == 3
148
150
prior_boxes = raw_outputs [self .priorbox_out ][0 ][0 ].reshape (- 1 , 4 ) if not self .multihead else None
149
151
prior_variances = raw_outputs [self .priorbox_out ][0 ][1 ].reshape (- 1 , 4 ) if not self .multihead else None
150
152
151
- head_shifts = self .estimate_head_shifts (raw_outputs , self . head_sizes , self .add_conf_outs , self .multihead )
153
+ head_shifts = self .estimate_head_shifts (
154
+ raw_outputs , self . head_sizes , self .add_conf_outs , self .multihead , nchw_layout )
152
155
153
156
for batch_id , identifier in enumerate (identifiers ):
154
157
labels , class_scores , x_mins , y_mins , x_maxs , y_maxs , main_scores = self .prepare_detection_for_id (
155
- batch_id , raw_outputs , prior_boxes , prior_variances , head_shifts
158
+ batch_id , raw_outputs , prior_boxes , prior_variances , head_shifts , nchw = nchw_layout
156
159
)
157
160
action_prediction = ActionDetectionPrediction (
158
161
identifier , labels , class_scores , main_scores , x_mins , y_mins , x_maxs , y_maxs
@@ -167,12 +170,15 @@ def process(self, raw, identifiers, frame_meta):
167
170
return result
168
171
169
172
def prepare_detection_for_id (self , batch_id , raw_outputs , prior_boxes , prior_variances , head_shifts ,
170
- default_label = 0 ):
173
+ default_label = 0 , nchw = True ):
171
174
num_detections = raw_outputs [self .loc_out ][batch_id ].size // 4
172
175
locs = raw_outputs [self .loc_out ][batch_id ].reshape (- 1 , 4 )
173
176
main_conf = raw_outputs [self .main_conf_out ][batch_id ].reshape (num_detections , - 1 )
174
177
175
178
add_confs = [raw_outputs [layer ][batch_id ] for layer in self .add_conf_outs ]
179
+ if not nchw :
180
+ add_confs = [np .transpose (conf_l , (2 , 0 , 1 )) for conf_l in add_confs ]
181
+
176
182
if self .multihead :
177
183
spatial_sizes = [layer .shape [1 :] for layer in add_confs ]
178
184
add_confs = [layer .reshape (self .num_action_classes , - 1 ) for layer in add_confs ]
@@ -238,14 +244,16 @@ def decode_box(prior, var, deltas):
238
244
return decoded_xmin , decoded_ymin , decoded_xmax , decoded_ymax
239
245
240
246
@staticmethod
241
- def estimate_head_shifts (raw_outputs , head_sizes , add_conf_outs , multihead_net ):
247
+ def estimate_head_shifts (raw_outputs , head_sizes , add_conf_outs , multihead_net , nchw = True ):
242
248
layer_id = 0
243
249
head_shift = 0
244
250
head_shifts = [0 ]
245
251
for head_size in head_sizes :
246
252
for _ in range (head_size ):
247
253
layer = add_conf_outs [layer_id ]
248
254
layer_shape = raw_outputs [layer ][0 ].shape
255
+ if len (layer_shape ) == 3 and not nchw :
256
+ layer_shape = layer_shape [::- 1 ]
249
257
layer_size = np .prod (layer_shape [1 :]) if multihead_net else np .prod (layer_shape [:2 ])
250
258
head_shift += layer_size
251
259
layer_id += 1
@@ -297,19 +305,17 @@ def find_layer(regex, output_name, all_outputs):
297
305
298
306
self .loc_out = find_layer (loc_out_regex , 'loc' , raw_outputs )
299
307
self .main_conf_out = find_layer (main_conf_out_regex , 'main confidence' , raw_outputs )
300
- self .priorbox_out = self .check_output_name (self .priorbox_out , raw_outputs )
308
+ if hasattr (self , 'priorbox_out' ):
309
+ self .priorbox_out = self .check_output_name (self .priorbox_out , raw_outputs )
301
310
self .outputs_verified = True
302
- if contains_all (raw_outputs , self .add_conf_outs ):
303
- return
304
- add_conf_result = [layer_name + '/sink_port_0' for layer_name in self .add_conf_outs ]
305
- if contains_all (raw_outputs , add_conf_result ):
306
- self .add_conf_outs = add_conf_result
311
+ add_conf_outs = [self .check_output_name (layer , raw_outputs ) for layer in self .add_conf_outs ]
312
+ if contains_all (add_conf_outs ):
313
+ self .add_conf_outs = add_conf_outs
307
314
return
308
- add_conf_with_bias = [layer_name + '/add_' for layer_name in self .add_conf_outs ]
315
+
316
+ add_conf_with_bias = [self .check_output_name (layer_name + '/add_' , raw_outputs )
317
+ for layer_name in self .add_conf_outs ]
309
318
if contains_all (raw_outputs , add_conf_with_bias ):
310
319
self .add_conf_outs = add_conf_with_bias
311
320
return
312
- add_conf_with_bias_result = [layer_name + '/add_/sink_port_0' for layer_name in self .add_conf_outs ]
313
- if contains_all (raw_outputs , add_conf_with_bias_result ):
314
- self .add_conf_outs = add_conf_with_bias_result
315
321
return
0 commit comments