Skip to content

Commit 1096746

Browse files
author
chengduo
authored
Fuse Adam And SGD ops (#15933)
* fuse optimizer
1 parent 2632327 commit 1096746

23 files changed

+1101
-147
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framewor
1010
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
1111
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
1212
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
13+
1314
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
15+
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
16+
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
1417

1518
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1619

@@ -104,5 +107,7 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
104107
graph_viz_pass multi_devices_graph_pass
105108
multi_devices_graph_print_pass multi_devices_graph_check_pass
106109
fuse_elewise_add_act_pass multi_batch_merge_pass
107-
fuse_relu_depthwise_conv_pass
108-
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass)
110+
fuse_relu_depthwise_conv_pass
111+
memory_optimize_pass lock_free_optimize_pass
112+
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
113+
fuse_adam_op_pass fuse_sgd_op_pass)

paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "paddle/fluid/framework/details/multi_devices_helper.h"
2222
#include "paddle/fluid/framework/ir/graph_helper.h"
2323
#include "paddle/fluid/framework/op_registry.h"
24+
2425
DEFINE_uint32(fuse_parameter_memory_size, 0, // 0 KB
2526
"fuse_parameter_memory_size is up limited memory size "
2627
"of one group parameters' gradient which is the input "
@@ -105,20 +106,29 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
105106
auto ele_dtype = iter->second->Var()->GetDataType();
106107
if (dtype == kDefaultDtype) {
107108
dtype = ele_dtype;
108-
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype);
109+
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
110+
"The data type should not be bool.");
109111
}
110-
PADDLE_ENFORCE_EQ(ele_dtype, dtype);
112+
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
113+
"The data type of input is not consistent.");
111114
}
112115

113-
// Create the fused variable name.
116+
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
117+
// pass.
114118
if (!result.Has(kFusedVars)) {
115119
result.Set(kFusedVars, new FusedVars);
116120
}
117-
const std::string prefix(kFusedVarNamePrefix);
118-
// The fused_var_name should be unique.
119-
auto fused_var_name = prefix + "GRAD@" + params_grads[0].second;
121+
// the kFusedGrads is used be fuse_optimizer_op_pass.
122+
result.Set(kFusedGrads, new FusedGrads);
123+
124+
// the fused_var_name should be unique, so it appends
125+
// params_grads.begin()->second.
126+
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
127+
params_grads.begin()->second;
128+
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
120129
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
121-
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
130+
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
131+
"%s is duplicate in FusedVars.", fused_var_name);
122132
fused_var_set.insert(fused_var_name);
123133

124134
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
@@ -295,17 +305,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
295305
return type == proto::VarType::LOD_TENSOR;
296306
}
297307

298-
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
299-
const std::vector<std::string> &grads_name,
300-
const std::string &fused_var_name,
301-
BlockDesc *global_block) const {
302-
auto op_desc = global_block->AppendOp();
303-
op_desc->SetType("alloc_continuous_space");
304-
op_desc->SetInput("Input", params_name);
305-
op_desc->SetOutput("Output", grads_name);
306-
op_desc->SetOutput("FusedOutput", {fused_var_name});
307-
}
308-
309308
void RecordParamsAndGrads(ir::Node *node,
310309
ParamsAndGrads *params_grads) const {
311310
try {
@@ -358,6 +357,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
358357
}
359358
}
360359

360+
// Alloc continuous space for vars.
361361
std::vector<std::string> grads_name;
362362
std::vector<std::string> params_name;
363363
grads_name.reserve(params_grads.size());
@@ -370,14 +370,24 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
370370
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
371371
program_desc.MutableBlock(0));
372372

373-
// Run Only Once Programs
374373
for (size_t i = 0; i < local_scopes.size(); ++i) {
375374
for (auto &op_desc : program_desc.Block(0).AllOps()) {
376375
auto op = OpRegistry::CreateOp(*op_desc);
377376
op->Run(*local_scopes[i], places[i]);
378377
}
379378
}
380379
}
380+
381+
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
382+
const std::vector<std::string> &grads_name,
383+
const std::string &fused_var_name,
384+
BlockDesc *global_block) const {
385+
auto op_desc = global_block->AppendOp();
386+
op_desc->SetType("alloc_continuous_space");
387+
op_desc->SetInput("Input", params_name);
388+
op_desc->SetOutput("Output", grads_name);
389+
op_desc->SetOutput("FusedOutput", {fused_var_name});
390+
}
381391
};
382392

383393
} // namespace details

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,17 @@ void BroadcastOpHandle::RunImpl() {
2727
if (places_.size() == 1) return;
2828

2929
// The input and output may have dummy vars.
30-
VarHandle *in_var_handle;
31-
{
32-
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
33-
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
34-
"The number of input should be one.");
35-
in_var_handle = in_var_handles[0];
36-
}
37-
30+
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
3831
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
3932

33+
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
34+
"The number of input should be one.");
4035
PADDLE_ENFORCE_EQ(
4136
out_var_handles.size(), places_.size(),
4237
"The number of output should equal to the number of places.");
4338

39+
VarHandle *in_var_handle = in_var_handles[0];
40+
4441
WaitInputVarGenerated();
4542

4643
std::vector<const Scope *> var_scopes;

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License. */
1717
#include <glog/logging.h>
1818
#include <memory>
1919
#include <utility>
20-
2120
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
2221
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
2322
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
@@ -82,23 +81,43 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
8281
AppendPass("inplace_pass");
8382
}
8483

