@@ -17,7 +17,6 @@ limitations under the License. */
17
17
#include < glog/logging.h>
18
18
#include < memory>
19
19
#include < utility>
20
-
21
20
#include " paddle/fluid/framework/details/memory_optimize_helper.h"
22
21
#include " paddle/fluid/framework/details/multi_devices_graph_pass.h"
23
22
#include " paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
@@ -82,23 +81,43 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
82
81
AppendPass (" inplace_pass" );
83
82
}
84
83
85
- if (strategy .fuse_elewise_add_act_ops_ ) {
84
+ if (strategy_ .fuse_elewise_add_act_ops_ ) {
86
85
VLOG (10 ) << " Add fuse_elewise_add_act_pass" ;
87
86
AppendPass (" fuse_elewise_add_act_pass" );
88
87
}
89
88
90
89
// for single card training, fuse_all_reduce_ops is unnecessary.
91
90
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
92
- if (strategy .fuse_all_reduce_ops_ ) {
91
+ if (strategy_ .fuse_all_reduce_ops_ ) {
93
92
VLOG (10 ) << " Add alloc_continuous_space_for_grad_pass" ;
94
93
AppendPass (" alloc_continuous_space_for_grad_pass" );
95
94
}
96
95
96
+ if (strategy_.fuse_all_optimizer_ops_ ) {
97
+ if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
98
+ strategy_.is_distribution_ ) {
99
+ VLOG (3 )
100
+ << " Currently, fuse_all_optimizer_ops only works under AllReduce "
101
+ " mode." ;
102
+ strategy_.fuse_all_optimizer_ops_ = false ;
103
+ } else {
104
+ VLOG (10 ) << " Add alloc_continuous_space_for_grad_pass" ;
105
+ AppendPass (" alloc_continuous_space_for_grad_pass" );
106
+ // NOTE: fuse_all_xx_ops will count the number of xx operator first,
107
+ // if the number is zero, fuse_all_reduce_ops will do nothing.
108
+ // Currently, only one type of optimization algorithm can be fused.
109
+ VLOG (10 ) << " Add fuse_adam_op_pass" ;
110
+ AppendPass (" fuse_adam_op_pass" );
111
+ VLOG (10 ) << " Add fuse_sgd_op_pass" ;
112
+ AppendPass (" fuse_sgd_op_pass" );
113
+ }
114
+ }
115
+
97
116
// Add a graph viz pass to record a graph.
98
117
if (!strategy.debug_graphviz_path_ .empty ()) {
99
118
auto viz_pass = AppendPass (" graph_viz_pass" );
100
119
const std::string graph_path = string::Sprintf (
101
- " %s%s" , strategy .debug_graphviz_path_ .c_str (), " _fused_graph" );
120
+ " %s%s" , strategy_ .debug_graphviz_path_ .c_str (), " _fused_graph" );
102
121
viz_pass->Set <std::string>(" graph_viz_path" , new std::string (graph_path));
103
122
}
104
123
@@ -118,14 +137,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
118
137
// the de-fact IR, any reuse on Graph is meaningless.
119
138
// A side-effect of that, memory optimize cannot forsee the fetched vars
120
139
// , so fetchlist should be set persistable before call the Run interface.
121
- if (strategy .memory_optimize_ ) {
140
+ if (strategy_ .memory_optimize_ ) {
122
141
VLOG (10 ) << " Add memory_optimize_pass" ;
123
142
AppendPass (" memory_optimize_pass" );
124
143
}
125
144
126
- AppendMultiDevPass (strategy );
145
+ AppendMultiDevPass (strategy_ );
127
146
128
- if (strategy .fuse_all_reduce_ops_ ) {
147
+ if (strategy_ .fuse_all_reduce_ops_ ) {
129
148
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
130
149
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
131
150
VLOG (10 ) << " Add fuse_all_reduce_op_pass" ;
@@ -151,7 +170,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
151
170
AppendPass (" all_reduce_deps_pass" );
152
171
}
153
172
154
- if (SeqOnlyAllReduceOps (strategy )) {
173
+ if (SeqOnlyAllReduceOps (strategy_ )) {
155
174
VLOG (10 ) << " Add all_reduce_deps_pass" ;
156
175
AppendPass (" all_reduce_deps_pass" );
157
176
}
@@ -165,7 +184,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
165
184
// Convert graph to run on multi-devices.
166
185
void AppendMultiDevPass (const BuildStrategy &strategy) {
167
186
ir::Pass *multi_devices_pass = nullptr ;
168
- if (strategy_ .is_distribution_ ) {
187
+ if (strategy .is_distribution_ ) {
169
188
VLOG (10 ) << " Add dist_multi_devices_pass" ;
170
189
multi_devices_pass = AppendPass (" dist_multi_devices_pass" ).get ();
171
190
} else {
@@ -235,17 +254,22 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
235
254
pass->Erase (kNCCLCtxs );
236
255
pass->SetNotOwned <platform::NCCLContextMap>(kNCCLCtxs , nctx);
237
256
#endif
238
- } else if (pass->Type () == " fuse_all_reduce_op_pass" ) {
257
+ } else if (pass->Type () == " alloc_continuous_space_for_grad_pass" ||
258
+ pass->Type () == " fuse_adam_op_pass" ||
259
+ pass->Type () == " fuse_sgd_op_pass" ||
260
+ pass->Type () == " fuse_all_reduce_op_pass" ) {
239
261
pass->Erase (kPlaces );
240
262
pass->SetNotOwned <const std::vector<platform::Place>>(kPlaces , &places);
241
263
pass->Erase (kLocalScopes );
242
264
pass->SetNotOwned <const std::vector<Scope *>>(kLocalScopes ,
243
265
&local_scopes);
266
+ if (pass->Type () == " fuse_all_reduce_op_pass" ) {
244
267
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
245
- platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr ;
246
- pass->Erase (kNCCLCtxs );
247
- pass->SetNotOwned <platform::NCCLContextMap>(kNCCLCtxs , nctx);
268
+ platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr ;
269
+ pass->Erase (kNCCLCtxs );
270
+ pass->SetNotOwned <platform::NCCLContextMap>(kNCCLCtxs , nctx);
248
271
#endif
272
+ }
249
273
} else if (pass->Type () == " alloc_continuous_space_for_grad_pass" ) {
250
274
pass->Erase (kPlaces );
251
275
pass->SetNotOwned <const std::vector<platform::Place>>(kPlaces , &places);
@@ -294,4 +318,6 @@ USE_PASS(inplace_pass);
294
318
USE_PASS (lock_free_optimize_pass);
295
319
USE_PASS (alloc_continuous_space_for_grad_pass);
296
320
USE_PASS (graph_to_program_pass);
321
+ USE_PASS (fuse_adam_op_pass);
322
+ USE_PASS (fuse_sgd_op_pass);
297
323
USE_PASS (fuse_all_reduce_op_pass);
0 commit comments