Skip to content

Commit a9539cb

Browse files
committed
Merge pull request #16913 from phlrain/fix_bpr_loss
Fix bpr loss
1 parent 5c10c57 commit a9539cb

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

paddle/fluid/operators/bpr_loss_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel {
3232
int rank = x_dims.size();
3333
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
3434
"Input(X) and Input(Label) shall have the same rank.");
35-
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
36-
framework::slice_ddim(label_dims, 0, rank - 1),
37-
"Input(X) and Input(Label) shall have the same shape "
38-
"except the last dimension.");
35+
36+
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
37+
framework::product(label_dims) > 0)) {
38+
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
39+
framework::slice_ddim(label_dims, 0, rank - 1),
40+
"Input(X) and Input(Label) shall have the same shape "
41+
"except the last dimension.");
42+
}
3943

4044
auto y_dims = x_dims;
4145
y_dims[rank - 1] = 1;

0 commit comments

Comments
 (0)