File tree Expand file tree Collapse file tree 3 files changed +8
-5
lines changed Expand file tree Collapse file tree 3 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch";
35
35
36
36
Executor::Executor (const platform::Place& place) : place_(place) {}
37
37
38
- static void CreateTensor (Variable* var, proto::VarDesc::VarType var_type) {
38
+ void CreateTensor (Variable* var, proto::VarDesc::VarType var_type) {
39
39
if (var_type == proto::VarDesc::LOD_TENSOR) {
40
40
var->GetMutable <LoDTensor>();
41
41
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {
Original file line number Diff line number Diff line change @@ -45,5 +45,7 @@ class Executor {
45
45
const platform::Place place_;
46
46
};
47
47
48
+ void CreateTensor (Variable* var, proto::VarDesc::VarType var_type);
49
+
48
50
} // namespace framework
49
51
} // namespace paddle
Original file line number Diff line number Diff line change @@ -19,7 +19,6 @@ limitations under the License. */
19
19
20
20
#include < unistd.h>
21
21
22
- #include " paddle/framework/data_type.h"
23
22
#include " paddle/framework/executor.h"
24
23
#include " paddle/framework/framework.pb.h"
25
24
#include " paddle/framework/lod_tensor.h"
@@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase {
111
110
<< " updating param: " << param_var_name;
112
111
auto *merged_grad = recv_scope.FindVar (grad_var_name);
113
112
if (merged_grad == nullptr ) {
114
- // create output of merged var.
115
- auto merged_var = recv_scope.Var (grad_var_name);
116
- merged_var->GetMutable <framework::LoDTensor>();
113
+ auto *ptr = recv_scope.Var (grad_var_name);
114
+ framework::CreateTensor (ptr,
115
+ framework::ToVarType (merged_grad->Type ()));
116
+ VLOG (3 ) << " Create Variable " << grad_var_name
117
+ << " on recv scope, which pointer is " << ptr;
117
118
}
118
119
119
120
if (trainer_count > 1 ) {
You can’t perform that action at this time.
0 commit comments