Skip to content

Commit 2f63911

Browse files
author
chengduo
authored
Fix sum_op's GetExpectedKernelType (#14112)
* fix sum_op's GetExpectedKernelType test=develop * fix ci fail test=develop
1 parent 5577f9b commit 2f63911

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) {
358358
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
359359
}
360360

361-
static const Tensor* GetTensorFromVar(Variable* var) {
361+
const Tensor* GetTensorFromVar(Variable* var) {
362362
if (var->IsType<LoDTensor>()) {
363363
return var->GetMutable<LoDTensor>();
364364
} else if (var->IsType<SelectedRows>()) {

paddle/fluid/framework/operator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
6363
}
6464

6565
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
66+
const Tensor* GetTensorFromVar(Variable* var);
6667

6768
class OperatorBase;
6869
class ExecutionContext;

paddle/fluid/operators/sum_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,16 @@ class SumOp : public framework::OperatorWithKernel {
8282
if (x_vars[0]->IsType<framework::LoDTensor>()) {
8383
int dtype = -1;
8484
for (auto& x_var : x_vars) {
85-
auto& lod_tensor = x_var->Get<framework::LoDTensor>();
86-
if (lod_tensor.numel() == 0) {
85+
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
86+
auto tensor = framework::GetTensorFromVar(
87+
const_cast<framework::Variable*>(x_var));
88+
if (tensor->numel() == 0) {
8789
continue;
8890
}
8991
if (dtype == -1) {
90-
dtype = framework::ToDataType(lod_tensor.type());
92+
dtype = framework::ToDataType(tensor->type());
9193
} else {
92-
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type()));
94+
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
9395
}
9496
}
9597
PADDLE_ENFORCE_NE(dtype, -1,

0 commit comments

Comments
 (0)