Skip to content

Commit 23ba766

Browse files
authored
Merge pull request #13475 from panyx0718/ir5
avoid creating dangling ir::Node.
2 parents dffc457 + 0bd7a67 commit 23ba766

9 files changed

+85
-56
lines changed

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ struct TestBroadcastOpHandle {
9696
}
9797
param_scopes_[input_scope_idx]->Var("input");
9898

99-
std::unique_ptr<ir::Node> n(
100-
new ir::Node("node0", ir::Node::Type::kOperation));
99+
std::unique_ptr<ir::Node> n =
100+
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation);
101101
if (use_gpu_) {
102102
#ifdef PADDLE_WITH_CUDA
103103
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
@@ -115,17 +115,17 @@ struct TestBroadcastOpHandle {
115115
#endif
116116
}
117117

118-
std::unique_ptr<ir::Node> v(
119-
new ir::Node("node1", ir::Node::Type::kVariable));
118+
std::unique_ptr<ir::Node> v =
119+
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable);
120120
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
121121
gpu_list_[input_scope_idx]);
122122
vars_.emplace_back(in_var_handle);
123123
op_handle_->AddInput(in_var_handle);
124124

125125
// add dummy var
126126

127-
std::unique_ptr<ir::Node> v2(
128-
new ir::Node("node2", ir::Node::Type::kVariable));
127+
std::unique_ptr<ir::Node> v2 =
128+
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable);
129129
vars_.emplace_back(new DummyVarHandle(v2.get()));
130130
DummyVarHandle* dummy_var_handle =
131131
static_cast<DummyVarHandle*>(vars_.back().get());
@@ -136,17 +136,17 @@ struct TestBroadcastOpHandle {
136136
if (!use_gpu_) {
137137
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
138138
}
139-
std::unique_ptr<ir::Node> v3(
140-
new ir::Node("node3", ir::Node::Type::kVariable));
139+
std::unique_ptr<ir::Node> v3 =
140+
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable);
141141
VarHandle* out_var_handle =
142142
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
143143
vars_.emplace_back(out_var_handle);
144144
op_handle_->AddOutput(out_var_handle);
145145
}
146146

147147
// add dummy var
148-
std::unique_ptr<ir::Node> v4(
149-
new ir::Node("node4", ir::Node::Type::kVariable));
148+
std::unique_ptr<ir::Node> v4 =
149+
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable);
150150
vars_.emplace_back(new DummyVarHandle(v4.get()));
151151
DummyVarHandle* out_dummy_var_handle =
152152
static_cast<DummyVarHandle*>(vars_.back().get());

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
5454
paddle::framework::FeedFetchList fetches;
5555
fetches.resize(fetch_tensors.size());
5656
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
57-
std::vector<std::unique_ptr<ir::Node>> fetch_nodes;
5857
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
5958

6059
for (auto &fetch_var_name : fetch_tensors) {
@@ -75,9 +74,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
7574

7675
auto &vars = fetched_var_it->second;
7776

78-
fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
79-
auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i,
80-
&local_scopes_);
77+
ir::Node *fetch_node =
78+
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
79+
auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
8180
fetch_ops.emplace_back(op);
8281

8382
for (auto &p : places_) {
@@ -116,9 +115,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
116115
num_complete += num_comp;
117116
}
118117
// Wait FetchOps.
119-
if (!fetch_ops.empty()) {
120-
fetch_ops.clear();
121-
}
118+
ClearFetchOp(graph_.get(), &fetch_ops);
122119
return fetches;
123120
}
124121
void FastThreadedSSAGraphExecutor::RunOpAsync(

paddle/fluid/framework/details/gather_op_handle_test.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,41 @@ struct TestGatherOpHandle {
8282
}
8383
param_scopes_[input_scope_idx]->Var("out");
8484

85-
nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation));
85+
nodes.emplace_back(
86+
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
8687
op_handle_.reset(
8788
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
8889
// add input
8990
for (size_t j = 0; j < gpu_list_.size(); ++j) {
9091
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
91-
nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable));
92+
nodes.emplace_back(
93+
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
9294
auto* in_var_handle =
9395
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
9496
vars_.emplace_back(in_var_handle);
9597
op_handle_->AddInput(in_var_handle);
9698
}
9799

98100
// add dummy var
99-
nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable));
101+
nodes.emplace_back(
102+
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
100103
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
101104
DummyVarHandle* in_dummy_var_handle =
102105
static_cast<DummyVarHandle*>(vars_.back().get());
103106
in_dummy_var_handle->ClearGeneratedOp();
104107
op_handle_->AddInput(in_dummy_var_handle);
105108

106109
// add output
107-
nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable));
110+
nodes.emplace_back(
111+
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
108112
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
109113
"out", gpu_list_[input_scope_idx]);
110114
vars_.emplace_back(out_var_handle);
111115
op_handle_->AddOutput(out_var_handle);
112116

113117
// add dummy var
114-
nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable));
118+
nodes.emplace_back(
119+
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
115120
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
116121
DummyVarHandle* dummy_var_handle =
117122
static_cast<DummyVarHandle*>(vars_.back().get());

paddle/fluid/framework/details/ssa_graph_executor.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@ namespace framework {
1919
namespace details {
2020
SSAGraphExecutor::~SSAGraphExecutor() {}
2121

22+
void ClearFetchOp(ir::Graph* graph,
23+
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
24+
if (fetch_ops->empty()) return;
25+
26+
for (auto& op : *fetch_ops) {
27+
for (auto& out_var : op->Node()->outputs) {
28+
graph->RemoveNode(out_var);
29+
}
30+
graph->RemoveNode(op->Node());
31+
}
32+
fetch_ops->clear();
33+
}
34+
2235
} // namespace details
2336
} // namespace framework
2437
} // namespace paddle

