File tree Expand file tree Collapse file tree 3 files changed +10
-9
lines changed Expand file tree Collapse file tree 3 files changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) {
358
358
return var->IsType <LoDTensor>() || var->IsType <SelectedRows>();
359
359
}
360
360
361
- static const Tensor* GetTensorFromVar (Variable* var) {
361
+ const Tensor* GetTensorFromVar (const Variable* var) {
362
362
if (var->IsType <LoDTensor>()) {
363
- return var->GetMutable <LoDTensor>();
363
+ return static_cast < const Tensor*>(&( var->Get <LoDTensor>()) );
364
364
} else if (var->IsType <SelectedRows>()) {
365
- return var->GetMutable <SelectedRows>()-> mutable_value ( );
365
+ return &( var->Get <SelectedRows>(). value () );
366
366
} else {
367
367
PADDLE_THROW (" Variable type_id %s, expect LoDTensor/SelectedRows." ,
368
368
var->Type ().name ());
@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
415
415
template <>
416
416
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
417
417
auto * var = InputVar (name);
418
- return var == nullptr ? nullptr
419
- : GetTensorFromVar (const_cast <Variable*>(var));
418
+ return var == nullptr ? nullptr : GetTensorFromVar (var);
420
419
}
421
420
422
421
template <>
Original file line number Diff line number Diff line change @@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
63
63
}
64
64
65
65
proto::VarType::Type GetDataTypeOfVar (const Variable* var);
66
+ const Tensor* GetTensorFromVar (const Variable* var);
66
67
67
68
class OperatorBase ;
68
69
class ExecutionContext ;
Original file line number Diff line number Diff line change @@ -82,14 +82,15 @@ class SumOp : public framework::OperatorWithKernel {
82
82
if (x_vars[0 ]->IsType <framework::LoDTensor>()) {
83
83
int dtype = -1 ;
84
84
for (auto & x_var : x_vars) {
85
- auto & lod_tensor = x_var->Get <framework::LoDTensor>();
86
- if (lod_tensor.numel () == 0 ) {
85
+ // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
86
+ auto tensor = framework::GetTensorFromVar (x_var);
87
+ if (tensor->numel () == 0 ) {
87
88
continue ;
88
89
}
89
90
if (dtype == -1 ) {
90
- dtype = framework::ToDataType (lod_tensor. type ());
91
+ dtype = framework::ToDataType (tensor-> type ());
91
92
} else {
92
- PADDLE_ENFORCE_EQ (dtype, framework::ToDataType (lod_tensor. type ()));
93
+ PADDLE_ENFORCE_EQ (dtype, framework::ToDataType (tensor-> type ()));
93
94
}
94
95
}
95
96
PADDLE_ENFORCE_NE (dtype, -1 ,
You can’t perform that action at this time.
0 commit comments