85-
if (strategy.fuse_elewise_add_act_ops_) {
84+
if (strategy_.fuse_elewise_add_act_ops_) {
8685
VLOG(10) << "Add fuse_elewise_add_act_pass";
8786
AppendPass("fuse_elewise_add_act_pass");
8887
}
8988

9089
// for single card training, fuse_all_reduce_ops is unnecessary.
9190
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
92-
if (strategy.fuse_all_reduce_ops_) {
91+
if (strategy_.fuse_all_reduce_ops_) {
9392
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
9493
AppendPass("alloc_continuous_space_for_grad_pass");
9594
}
9695

96+
if (strategy_.fuse_all_optimizer_ops_) {
97+
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
98+
strategy_.is_distribution_) {
99+
VLOG(3)
100+
<< "Currently, fuse_all_optimizer_ops only works under AllReduce "
101+
"mode.";
102+
strategy_.fuse_all_optimizer_ops_ = false;
103+
} else {
104+
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
105+
AppendPass("alloc_continuous_space_for_grad_pass");
106+
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
107+
// if the number is zero, fuse_all_reduce_ops will do nothing.
108+
// Currently, only one type of optimization algorithm can be fused.
109+
VLOG(10) << "Add fuse_adam_op_pass";
110+
AppendPass("fuse_adam_op_pass");
111+
VLOG(10) << "Add fuse_sgd_op_pass";
112+
AppendPass("fuse_sgd_op_pass");
113+
}
114+
}
115+
97116
// Add a graph viz pass to record a graph.
98117
if (!strategy.debug_graphviz_path_.empty()) {
99118
auto viz_pass = AppendPass("graph_viz_pass");
100119
const std::string graph_path = string::Sprintf(
101-
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
120+
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph");
102121
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
103122
}
104123

@@ -118,14 +137,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
118137
// the de-fact IR, any reuse on Graph is meaningless.
119138
// A side-effect of that, memory optimize cannot forsee the fetched vars
120139
// , so fetchlist should be set persistable before call the Run interface.
121-
if (strategy.memory_optimize_) {
140+
if (strategy_.memory_optimize_) {
122141
VLOG(10) << "Add memory_optimize_pass";
123142
AppendPass("memory_optimize_pass");
124143
}
125144

126-
AppendMultiDevPass(strategy);
145+
AppendMultiDevPass(strategy_);
127146

128-
if (strategy.fuse_all_reduce_ops_) {
147+
if (strategy_.fuse_all_reduce_ops_) {
129148
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
130149
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
131150
VLOG(10) << "Add fuse_all_reduce_op_pass";
@@ -151,7 +170,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
151170
AppendPass("all_reduce_deps_pass");
152171
}
153172

154-
if (SeqOnlyAllReduceOps(strategy)) {
173+
if (SeqOnlyAllReduceOps(strategy_)) {
155174
VLOG(10) << "Add all_reduce_deps_pass";
156175
AppendPass("all_reduce_deps_pass");
157176
}
@@ -165,7 +184,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
165184
// Convert graph to run on multi-devices.
166185
void AppendMultiDevPass(const BuildStrategy &strategy) {
167186
ir::Pass *multi_devices_pass = nullptr;
168-
if (strategy_.is_distribution_) {
187+
if (strategy.is_distribution_) {
169188
VLOG(10) << "Add dist_multi_devices_pass";
170189
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
171190
} else {
@@ -235,17 +254,22 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
235254
pass->Erase(kNCCLCtxs);
236255
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
237256
#endif
238-
} else if (pass->Type() == "fuse_all_reduce_op_pass") {
257+
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
258+
pass->Type() == "fuse_adam_op_pass" ||
259+
pass->Type() == "fuse_sgd_op_pass" ||
260+
pass->Type() == "fuse_all_reduce_op_pass") {
239261
pass->Erase(kPlaces);
240262
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
241263
pass->Erase(kLocalScopes);
242264
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
243265
&local_scopes);
266+
if (pass->Type() == "fuse_all_reduce_op_pass") {
244267
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
245-
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
246-
pass->Erase(kNCCLCtxs);
247-
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
268+
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
269+
pass->Erase(kNCCLCtxs);
270+
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
248271
#endif
272+
}
249273
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
250274
pass->Erase(kPlaces);
251275
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
@@ -294,4 +318,6 @@ USE_PASS(inplace_pass);
294318
USE_PASS(lock_free_optimize_pass);
295319
USE_PASS(alloc_continuous_space_for_grad_pass);
296320
USE_PASS(graph_to_program_pass);
321+
USE_PASS(fuse_adam_op_pass);
322+
USE_PASS(fuse_sgd_op_pass);
297323
USE_PASS(fuse_all_reduce_op_pass);

paddle/fluid/framework/details/build_strategy.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <string>
1919
#include <utility>
2020
#include <vector>
21-
2221
#include "paddle/fluid/framework/ir/pass_builder.h"
2322
#include "paddle/fluid/framework/program_desc.h"
2423
#include "paddle/fluid/framework/scope.h"
@@ -76,6 +75,8 @@ struct BuildStrategy {
7675

7776
bool fuse_elewise_add_act_ops_{false};
7877

78+
bool fuse_all_optimizer_ops_{false};
79+
7980
bool fuse_all_reduce_ops_{false};
8081

8182
bool fuse_relu_depthwise_conv_{false};

0 commit comments

Comments
 (0)