@@ -723,11 +723,10 @@ def __reshape_to_2d(var):
723
723
target_label .stop_gradient = True
724
724
conf_loss = nn .softmax_with_cross_entropy (confidence , target_label )
725
725
# 3. Mining hard examples
726
+ actual_shape = ops .slice (conf_shape , axes = [0 ], starts = [0 ], ends = [2 ])
727
+ actual_shape .stop_gradient = True
726
728
conf_loss = nn .reshape (
727
- x = conf_loss ,
728
- shape = (num , num_prior ),
729
- actual_shape = ops .slice (
730
- conf_shape , axes = [0 ], starts = [0 ], ends = [2 ]))
729
+ x = conf_loss , shape = (num , num_prior ), actual_shape = actual_shape )
731
730
conf_loss .stop_gradient = True
732
731
neg_indices = helper .create_tmp_variable (dtype = 'int32' )
733
732
dtype = matched_indices .dtype
@@ -796,11 +795,7 @@ def __reshape_to_2d(var):
796
795
# 5.3 Compute overall weighted loss.
797
796
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
798
797
# reshape to [N, Np], N is the batch size and Np is the prior box number.
799
- loss = nn .reshape (
800
- x = loss ,
801
- shape = (num , num_prior ),
802
- actual_shape = ops .slice (
803
- conf_shape , axes = [0 ], starts = [0 ], ends = [2 ]))
798
+ loss = nn .reshape (x = loss , shape = (num , num_prior ), actual_shape = actual_shape )
804
799
loss = nn .reduce_sum (loss , dim = 1 , keep_dim = True )
805
800
if normalize :
806
801
normalizer = nn .reduce_sum (target_loc_weight )
0 commit comments