File tree Expand file tree Collapse file tree 3 files changed +8
-5
lines changed Expand file tree Collapse file tree 3 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) {
358
358
return var->IsType <LoDTensor>() || var->IsType <SelectedRows>();
359
359
}
360
360
361
- static const Tensor* GetTensorFromVar (Variable* var) {
361
+ const Tensor* GetTensorFromVar (Variable* var) {
362
362
if (var->IsType <LoDTensor>()) {
363
363
return var->GetMutable <LoDTensor>();
364
364
} else if (var->IsType <SelectedRows>()) {
Original file line number Diff line number Diff line change @@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
63
63
}
64
64
65
65
proto::VarType::Type GetDataTypeOfVar (const Variable* var);
66
+ const Tensor* GetTensorFromVar (Variable* var);
66
67
67
68
class OperatorBase ;
68
69
class ExecutionContext ;
Original file line number Diff line number Diff line change @@ -82,14 +82,16 @@ class SumOp : public framework::OperatorWithKernel {
82
82
if (x_vars[0 ]->IsType <framework::LoDTensor>()) {
83
83
int dtype = -1 ;
84
84
for (auto & x_var : x_vars) {
85
- auto & lod_tensor = x_var->Get <framework::LoDTensor>();
86
- if (lod_tensor.numel () == 0 ) {
85
+ // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
86
+ auto tensor = framework::GetTensorFromVar (
87
+ const_cast <framework::Variable*>(x_var));
88
+ if (tensor->numel () == 0 ) {
87
89
continue ;
88
90
}
89
91
if (dtype == -1 ) {
90
- dtype = framework::ToDataType (lod_tensor. type ());
92
+ dtype = framework::ToDataType (tensor-> type ());
91
93
} else {
92
- PADDLE_ENFORCE_EQ (dtype, framework::ToDataType (lod_tensor. type ()));
94
+ PADDLE_ENFORCE_EQ (dtype, framework::ToDataType (tensor-> type ()));
93
95
}
94
96
}
95
97
PADDLE_ENFORCE_NE (dtype, -1 ,
You can’t perform that action at this time.
0 commit comments