paddle/fluid/framework/details/ssa_graph_executor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <string>
1919
#include <vector>
2020

21+
#include "paddle/fluid/framework/details/fetch_op_handle.h"
2122
#include "paddle/fluid/framework/feed_fetch_type.h"
2223
#include "paddle/fluid/framework/ir/graph.h"
2324

@@ -36,6 +37,9 @@ class SSAGraphExecutor {
3637

3738
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
3839
};
40+
41+
void ClearFetchOp(ir::Graph* graph,
42+
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
3943
} // namespace details
4044
} // namespace framework
4145
} // namespace paddle

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
6969

7070
// Step 2. Insert FetchOps
7171
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
72-
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
7372
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
7473
FeedFetchList fetch_data(fetch_tensors.size());
7574

76-
InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
77-
&pending_ops, &pending_vars, &ready_vars, &fetch_data);
75+
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
76+
&pending_vars, &ready_vars, &fetch_data);
7877

7978
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
8079
for (auto *op : set) {
@@ -136,17 +135,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
136135
PADDLE_ENFORCE(ready_ops.empty());
137136

138137
// Wait FetchOps.
139-
if (!fetch_ops.empty()) {
140-
fetch_ops.clear();
141-
}
138+
ClearFetchOp(graph_.get(), &fetch_ops);
142139

143140
return fetch_data;
144141
}
145142

146143
void ThreadedSSAGraphExecutor::InsertFetchOps(
147144
const std::vector<std::string> &fetch_tensors,
148145
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
149-
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
150146
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
151147
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
152148
std::unordered_set<VarHandleBase *> *pending_vars,
@@ -171,9 +167,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
171167

172168
auto &vars = fetched_var_it->second;
173169

174-
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
175-
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
176-
&local_scopes_);
170+
ir::Node *fetch_node =
171+
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
172+
auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_);
177173
fetch_ops->emplace_back(op);
178174

179175
for (auto &p : places_) {
@@ -184,8 +180,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
184180
op->AddInput(var);
185181
}
186182

187-
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
188-
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
183+
ir::Node *fetch_var =
184+
graph_->CreateEmptyNode("fetch", ir::Node::Type::kVariable);
185+
auto *fetch_dummy = new DummyVarHandle(fetch_var);
189186
op->AddOutput(fetch_dummy);
190187
fetch_dependencies->emplace(fetch_dummy);
191188
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
7373
void InsertFetchOps(
7474
const std::vector<std::string> &fetch_tensors,
7575
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
76-
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
7776
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
7877
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
7978
std::unordered_set<VarHandleBase *> *pending_vars,

paddle/fluid/framework/ir/node.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ namespace framework {
1919
namespace ir {
2020
constexpr char Node::kControlDepVarName[];
2121
int Node::count_ = 0;
22+
23+
std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
24+
Node::Type type) {
25+
return std::unique_ptr<Node>(new Node(name, type));
26+
}
2227
} // namespace ir
2328
} // namespace framework
2429
} // namespace paddle

paddle/fluid/framework/ir/node.h

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,12 @@ namespace paddle {
2424
namespace framework {
2525
namespace ir {
2626

27+
// Node should normally created by Graph::CreateXXXNode().
2728
class Node {
2829
public:
2930
enum class Type { kOperation, kVariable };
3031
static constexpr char kControlDepVarName[] = "__control_var";
3132

32-
explicit Node(const std::string& name, Type type)
33-
: name_(name),
34-
var_desc_(nullptr),
35-
op_desc_(nullptr),
36-
type_(type),
37-
id_(count_++) {}
38-
39-
explicit Node(VarDesc* var_desc)
40-
: name_(var_desc->Name()),
41-
var_desc_(new VarDesc(*var_desc)),
42-
op_desc_(nullptr),
43-
type_(Type::kVariable),
44-
id_(count_++) {}
45-
46-
explicit Node(OpDesc* op_desc)
47-
: name_(op_desc->Type()),
48-
var_desc_(nullptr),
49-
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
50-
type_(Type::kOperation),
51-
id_(count_++) {}
52-
5333
Type NodeType() const { return type_; }
5434

5535
std::string Name() const { return name_; }
@@ -81,11 +61,40 @@ class Node {
8161

8262
private:
8363
friend class Graph;
64+
friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
65+
Node::Type type);
66+
67+
explicit Node(const std::string& name, Type type)
68+
: name_(name),
69+
var_desc_(nullptr),
70+
op_desc_(nullptr),
71+
type_(type),
72+
id_(count_++) {}
73+
74+
explicit Node(VarDesc* var_desc)
75+
: name_(var_desc->Name()),
76+
var_desc_(new VarDesc(*var_desc)),
77+
op_desc_(nullptr),
78+
type_(Type::kVariable),
79+
id_(count_++) {}
80+
81+
explicit Node(OpDesc* op_desc)
82+
: name_(op_desc->Type()),
83+
var_desc_(nullptr),
84+
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
85+
type_(Type::kOperation),
86+
id_(count_++) {}
87+
88+
Node() = delete;
89+
8490
static int count_;
8591
static void ResetId() { count_ = 0; }
8692
DISABLE_COPY_AND_ASSIGN(Node);
8793
};
8894

95+
std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
96+
Node::Type type);
97+
8998
} // namespace ir
9099
} // namespace framework
91100
} // namespace paddle

0 commit comments

Comments
 (0)