@@ -354,18 +354,18 @@ void OperatorBase::GenerateTemporaryNames() {
354
354
}
355
355
}
356
356
357
- static bool VarIsTensor (const Variable* var) {
358
- return var-> IsType <LoDTensor>() || var-> IsType <SelectedRows>();
357
+ static bool VarIsTensor (const Variable& var) {
358
+ return var. IsType <LoDTensor>() || var. IsType <SelectedRows>();
359
359
}
360
360
361
- const Tensor* GetTensorFromVar (Variable* var) {
362
- if (var-> IsType <LoDTensor>()) {
363
- return var-> GetMutable <LoDTensor>();
364
- } else if (var-> IsType <SelectedRows>()) {
365
- return var-> GetMutable <SelectedRows>()-> mutable_value ( );
361
+ const Tensor* GetTensorFromVar (const Variable& var) {
362
+ if (var. IsType <LoDTensor>()) {
363
+ return static_cast < const Tensor*>(&(var. Get <LoDTensor>()) );
364
+ } else if (var. IsType <SelectedRows>()) {
365
+ return &( var. Get <SelectedRows>(). value () );
366
366
} else {
367
367
PADDLE_THROW (" Variable type_id %s, expect LoDTensor/SelectedRows." ,
368
- var-> Type ().name ());
368
+ var. Type ().name ());
369
369
}
370
370
}
371
371
@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
415
415
template <>
416
416
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
417
417
auto * var = InputVar (name);
418
- return var == nullptr ? nullptr
419
- : GetTensorFromVar (const_cast <Variable*>(var));
418
+ return var == nullptr ? nullptr : GetTensorFromVar (*var);
420
419
}
421
420
422
421
template <>
@@ -428,7 +427,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
428
427
std::transform (names.begin (), names.end (), std::back_inserter (res),
429
428
[&](const std::string& sub_name) {
430
429
auto var = scope_.FindVar (sub_name);
431
- return var == nullptr ? nullptr : GetTensorFromVar (var);
430
+ return var == nullptr ? nullptr : GetTensorFromVar (* var);
432
431
});
433
432
return res;
434
433
}
@@ -770,8 +769,10 @@ void OperatorWithKernel::TransferInplaceVarsBack(
770
769
for (auto & var_name : inplace_vars) {
771
770
VLOG (3 ) << " share inplace var " + var_name + " back to it's original scope" ;
772
771
auto * original_tensor = GetMutableTensorFromVar (scope.FindVar (var_name));
773
- auto * transformed_tensor =
774
- GetTensorFromVar (transfer_scope.FindVar (var_name));
772
+ auto * var = transfer_scope.FindVar (var_name);
773
+ PADDLE_ENFORCE (var != nullptr , " The var[%s] should not be nullptr" ,
774
+ var_name);
775
+ auto * transformed_tensor = GetTensorFromVar (*var);
775
776
original_tensor->ShareDataWith (*transformed_tensor);
776
777
}
777
778
}
@@ -784,11 +785,11 @@ Scope* OperatorWithKernel::TryTransferData(
784
785
for (auto & var_name : var_name_item.second ) {
785
786
auto * var = scope.FindVar (var_name);
786
787
// Only tensor can be tranfer to another device.
787
- if (var == nullptr || !VarIsTensor (var)) {
788
+ if (var == nullptr || !VarIsTensor (* var)) {
788
789
continue ;
789
790
}
790
791
791
- auto * tensor_in = GetTensorFromVar (var);
792
+ auto * tensor_in = GetTensorFromVar (* var);
792
793
if (!tensor_in->IsInitialized ()) {
793
794
continue ;
794
795
}
0 commit comments