@@ -620,8 +620,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
620
620
" There are no kernels which are registered in the %s operator." , type_);
621
621
}
622
622
623
- ExecutionContext ctx (*this , scope, *dev_ctx);
624
-
625
623
OpKernelMap& kernels = kernels_iter->second ;
626
624
627
625
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
@@ -631,7 +629,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
631
629
// Do selection
632
630
// }
633
631
634
- auto expected_kernel_key = this ->GetExpectedKernelType (ctx);
632
+ auto expected_kernel_key =
633
+ this ->GetExpectedKernelType (ExecutionContext (*this , scope, *dev_ctx));
635
634
VLOG (3 ) << " expected_kernel_key:" << expected_kernel_key;
636
635
637
636
auto kernel_iter = kernels.find (expected_kernel_key);
@@ -640,63 +639,99 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
640
639
KernelTypeToString (expected_kernel_key));
641
640
}
642
641
643
- // do data transform
644
- Scope& new_scope = scope.NewScope ();
642
+ // do data transformScope &transfer_scope;
643
+ std::vector<std::string> transfered_inplace_vars;
644
+ auto * transfer_scope =
645
+ TryTransferData (scope, expected_kernel_key, &transfered_inplace_vars);
645
646
646
- std::vector<std::string> inplace_vars;
647
- for (auto & var_name_item : this ->Inputs ()) {
648
- for (auto & var_name : var_name_item.second ) {
649
- auto * var = scope.FindVar (var_name);
650
- if (var && VarIsTensor (var)) {
651
- auto * tensor_in = GetTensorFromVar (var);
652
- if (tensor_in->IsInitialized ()) {
653
- auto kernel_type_for_var = this ->GetKernelTypeForVar (
654
- var_name_item.first , *tensor_in, expected_kernel_key);
655
- if (TransFromNeeded (kernel_type_for_var, expected_kernel_key)) {
656
- auto out_var_names = OutputVars (true );
657
- if (std::find (out_var_names.begin (), out_var_names.end (),
658
- var_name) != out_var_names.end ()) {
659
- inplace_vars.push_back (var_name);
660
- }
661
- VLOG (3 ) << " Transform Variable " << var_name << " from "
662
- << kernel_type_for_var << " to " << expected_kernel_key;
663
- auto * trans_var = new_scope.Var (var_name);
664
- std::shared_ptr<Tensor> out (new Tensor);
665
- DataTransform (expected_kernel_key, kernel_type_for_var, *tensor_in,
666
- out.get ());
667
- CopyVariableWithTensor (*var, *(out.get ()), trans_var);
668
- }
669
- }
670
- }
671
- }
647
+ // exec scope is the scope that kernel actually executed on.
648
+ const Scope& exec_scope =
649
+ (transfer_scope == nullptr ? scope : *transfer_scope);
650
+
651
+ if (!(expected_kernel_key.place_ == dev_ctx->GetPlace ())) {
652
+ dev_ctx = pool.Get (expected_kernel_key.place_ );
672
653
}
673
654
674
- auto * new_dev_ctx = pool.Get (expected_kernel_key.place_ );
675
- kernel_iter->second ->Compute (
676
- ExecutionContext (*this , new_scope, *new_dev_ctx));
655
+ kernel_iter->second ->Compute (ExecutionContext (*this , exec_scope, *dev_ctx));
677
656
678
- for (auto & var_name : inplace_vars) {
679
- VLOG (3 ) << " share inplace var " + var_name + " back to it's original scope" ;
680
- auto * original_tensor = GetMutableTensorFromVar (scope.FindVar (var_name));
681
- auto * transformed_tensor = GetTensorFromVar (new_scope.FindVar (var_name));
682
- original_tensor->ShareDataWith (*transformed_tensor);
657
+ if (!transfered_inplace_vars.empty ()) {
658
+ // there is inplace variable has been transfered.
659
+ TransferInplaceVarsBack (scope, transfered_inplace_vars, *transfer_scope);
683
660
}
684
661
685
662
/* For profiling/benchmark only*/
686
663
if (FLAGS_benchmark) {
687
- new_dev_ctx ->Wait ();
664
+ dev_ctx ->Wait ();
688
665
}
689
666
690
667
if (FLAGS_check_nan_inf) {
691
668
for (auto & vname : OutputVars (true )) {
692
- auto * var = new_scope .FindVar (vname);
669
+ auto * var = exec_scope .FindVar (vname);
693
670
if (var == nullptr ) continue ;
694
671
if (var->IsType <framework::LoDTensor>()) {
695
672
CheckTensorNANOrInf (vname, var->Get <framework::LoDTensor>());
696
673
}
697
674
}
698
675
}
699
676
}
677
+ void OperatorWithKernel::TransferInplaceVarsBack (
678
+ const Scope& scope, const std::vector<std::string>& inplace_vars,
679
+ const Scope& transfer_scope) const {
680
+ for (auto & var_name : inplace_vars) {
681
+ VLOG (3 ) << " share inplace var " + var_name + " back to it's original scope" ;
682
+ auto * original_tensor = GetMutableTensorFromVar (scope.FindVar (var_name));
683
+ auto * transformed_tensor =
684
+ GetTensorFromVar (transfer_scope.FindVar (var_name));
685
+ original_tensor->ShareDataWith (*transformed_tensor);
686
+ }
687
+ }
688
+
689
+ Scope* OperatorWithKernel::TryTransferData (
690
+ const Scope& scope, const OpKernelType& expected_kernel_key,
691
+ std::vector<std::string>* transfered_inplace_vars) const {
692
+ Scope* new_scope = nullptr ;
693
+ for (auto & var_name_item : Inputs ()) {
694
+ for (auto & var_name : var_name_item.second ) {
695
+ auto * var = scope.FindVar (var_name);
696
+ // Only tensor can be tranfer to another device.
697
+ if (var == nullptr || !VarIsTensor (var)) {
698
+ continue ;
699
+ }
700
+
701
+ auto * tensor_in = GetTensorFromVar (var);
702
+ if (!tensor_in->IsInitialized ()) {
703
+ continue ;
704
+ }
705
+
706
+ auto kernel_type_for_var = GetKernelTypeForVar (
707
+ var_name_item.first , *tensor_in, expected_kernel_key);
708
+
709
+ if (!NeedTransform (kernel_type_for_var, expected_kernel_key)) {
710
+ continue ;
711
+ }
712
+
713
+ auto out_var_names = OutputVars (true );
714
+ if (std::find (out_var_names.begin (), out_var_names.end (), var_name) !=
715
+ out_var_names.end ()) {
716
+ transfered_inplace_vars->emplace_back (var_name);
717
+ }
718
+
719
+ VLOG (3 ) << " Transform Variable " << var_name << " from "
720
+ << kernel_type_for_var << " to " << expected_kernel_key;
721
+
722
+ if (new_scope == nullptr ) {
723
+ new_scope = &scope.NewScope ();
724
+ }
725
+
726
+ auto * trans_var = new_scope->Var (var_name);
727
+ Tensor out;
728
+ TransformData (expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
729
+ SetTensorToVariable (*var, out, trans_var);
730
+ }
731
+ }
732
+
733
+ return new_scope;
734
+ }
700
735
701
736
proto::VarType::Type OperatorWithKernel::IndicateDataType (
702
737
const ExecutionContext& ctx) const {
0 commit comments