File tree Expand file tree Collapse file tree 3 files changed +5
-4
lines changed Expand file tree Collapse file tree 3 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) {
358
358
return var->IsType <LoDTensor>() || var->IsType <SelectedRows>();
359
359
}
360
360
361
- static const Tensor* GetTensorFromVar (Variable* var) {
361
+ const Tensor* GetTensorFromVar (Variable* var) {
362
362
if (var->IsType <LoDTensor>()) {
363
363
return var->GetMutable <LoDTensor>();
364
364
} else if (var->IsType <SelectedRows>()) {
Original file line number Diff line number Diff line change @@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
63
63
}
64
64
65
65
proto::VarType::Type GetDataTypeOfVar (const Variable* var);
66
+ const Tensor* GetTensorFromVar (Variable* var);
66
67
67
68
class OperatorBase ;
68
69
class ExecutionContext ;
Original file line number Diff line number Diff line change @@ -81,10 +81,10 @@ class SumOp : public framework::OperatorWithKernel {
81
81
82
82
if (x_vars[0 ]->IsType <framework::LoDTensor>()) {
83
83
int dtype = -1 ;
84
- auto x_var_names = ctx.Inputs (" X" );
85
- for (auto & x_var_n : x_var_names) {
84
+ for (auto & x_var : x_vars) {
86
85
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
87
- auto tensor = ctx.Input <Tensor>(x_var_n);
86
+ auto tensor = framework::GetTensorFromVar (
87
+ const_cast <framework::Variable*>(x_var));
88
88
if (tensor->numel () == 0 ) {
89
89
continue ;
90
90
}
You can’t perform that action at this time.
0 commit comments