12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ #include < algorithm>
16
+ #include < functional>
15
17
#include < queue>
16
18
#include < string>
19
+ #include < tuple>
17
20
#include < vector>
18
21
19
22
#include " paddle/fluid/framework/details/computation_op_handle.h"
20
23
#include " paddle/fluid/framework/details/eager_deletion_op_handle.h"
21
- #include " paddle/fluid/framework/details/eager_deletion_pass.h"
22
24
#include " paddle/fluid/framework/details/multi_devices_helper.h"
23
25
#include " paddle/fluid/framework/ir/graph_helper.h"
24
26
27
+ DEFINE_double (memory_fraction_of_eager_deletion, 1.0 ,
28
+ " Fraction of eager deletion. If less than 1.0, all variables in "
29
+ " the program would be sorted according to its memory size, and "
30
+ " only the FLAGS_memory_fraction_of_eager_deletion of the largest "
31
+ " variables would be deleted." );
32
+
25
33
namespace paddle {
26
34
namespace framework {
27
35
namespace details {
28
36
37
+ // op -> variables which can be deleted after op runs
38
+ using OpToVarNameSetMap =
39
+ std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
40
+
41
+ // Check whether the variable is LoDTensor based on static VarDesc info
42
+ static bool IsLoDTensor (VarDesc *var) {
43
+ return var->Proto ()->type ().type () == proto::VarType::LOD_TENSOR;
44
+ }
45
+
46
+ // Get memory size of LoDTensor
47
+ static int64_t GetMemorySize (
48
+ const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
49
+ const std::string &var_name) {
50
+ auto *var_desc = TryGetLatestVarDesc (vars.at (var_name));
51
+ PADDLE_ENFORCE_NOT_NULL (var_desc);
52
+ PADDLE_ENFORCE (IsLoDTensor (var_desc));
53
+ auto dims = var_desc->GetShape ();
54
+ return SizeOfType (var_desc->GetDataType ()) *
55
+ std::accumulate (dims.begin (), dims.end (), static_cast <int64_t >(1 ),
56
+ std::multiplies<int64_t >());
57
+ }
58
+
59
+ // Split all variables in the graph into LoDTensor and Non-LoDTensor (e.g.
60
+ // SelectedRows, LoDTensorArray)
61
+ // Since partial GC is based on static analysis of memory size of each variable
62
+ // So we should skip SelectedRows and LoDTensorArray here
63
+ static void SplitIntoLoDTensorAndNonLoDTensorVars (
64
+ const OpToVarNameSetMap &m, const GraphVars &vars,
65
+ OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
66
+ lod_tensors->clear ();
67
+ other_vars->clear ();
68
+
69
+ for (auto &op_vars_pair : m) {
70
+ for (auto &var_name : op_vars_pair.second ) {
71
+ auto *var_desc = TryGetLatestVarDesc (
72
+ vars[op_vars_pair.first ->GetScopeIdx ()].at (var_name));
73
+ if (IsLoDTensor (var_desc)) {
74
+ (*lod_tensors)[op_vars_pair.first ].insert (var_name);
75
+ } else {
76
+ (*other_vars)[op_vars_pair.first ].insert (var_name);
77
+ }
78
+ }
79
+ }
80
+ }
81
+
82
+ struct GCVarInfo {
83
+ GCVarInfo (const std::string &name, int64_t memory_size,
84
+ ComputationOpHandle *op, size_t scope_idx)
85
+ : name_(name),
86
+ memory_size_ (memory_size),
87
+ op_(op),
88
+ scope_idx_(scope_idx) {}
89
+
90
+ std::string name_; // variable name
91
+ int64_t memory_size_; // memory size
92
+ ComputationOpHandle *op_; // op after which the variable could be deleted
93
+ size_t scope_idx_; // scope index where the variable locates
94
+
95
+ int64_t AbsMemorySize () const { return std::abs (memory_size_); }
96
+ };
97
+
98
+ // Delete delete_lod_tensor_only is not used currently
99
+ static OpToVarNameSetMap ShrinkGCVars (
100
+ const OpToVarNameSetMap &m, const GraphVars &vars,
101
+ const std::vector<platform::Place> &places, double fraction_of_memory_size,
102
+ bool delete_lod_tensor_only = false ) {
103
+ // Do not perform gc when fraction_of_memory_size = 0
104
+ if (fraction_of_memory_size <= 0.0 ) return {};
105
+
106
+ /* *
107
+ * Step 1: Split all variables into LoDTensor and Non-LoDTensor.
108
+ * We can only calculate memory size of LoDTensors
109
+ */
110
+ OpToVarNameSetMap lod_tensors, other_vars;
111
+ SplitIntoLoDTensorAndNonLoDTensorVars (m, vars, &lod_tensors, &other_vars);
112
+
113
+ // Perform complete gc when fraction_of_memory_size >= 1
114
+ if (fraction_of_memory_size >= 1.0 ) {
115
+ return delete_lod_tensor_only ? lod_tensors : m;
116
+ }
117
+
118
+ /* *
119
+ * Step 2: build GCVarInfos, and calculate total memory sizes of each device
120
+ */
121
+
122
+ // place -> variable info (name, memory size, place, scope_idx)
123
+ std::map<platform::Place, std::vector<GCVarInfo>> place_to_vars;
124
+
125
+ // place -> total memory sizes
126
+ std::map<platform::Place, int64_t > place_to_size;
127
+ for (auto &op_vars_pair : lod_tensors) {
128
+ auto *op = op_vars_pair.first ;
129
+ auto &var_names = op_vars_pair.second ;
130
+ auto scope_idx = op->GetScopeIdx ();
131
+ auto &place = places[scope_idx];
132
+
133
+ for (auto &var_name : var_names) {
134
+ auto var_size = GetMemorySize (vars[scope_idx], var_name);
135
+ GCVarInfo var_info (var_name, var_size, op, scope_idx);
136
+ place_to_size[place] += var_info.AbsMemorySize ();
137
+ place_to_vars[place].emplace_back (std::move (var_info));
138
+ }
139
+ }
140
+
141
+ /* *
142
+ * Step 3: sort GCVarInfos, and only delete the largest variables.
143
+ */
144
+ OpToVarNameSetMap partial_vars;
145
+ for (auto &place_to_var_pair : place_to_vars) {
146
+ auto &place = place_to_var_pair.first ;
147
+ auto &gc_vars = place_to_var_pair.second ;
148
+ std::sort (gc_vars.begin (), gc_vars.end (),
149
+ [](const GCVarInfo &var1, const GCVarInfo &var2) {
150
+ return var1.AbsMemorySize () > var2.AbsMemorySize ();
151
+ });
152
+
153
+ int64_t accumulated_size = 0 ;
154
+ int64_t size_threshold =
155
+ static_cast <int64_t >(fraction_of_memory_size * place_to_size[place]);
156
+ for (size_t i = 0 ; i < gc_vars.size () && accumulated_size < size_threshold;
157
+ ++i) {
158
+ partial_vars[gc_vars[i].op_ ].insert (gc_vars[i].name_ );
159
+ accumulated_size += gc_vars[i].AbsMemorySize ();
160
+ }
161
+ }
162
+
163
+ /* *
164
+ * Step 4: Combine other vars (SelectedRows, LoDTensorArray)
165
+ */
166
+ if (!delete_lod_tensor_only) {
167
+ for (auto &op_vars_pair : other_vars) {
168
+ partial_vars[op_vars_pair.first ].insert (op_vars_pair.second .begin (),
169
+ op_vars_pair.second .end ());
170
+ }
171
+ }
172
+
173
+ return partial_vars;
174
+ }
175
+
176
+ class EagerDeletionPass : public ir ::Pass {
177
+ protected:
178
+ std::unique_ptr<ir::Graph> ApplyImpl (
179
+ std::unique_ptr<ir::Graph> graph) const override ;
180
+ };
181
+
29
182
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl (
30
183
std::unique_ptr<ir::Graph> graph) const {
31
184
auto &ref_cnts =
@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
43
196
44
197
// a reverse map of last_live_ops
45
198
// i.e., last op --> variable names which can be deleted.
46
- std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
47
- op_vars_map;
48
-
199
+ OpToVarNameSetMap op_vars_map;
49
200
for (auto &var_ops_map : last_live_ops) {
50
201
for (auto &var_ops_pair : var_ops_map) {
51
202
const std::string &var_name = var_ops_pair.first ;
@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
55
206
}
56
207
}
57
208
209
+ op_vars_map = ShrinkGCVars (op_vars_map, vars, places,
210
+ FLAGS_memory_fraction_of_eager_deletion);
211
+
58
212
for (auto &pair : op_vars_map) {
59
213
auto *op = pair.first ;
60
214
auto &var_names = pair.second ;
@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
85
239
eager_deletion_op->AddOutput (dummy_leaf);
86
240
}
87
241
242
+ VLOG (10 ) << " FLAGS_memory_fraction_of_eager_deletion = "
243
+ << FLAGS_memory_fraction_of_eager_deletion;
88
244
VLOG (10 ) << " Create " << op_vars_map.size () << " EagerDeletionOpHandle(s)" ;
89
- return graph;
245
+
246
+ auto while_op_eager_deletion_pass =
247
+ ir::PassRegistry::Instance ().Get (" while_op_eager_deletion_pass" );
248
+ return while_op_eager_deletion_pass->Apply (std::move (graph));
90
249
}
91
250
92
251
} // namespace details
@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass,
99
258
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars )
100
259
.RequirePassAttr(paddle::framework::details::kAllPlaces )
101
260
.RequirePassAttr(paddle::framework::details::kGarbageCollector );
261
+
262
+ USE_PASS (while_op_eager_deletion_pass);
0 commit comments