Skip to content

Commit ece7451

Browse files
zhaoyuchen2018phlrain
authored andcommitted
Merge pull request #16857 from zhaoyuchen2018/sumreshape
Fix sum infershape issue
1 parent a9539cb commit ece7451

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

paddle/fluid/operators/sum_op.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,21 @@ class SumOp : public framework::OperatorWithKernel {
6565
if (framework::product(in_dim) == 0) {
6666
in_dim = x_dim;
6767
} else {
68-
PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape");
68+
if (ctx->IsRuntime()) {
69+
PADDLE_ENFORCE_EQ(in_dim, x_dim,
70+
"Input tensors must have same shape");
71+
} else {
72+
PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(),
73+
"Input tensors must have same shape size");
74+
// if in_dim or x_dim has -1, not check equal
75+
for (int i = 0; i < x_dim.size(); ++i) {
76+
if (x_dim[i] == -1 || in_dim[i] == -1) {
77+
continue;
78+
}
79+
PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i],
80+
"Input tensors must have same shape if not -1");
81+
}
82+
}
6983
}
7084
}
7185
ctx->SetOutputDim("Out", in_dim);

0 commit comments

Comments
 (0)