File tree Expand file tree Collapse file tree 3 files changed +6
-8
lines changed Expand file tree Collapse file tree 3 files changed +6
-8
lines changed Original file line number Diff line number Diff line change @@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) {
358
358
return var->IsType <LoDTensor>() || var->IsType <SelectedRows>();
359
359
}
360
360
361
- const Tensor* GetTensorFromVar (Variable* var) {
361
+ const Tensor* GetTensorFromVar (const Variable* var) {
362
362
if (var->IsType <LoDTensor>()) {
363
- return var->GetMutable <LoDTensor>();
363
+ return static_cast < const Tensor*>(&( var->Get <LoDTensor>()) );
364
364
} else if (var->IsType <SelectedRows>()) {
365
- return var->GetMutable <SelectedRows>()-> mutable_value ( );
365
+ return &( var->Get <SelectedRows>(). value () );
366
366
} else {
367
367
PADDLE_THROW (" Variable type_id %s, expect LoDTensor/SelectedRows." ,
368
368
var->Type ().name ());
@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
415
415
template <>
416
416
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
417
417
auto * var = InputVar (name);
418
- return var == nullptr ? nullptr
419
- : GetTensorFromVar (const_cast <Variable*>(var));
418
+ return var == nullptr ? nullptr : GetTensorFromVar (var);
420
419
}
421
420
422
421
template <>
Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
63
63
}
64
64
65
65
proto::VarType::Type GetDataTypeOfVar (const Variable* var);
66
- const Tensor* GetTensorFromVar (Variable* var);
66
+ const Tensor* GetTensorFromVar (const Variable* var);
67
67
68
68
class OperatorBase ;
69
69
class ExecutionContext ;
Original file line number Diff line number Diff line change @@ -83,8 +83,7 @@ class SumOp : public framework::OperatorWithKernel {
83
83
int dtype = -1 ;
84
84
for (auto & x_var : x_vars) {
85
85
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
86
- auto tensor = framework::GetTensorFromVar (
87
- const_cast <framework::Variable*>(x_var));
86
+ auto tensor = framework::GetTensorFromVar (x_var);
88
87
if (tensor->numel () == 0 ) {
89
88
continue ;
90
89
}
You can’t perform that action at this time.
0 commit comments