Skip to content

Commit 65efebb

Browse files
authored
Fix detection.py after merge slice_op. (#13435)
1 parent 289acfa commit 65efebb

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

python/paddle/fluid/layers/detection.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,10 @@ def __reshape_to_2d(var):
723723
target_label.stop_gradient = True
724724
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)
725725
# 3. Mining hard examples
726+
actual_shape = ops.slice(conf_shape, axes=[0], starts=[0], ends=[2])
727+
actual_shape.stop_gradient = True
726728
conf_loss = nn.reshape(
727-
x=conf_loss,
728-
shape=(num, num_prior),
729-
actual_shape=ops.slice(
730-
conf_shape, axes=[0], starts=[0], ends=[2]))
729+
x=conf_loss, shape=(num, num_prior), actual_shape=actual_shape)
731730
conf_loss.stop_gradient = True
732731
neg_indices = helper.create_tmp_variable(dtype='int32')
733732
dtype = matched_indices.dtype
@@ -796,11 +795,7 @@ def __reshape_to_2d(var):
796795
# 5.3 Compute overall weighted loss.
797796
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
798797
# reshape to [N, Np], N is the batch size and Np is the prior box number.
799-
loss = nn.reshape(
800-
x=loss,
801-
shape=(num, num_prior),
802-
actual_shape=ops.slice(
803-
conf_shape, axes=[0], starts=[0], ends=[2]))
798+
loss = nn.reshape(x=loss, shape=(num, num_prior), actual_shape=actual_shape)
804799
loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
805800
if normalize:
806801
normalizer = nn.reduce_sum(target_loc_weight)

0 commit comments

Comments
 (0)