Skip to content

Commit 66024e9

Browse files
authored
Merge pull request #14149 from chengduoZH/fix_sum_op_bug_relase
Fix sum op's GetExpectedKernelType
2 parents c5591f7 + 618d7e3 commit 66024e9

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) {
358358
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
359359
}
360360

361-
static const Tensor* GetTensorFromVar(Variable* var) {
361+
const Tensor* GetTensorFromVar(const Variable* var) {
362362
if (var->IsType<LoDTensor>()) {
363-
return var->GetMutable<LoDTensor>();
363+
return static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
364364
} else if (var->IsType<SelectedRows>()) {
365-
return var->GetMutable<SelectedRows>()->mutable_value();
365+
return &(var->Get<SelectedRows>().value());
366366
} else {
367367
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
368368
var->Type().name());
@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
415415
template <>
416416
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
417417
auto* var = InputVar(name);
418-
return var == nullptr ? nullptr
419-
: GetTensorFromVar(const_cast<Variable*>(var));
418+
return var == nullptr ? nullptr : GetTensorFromVar(var);
420419
}
421420

422421
template <>

paddle/fluid/framework/operator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
6363
}
6464

6565
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
66+
const Tensor* GetTensorFromVar(const Variable* var);
6667

6768
class OperatorBase;
6869
class ExecutionContext;

paddle/fluid/operators/sum_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,15 @@ class SumOp : public framework::OperatorWithKernel {
8282
if (x_vars[0]->IsType<framework::LoDTensor>()) {
8383
int dtype = -1;
8484
for (auto& x_var : x_vars) {
85-
auto& lod_tensor = x_var->Get<framework::LoDTensor>();
86-
if (lod_tensor.numel() == 0) {
85+
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
86+
auto tensor = framework::GetTensorFromVar(x_var);
87+
if (tensor->numel() == 0) {
8788
continue;
8889
}
8990
if (dtype == -1) {
90-
dtype = framework::ToDataType(lod_tensor.type());
91+
dtype = framework::ToDataType(tensor->type());
9192
} else {
92-
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type()));
93+
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
9394
}
9495
}
9596
PADDLE_ENFORCE_NE(dtype, -1,

0 commit comments

Comments
 (0)