Skip to content

Commit aa75f1e

Browse files
author
Yancey
authored
Create tensor in recv op (#7286)
* create tensor in recv op * static global function to global function
1 parent 2d10c75 commit aa75f1e

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

paddle/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch";
3535

3636
Executor::Executor(const platform::Place& place) : place_(place) {}
3737

38-
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
38+
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
3939
if (var_type == proto::VarDesc::LOD_TENSOR) {
4040
var->GetMutable<LoDTensor>();
4141
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {

paddle/framework/executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,7 @@ class Executor {
4545
const platform::Place place_;
4646
};
4747

48+
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type);
49+
4850
} // namespace framework
4951
} // namespace paddle

paddle/operators/recv_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ limitations under the License. */
1919

2020
#include <unistd.h>
2121

22-
#include "paddle/framework/data_type.h"
2322
#include "paddle/framework/executor.h"
2423
#include "paddle/framework/framework.pb.h"
2524
#include "paddle/framework/lod_tensor.h"
@@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase {
111110
<< " updating param: " << param_var_name;
112111
auto *merged_grad = recv_scope.FindVar(grad_var_name);
113112
if (merged_grad == nullptr) {
114-
// create output of merged var.
115-
auto merged_var = recv_scope.Var(grad_var_name);
116-
merged_var->GetMutable<framework::LoDTensor>();
113+
auto *ptr = recv_scope.Var(grad_var_name);
114+
framework::CreateTensor(ptr,
115+
framework::ToVarType(merged_grad->Type()));
116+
VLOG(3) << "Create Variable " << grad_var_name
117+
<< " on recv scope, which pointer is " << ptr;
117118
}
118119

119120
if (trainer_count > 1) {

0 commit comments

Comments
 (0)