12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ #include < queue>
15
16
#include < string>
16
17
#include < vector>
17
18
@@ -23,6 +24,25 @@ namespace paddle {
23
24
namespace framework {
24
25
namespace details {
25
26
27
+ static ComputationOpHandle *FindNextComputationOpHandle (VarHandle *var_in) {
28
+ std::queue<VarHandleBase *> queue;
29
+ queue.push (var_in);
30
+ do {
31
+ auto *var = queue.front ();
32
+ queue.pop ();
33
+ for (auto *op : var->PendingOps ()) {
34
+ auto *compute_op = dynamic_cast <ComputationOpHandle *>(op);
35
+ if (compute_op != nullptr && compute_op->GetPlace () == var_in->place_ ) {
36
+ return compute_op;
37
+ }
38
+ for (auto *out_var : op->Outputs ()) {
39
+ queue.push (out_var);
40
+ }
41
+ }
42
+ } while (!queue.empty ());
43
+ return nullptr ;
44
+ }
45
+
26
46
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl (
27
47
std::unique_ptr<ir::Graph> graph) const {
28
48
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount );
@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
34
54
// Step 2: Find all variables in non-computation ops which refers to variables
35
55
// in computation ops
36
56
std::unordered_set<std::string> names;
57
+ std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>>
58
+ compute_ref_cnt_map;
59
+
37
60
auto get_ref_cnts_from_compute_op = [&](
38
61
const std::unique_ptr<OpHandleBase> &op,
39
62
const std::vector<VarHandleBase *> &vars) {
@@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
54
77
VarDesc *var_desc = var_handle->Node ()->Var ();
55
78
auto var_name = var_handle->Node ()->Name ();
56
79
57
- // This is wierd but there is really some variables without var_desc
80
+ // This is weird but there is really some variables without var_desc
58
81
// in computation_op
59
82
if (var_desc == nullptr ) {
60
83
if (compute_op->Node ()->Op ()->Block ()->FindVar (var_name) == nullptr )
61
84
continue ;
62
85
} else {
63
- if (var_desc->Persistable () ||
64
- var_desc->Proto ()->type ().type () != proto::VarType::LOD_TENSOR)
86
+ if (var_desc->Persistable ()) continue ;
87
+ auto var_type = var_desc->Proto ()->type ().type ();
88
+ if (var_type != proto::VarType::LOD_TENSOR &&
89
+ var_type != proto::VarType::SELECTED_ROWS) {
65
90
continue ;
91
+ }
66
92
}
67
93
68
94
// compute op only runs in one device
@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
93
119
if (ref_cnts.count (place.device ) &&
94
120
ref_cnts[place.device ]->count (var_name)) {
95
121
++(*ref_cnts[place.device ])[var_name];
122
+
123
+ auto *next_compute_op = FindNextComputationOpHandle (var_handle);
124
+ if (next_compute_op != nullptr ) {
125
+ if (compute_ref_cnt_map.count (next_compute_op)) {
126
+ compute_ref_cnt_map[next_compute_op]->AddVar (var_name);
127
+ VLOG (5 ) << " Add reference count of " << var_name << " to Operator "
128
+ << next_compute_op->Name ();
129
+ } else {
130
+ // Create new reference_count_op_handle
131
+ ir::Node *ref_cnt_node = graph->CreateEmptyNode (
132
+ " reference_count" , ir::Node::Type::kOperation );
133
+ auto *ref_cnt_handle = new ReferenceCountOpHandle (
134
+ ref_cnt_node, next_compute_op->GetScope (), place, {var_name},
135
+ gcs[place.device ].get (), cur_ref_cnts[place.device ].get ());
136
+ if (next_compute_op->Outputs ().empty ()) {
137
+ auto *dep_var = new DummyVarHandle (graph->CreateControlDepVar ());
138
+ next_compute_op->AddOutput (dep_var);
139
+ graph->Get <GraphDepVars>(kGraphDepVars ).emplace (dep_var);
140
+ }
141
+ ref_cnt_handle->AddInput (next_compute_op->Outputs ().front ());
142
+ compute_ref_cnt_map[next_compute_op].reset (ref_cnt_handle);
143
+ }
144
+ }
96
145
}
97
146
}
98
147
};
99
148
100
- std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
101
- compute_ref_cnt_map;
102
149
auto &all_ops = graph->Get <GraphOps>(kGraphOps );
103
150
for (auto &op : all_ops) {
104
151
auto in_var_names = get_ref_cnts_from_compute_op (op, op->Inputs ());
@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
113
160
auto *ref_cnt_handle = new ReferenceCountOpHandle (
114
161
ref_cnt_node, compute_op->GetScope (), place, in_var_names,
115
162
gcs[place.device ].get (), cur_ref_cnts[place.device ].get ());
116
- auto *dep_var = new DummyVarHandle (graph->CreateControlDepVar ());
117
- compute_op->AddOutput (dep_var);
118
- ref_cnt_handle->AddInput (dep_var);
119
- graph->Get <GraphDepVars>(kGraphDepVars ).emplace (dep_var);
120
- compute_ref_cnt_map[compute_op] = ref_cnt_handle;
163
+ if (compute_op->Outputs ().empty ()) {
164
+ auto *dep_var = new DummyVarHandle (graph->CreateControlDepVar ());
165
+ compute_op->AddOutput (dep_var);
166
+ graph->Get <GraphDepVars>(kGraphDepVars ).emplace (dep_var);
167
+ }
168
+ ref_cnt_handle->AddInput (compute_op->Outputs ().front ());
169
+ compute_ref_cnt_map[compute_op].reset (ref_cnt_handle);
121
170
}
122
171
123
172
for (auto &op : all_ops) {
@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
131
180
new_all_ops.emplace_back (std::move (op));
132
181
auto it = compute_ref_cnt_map.find (new_all_ops.back ().get ());
133
182
if (it != compute_ref_cnt_map.end ()) {
134
- new_all_ops.emplace_back (it->second );
183
+ // Add LeafNode to ReferenceCountOpHandle
184
+ auto *dummy_leaf = new DummyVarHandle (graph->CreateControlDepVar ());
185
+ graph->Get <GraphDepVars>(kGraphDepVars ).emplace (dummy_leaf);
186
+ it->second ->AddOutput (dummy_leaf);
187
+ new_all_ops.emplace_back (std::move (it->second ));
135
188
}
136
189
}
137
190
0 commit comments