Skip to content

Commit dc9e23c

Browse files
author
chengduozh
committed
fix sum_op's GetExpectedKernelType
test=develop
1 parent 7719349 commit dc9e23c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

paddle/fluid/operators/sum_op.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,17 @@ class SumOp : public framework::OperatorWithKernel {
8181

8282
if (x_vars[0]->IsType<framework::LoDTensor>()) {
8383
int dtype = -1;
84-
for (auto& x_var : x_vars) {
85-
auto& lod_tensor = x_var->Get<framework::LoDTensor>();
86-
if (lod_tensor.numel() == 0) {
84+
auto x_var_names = ctx.Inputs("X");
85+
for (auto& x_var_n : x_var_names) {
86+
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
87+
auto tensor = ctx.Input<Tensor>(x_var_n);
88+
if (tensor->numel() == 0) {
8789
continue;
8890
}
8991
if (dtype == -1) {
90-
dtype = framework::ToDataType(lod_tensor.type());
92+
dtype = framework::ToDataType(tensor->type());
9193
} else {
92-
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type()));
94+
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
9395
}
9496
}
9597
PADDLE_ENFORCE_NE(dtype, -1,

0 commit comments

Comments
 (0)