Skip to content

Commit 2ccf77d

Browse files
author
chengduo
authored
Refine GetTensorFromVar (#14160)
* fix GetTensorFromVar test=release/1.1 * refine GetTensorFromVar test=develop
1 parent d2e622f commit 2ccf77d

File tree

3 files changed

+22
-19
lines changed

3 files changed

+22
-19
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -354,18 +354,18 @@ void OperatorBase::GenerateTemporaryNames() {
354354
}
355355
}
356356

357-
static bool VarIsTensor(const Variable* var) {
358-
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
357+
static bool VarIsTensor(const Variable& var) {
358+
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
359359
}
360360

361-
const Tensor* GetTensorFromVar(Variable* var) {
362-
if (var->IsType<LoDTensor>()) {
363-
return var->GetMutable<LoDTensor>();
364-
} else if (var->IsType<SelectedRows>()) {
365-
return var->GetMutable<SelectedRows>()->mutable_value();
361+
const Tensor* GetTensorFromVar(const Variable& var) {
362+
if (var.IsType<LoDTensor>()) {
363+
return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
364+
} else if (var.IsType<SelectedRows>()) {
365+
return &(var.Get<SelectedRows>().value());
366366
} else {
367367
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
368-
var->Type().name());
368+
var.Type().name());
369369
}
370370
}
371371

@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
415415
template <>
416416
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
417417
auto* var = InputVar(name);
418-
return var == nullptr ? nullptr
419-
: GetTensorFromVar(const_cast<Variable*>(var));
418+
return var == nullptr ? nullptr : GetTensorFromVar(*var);
420419
}
421420

422421
template <>
@@ -428,7 +427,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
428427
std::transform(names.begin(), names.end(), std::back_inserter(res),
429428
[&](const std::string& sub_name) {
430429
auto var = scope_.FindVar(sub_name);
431-
return var == nullptr ? nullptr : GetTensorFromVar(var);
430+
return var == nullptr ? nullptr : GetTensorFromVar(*var);
432431
});
433432
return res;
434433
}
@@ -770,8 +769,10 @@ void OperatorWithKernel::TransferInplaceVarsBack(
770769
for (auto& var_name : inplace_vars) {
771770
VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
772771
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
773-
auto* transformed_tensor =
774-
GetTensorFromVar(transfer_scope.FindVar(var_name));
772+
auto* var = transfer_scope.FindVar(var_name);
773+
PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
774+
var_name);
775+
auto* transformed_tensor = GetTensorFromVar(*var);
775776
original_tensor->ShareDataWith(*transformed_tensor);
776777
}
777778
}
@@ -784,11 +785,11 @@ Scope* OperatorWithKernel::TryTransferData(
784785
for (auto& var_name : var_name_item.second) {
785786
auto* var = scope.FindVar(var_name);
786787
// Only tensor can be tranfer to another device.
787-
if (var == nullptr || !VarIsTensor(var)) {
788+
if (var == nullptr || !VarIsTensor(*var)) {
788789
continue;
789790
}
790791

791-
auto* tensor_in = GetTensorFromVar(var);
792+
auto* tensor_in = GetTensorFromVar(*var);
792793
if (!tensor_in->IsInitialized()) {
793794
continue;
794795
}

paddle/fluid/framework/operator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
6363
}
6464

6565
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
66-
const Tensor* GetTensorFromVar(Variable* var);
66+
const Tensor* GetTensorFromVar(const Variable& var);
6767

6868
class OperatorBase;
6969
class ExecutionContext;

paddle/fluid/operators/sum_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class SumOp : public framework::OperatorWithKernel {
6767
framework::OpKernelType GetExpectedKernelType(
6868
const framework::ExecutionContext& ctx) const override {
6969
auto x_vars = ctx.MultiInputVar("X");
70+
auto x_vars_name = ctx.Inputs("X");
7071

7172
framework::LibraryType library{framework::LibraryType::kPlain};
7273
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
@@ -81,10 +82,11 @@ class SumOp : public framework::OperatorWithKernel {
8182

8283
if (x_vars[0]->IsType<framework::LoDTensor>()) {
8384
int dtype = -1;
84-
for (auto& x_var : x_vars) {
85+
for (size_t idx = 0; idx < x_vars.size(); ++idx) {
86+
PADDLE_ENFORCE(x_vars[idx] != nullptr,
87+
"Input var[%s] should not be nullptr", x_vars_name[idx]);
8588
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
86-
auto tensor = framework::GetTensorFromVar(
87-
const_cast<framework::Variable*>(x_var));
89+
auto tensor = framework::GetTensorFromVar(*x_vars[idx]);
8890
if (tensor->numel() == 0) {
8991
continue;
9092
}

0 commit comments

Comments
 (0)