@@ -496,10 +496,15 @@ def __reshape_to_2d(var):
496
496
# 5.1 Compute confidence loss.
497
497
target_label = __reshape_to_2d (target_label )
498
498
target_label = tensor .cast (x = target_label , dtype = 'int64' )
499
+
499
500
conf_loss = nn .softmax_with_cross_entropy (confidence , target_label )
500
501
target_conf_weight = __reshape_to_2d (target_conf_weight )
501
502
conf_loss = conf_loss * target_conf_weight
502
503
504
+ # the target_label and target_conf_weight do not have gradient.
505
+ target_label .stop_gradient = True
506
+ target_conf_weight .stop_gradient = True
507
+
503
508
# 5.2 Compute regression loss.
504
509
location = __reshape_to_2d (location )
505
510
target_bbox = __reshape_to_2d (target_bbox )
@@ -508,6 +513,10 @@ def __reshape_to_2d(var):
508
513
target_loc_weight = __reshape_to_2d (target_loc_weight )
509
514
loc_loss = loc_loss * target_loc_weight
510
515
516
+ # the target_bbox and target_loc_weight do not have gradient.
517
+ target_bbox .stop_gradient = True
518
+ target_loc_weight .stop_gradient = True
519
+
511
520
# 5.3 Compute overall weighted loss.
512
521
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
513
522
# reshape to [N, Np], N is the batch size and Np is the prior box number.
0 commit comments