Skip to content

Commit 2fa8df1

Browse files
committed
separate graph building pass and graph-based pe builder
1 parent 37e5144 commit 2fa8df1

15 files changed

+273
-182
lines changed

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle {
9696
}
9797
param_scopes_[input_scope_idx]->Var("input");
9898

99-
std::unique_ptr<ir::Node> n(new ir::Node(ir::Node::Type::kOperation));
99+
std::unique_ptr<ir::Node> n(new ir::Node());
100100
if (use_gpu_) {
101101
#ifdef PADDLE_WITH_CUDA
102102
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
@@ -114,15 +114,15 @@ struct TestBroadcastOpHandle {
114114
#endif
115115
}
116116

117-
std::unique_ptr<ir::Node> v(new ir::Node(ir::Node::Type::kVariable));
117+
std::unique_ptr<ir::Node> v(new ir::Node());
118118
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
119119
gpu_list_[input_scope_idx]);
120120
vars_.emplace_back(in_var_handle);
121121
op_handle_->AddInput(in_var_handle);
122122

123123
// add dummy var
124124

125-
std::unique_ptr<ir::Node> v2(new ir::Node(ir::Node::Type::kVariable));
125+
std::unique_ptr<ir::Node> v2(new ir::Node());
126126
vars_.emplace_back(new DummyVarHandle(v2.get()));
127127
DummyVarHandle* dummy_var_handle =
128128
static_cast<DummyVarHandle*>(vars_.back().get());
@@ -133,15 +133,15 @@ struct TestBroadcastOpHandle {
133133
if (!use_gpu_) {
134134
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
135135
}
136-
std::unique_ptr<ir::Node> v3(new ir::Node(ir::Node::Type::kVariable));
136+
std::unique_ptr<ir::Node> v3(new ir::Node());
137137
VarHandle* out_var_handle =
138138
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
139139
vars_.emplace_back(out_var_handle);
140140
op_handle_->AddOutput(out_var_handle);
141141
}
142142

143143
// add dummy var
144-
std::unique_ptr<ir::Node> v4(new ir::Node(ir::Node::Type::kVariable));
144+
std::unique_ptr<ir::Node> v4(new ir::Node());
145145
vars_.emplace_back(new DummyVarHandle(v4.get()));
146146
DummyVarHandle* out_dummy_var_handle =
147147
static_cast<DummyVarHandle*>(vars_.back().get());

paddle/fluid/framework/details/gather_op_handle_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,36 @@ struct TestGatherOpHandle {
8282
}
8383
param_scopes_[input_scope_idx]->Var("out");
8484

85-
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
85+
nodes.emplace_back(new ir::Node());
8686
op_handle_.reset(
8787
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
8888
// add input
8989
for (size_t j = 0; j < gpu_list_.size(); ++j) {
9090
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
91-
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
91+
nodes.emplace_back(new ir::Node());
9292
auto* in_var_handle =
9393
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
9494
vars_.emplace_back(in_var_handle);
9595
op_handle_->AddInput(in_var_handle);
9696
}
9797

9898
// add dummy var
99-
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
99+
nodes.emplace_back(new ir::Node());
100100
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
101101
DummyVarHandle* in_dummy_var_handle =
102102
static_cast<DummyVarHandle*>(vars_.back().get());
103103
in_dummy_var_handle->ClearGeneratedOp();
104104
op_handle_->AddInput(in_dummy_var_handle);
105105

106106
// add output
107-
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
107+
nodes.emplace_back(new ir::Node());
108108
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
109109
"out", gpu_list_[input_scope_idx]);
110110
vars_.emplace_back(out_var_handle);
111111
op_handle_->AddOutput(out_var_handle);
112112

113113
// add dummy var
114-
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
114+
nodes.emplace_back(new ir::Node());
115115
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
116116
DummyVarHandle* dummy_var_handle =
117117
static_cast<DummyVarHandle*>(vars_.back().get());

0 commit comments

Comments
 (0)