Skip to content

Commit 4cc6144

Browse files
author
chengduo
authored
[Cherry-pick]Fix the bug of all_reduce_deps_pass (#16648)
* fix the bug of all_reduce_deps_pass test=release/1.4
1 parent d3b6291 commit 4cc6144

File tree

8 files changed

+166
-141
lines changed

8 files changed

+166
-141
lines changed

paddle/fluid/framework/details/all_reduce_deps_pass.cc

Lines changed: 151 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -13,125 +13,186 @@
1313
// limitations under the License.
1414

1515
#include <algorithm>
16-
#include <memory>
16+
#include <map>
1717
#include <string>
1818
#include <unordered_map>
1919
#include <unordered_set>
20+
#include <utility>
2021
#include <vector>
2122

22-
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
2323
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
24+
#include "paddle/fluid/framework/details/container_cast.h"
2425
#include "paddle/fluid/framework/details/multi_devices_helper.h"
2526
#include "paddle/fluid/framework/details/op_graph_view.h"
26-
#include "paddle/fluid/framework/details/var_handle.h"
27+
#include "paddle/fluid/framework/ir/graph.h"
2728
#include "paddle/fluid/framework/ir/graph_helper.h"
29+
#include "paddle/fluid/framework/ir/pass.h"
2830
#include "paddle/fluid/framework/op_proto_maker.h"
2931

3032
namespace paddle {
3133
namespace framework {
3234
namespace details {
3335

34-
VarHandle* GetValidInput(const OpHandleBase* a) {
35-
for (auto p : a->Inputs()) {
36-
VarHandle* b = dynamic_cast<VarHandle*>(p);
37-
if (b) {
38-
return b;
36+
class AllReduceDepsPass : public ir::Pass {
37+
protected:
38+
void ApplyImpl(ir::Graph* graph) const override {
39+
std::vector<AllReduceOpHandle*> all_reduce_op_handles =
40+
GetSortedAllReduceOps(*graph);
41+
42+
for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) {
43+
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
44+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
45+
all_reduce_op_handles[i - 1]->AddOutput(dep_var);
46+
all_reduce_op_handles[i]->AddInput(dep_var);
3947
}
40-
}
4148

42-
return nullptr;
43-
}
44-
45-
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
46-
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
47-
48-
// get vars order
49-
int order = 0;
50-
std::unordered_map<std::string, int> vars;
51-
// TODO(gongwb): use graph topology sort to find the order of operators.
52-
// Note that must assert topology sort is stable
53-
auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
54-
for (auto* op_desc : ops) {
55-
try {
56-
bool is_bk_op =
57-
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
58-
OpProtoAndCheckerMaker::OpRoleAttrName())) &
59-
static_cast<int>(OpRole::kBackward));
60-
if (!is_bk_op) continue;
61-
62-
auto backward_vars =
63-
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
64-
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
65-
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
66-
67-
auto outputs = op_desc->Outputs();
68-
for (auto& o_it : outputs) {
69-
for (auto& v : o_it.second) { // values
70-
vars[v] = order;
71-
VLOG(10) << "in all_reduce_deps_pass:" << v;
72-
}
73-
}
74-
order++;
75-
} catch (boost::bad_get e) {
49+
if (VLOG_IS_ON(10)) {
50+
DebugString(*graph, all_reduce_op_handles);
7651
}
7752
}
7853

79-
std::vector<OpHandleBase*> dist_ops;
80-
// get allreduce ops.
81-
for (auto& op : graph_ops) {
82-
// FIXME(gongwb):add broad cast.
83-
if (op->Name() == "all_reduce" || op->Name() == "reduce") {
84-
dist_ops.push_back(op);
54+
std::vector<AllReduceOpHandle*> GetSortedAllReduceOps(
55+
const ir::Graph& graph) const {
56+
std::vector<AllReduceOpHandle*> all_reduce_op_handles;
57+
std::unordered_map<OpHandleBase*, size_t> pending_ops;
58+
std::unordered_set<OpHandleBase*> ready_ops;
59+
std::unordered_set<OpHandleBase*> next_ready_ops;
60+
61+
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph);
62+
size_t num_of_ops = op_handles.size();
63+
for (OpHandleBase* op : op_handles) {
64+
size_t not_ready_vars = op->NotReadyInputSize();
65+
if (not_ready_vars) {
66+
pending_ops.insert({op, not_ready_vars});
67+
} else {
68+
ready_ops.insert(op);
69+
}
8570
}
86-
}
87-
88-
VLOG(10) << "dist_ops size:" << dist_ops.size()
89-
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
90-
91-
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
92-
OpHandleBase* op2) {
93-
VarHandle* i0 = dynamic_cast<VarHandle*>(GetValidInput(op1));
94-
VarHandle* i1 = dynamic_cast<VarHandle*>(GetValidInput(op2));
95-
96-
PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error",
97-
op1->DebugString(), op2->DebugString());
9871

99-
auto l_it = vars.find(i0->name());
100-
auto r_it = vars.find(i1->name());
101-
102-
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(),
103-
"can't find var's name %s and %s in opdesc", i0->name(),
104-
i1->name());
105-
106-
if (l_it->second < r_it->second) return true;
72+
GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
73+
74+
size_t has_run_ops = ready_ops.size();
75+
while (has_run_ops != num_of_ops) {
76+
for (auto* op : ready_ops) {
77+
for (auto& ready_var : op->Outputs()) {
78+
for (auto* pend_op : ready_var->PendingOps()) {
79+
auto& deps = --pending_ops[pend_op];
80+
if (deps == 0) {
81+
next_ready_ops.insert(pend_op);
82+
}
83+
}
84+
}
85+
}
10786

108-
if (l_it->second == r_it->second) {
109-
return i0->name() < i1->name();
87+
PADDLE_ENFORCE_NE(next_ready_ops.size(), 0, "There maybe have a cycle.");
88+
ready_ops.clear();
89+
std::swap(ready_ops, next_ready_ops);
90+
GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
91+
has_run_ops += ready_ops.size();
11092
}
93+
return all_reduce_op_handles;
94+
}
11195

112-
return false;
113-
});
114-
115-
// add dependency.
116-
auto& sorted_ops = dist_ops;
117-
for (size_t i = 1; i < sorted_ops.size(); ++i) {
118-
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
119-
120-
auto* pre_op = sorted_ops[i - 1];
121-
auto* op = sorted_ops[i];
122-
123-
pre_op->AddOutput(dep_var);
124-
op->AddInput(dep_var);
125-
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
96+
void GetSortedAllReduceOps(
97+
const std::unordered_set<OpHandleBase*>& ready_ops,
98+
std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const {
99+
std::vector<AllReduceOpHandle*> current_all_reduce_op_handles;
100+
for (auto& op_handle : ready_ops) {
101+
auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle);
102+
if (all_reduce_op_handle) {
103+
current_all_reduce_op_handles.emplace_back(all_reduce_op_handle);
104+
}
105+
}
126106

127-
VLOG(10) << "add all_reduce sequential dependencies between " << pre_op
128-
<< " and " << op;
107+
// NOTE(zcd): For distributed training, it is important to keep the order of
108+
// allReduce on each node consistent. Otherwise, hang may occur.
109+
// Sort the current_all_reduce_op_handles according to the name of input.
110+
sort(current_all_reduce_op_handles.begin(),
111+
current_all_reduce_op_handles.end(),
112+
[](const AllReduceOpHandle* left,
113+
const AllReduceOpHandle* right) -> bool {
114+
auto left_in_vars = DynamicCast<VarHandle>(left->Inputs());
115+
auto right_in_vars = DynamicCast<VarHandle>(right->Inputs());
116+
PADDLE_ENFORCE_GT(left_in_vars.size(), 0);
117+
PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size());
118+
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
119+
});
120+
121+
all_reduce_op_handles->insert(all_reduce_op_handles->end(),
122+
current_all_reduce_op_handles.begin(),
123+
current_all_reduce_op_handles.end());
124+
}
129125

130-
VLOG(10) << "pre_op:" << pre_op->DebugString()
131-
<< ", op:" << op->DebugString();
126+
void DebugString(
127+
const ir::Graph& graph,
128+
const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const {
129+
// get vars order
130+
std::map<int, std::vector<std::string>> vars =
131+
GetSoredGradientsFromStaleProgram(graph);
132+
std::stringstream out;
133+
size_t grads_of_stale_program = 0;
134+
out << "Get Order From kStaleProgramOpDescs: ";
135+
for (auto& var : vars) {
136+
out << "Order " << var.first << " [";
137+
for (auto& var_name : var.second) {
138+
out << var_name << ", ";
139+
++grads_of_stale_program;
140+
}
141+
out << "], ";
142+
}
143+
VLOG(10) << out.str();
144+
145+
std::stringstream out2;
146+
out2 << "Get Order From Topological order: ";
147+
for (auto& op : all_reduce_op_handles) {
148+
bool find_valid_input = false;
149+
for (auto& in_var : op->Inputs()) {
150+
if (dynamic_cast<VarHandle*>(in_var)) {
151+
out2 << in_var->Name() << ", ";
152+
find_valid_input = true;
153+
break;
154+
}
155+
}
156+
PADDLE_ENFORCE(find_valid_input, "Doesn't find valid input.");
157+
}
158+
VLOG(10) << out2.str();
159+
if (grads_of_stale_program != all_reduce_op_handles.size()) {
160+
VLOG(10)
161+
<< "The gradients number of stale program and graph is not equal.";
162+
}
132163
}
133-
}
134164

165+
std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram(
166+
const ir::Graph& graph) const {
167+
std::map<int, std::vector<std::string>> vars;
168+
auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
169+
int order = 0;
170+
for (auto* op_desc : ops) {
171+
try {
172+
bool is_bk_op =
173+
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
174+
OpProtoAndCheckerMaker::OpRoleAttrName())) &
175+
static_cast<int>(OpRole::kBackward));
176+
if (!is_bk_op) continue;
177+
178+
auto backward_vars =
179+
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
180+
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
181+
if (backward_vars.empty()) continue;
182+
183+
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
184+
for (size_t i = 1; i < backward_vars.size(); i += 2) {
185+
vars[order].emplace_back(backward_vars[i]);
186+
VLOG(1) << "get parameter and gradient: " << backward_vars[i - 1]
187+
<< ", " << backward_vars[i];
188+
}
189+
order++;
190+
} catch (boost::bad_get e) {
191+
}
192+
}
193+
return vars;
194+
}
195+
};
135196
} // namespace details
136197
} // namespace framework
137198
} // namespace paddle

paddle/fluid/framework/details/all_reduce_deps_pass.h

Lines changed: 0 additions & 32 deletions
This file was deleted.

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
// asynchronous nccl allreduce or synchronous issue:
2929
// https://github.com/PaddlePaddle/Paddle/issues/15049
3030
DEFINE_bool(
31-
sync_nccl_allreduce, false,
31+
sync_nccl_allreduce, true,
3232
"If set true, will call `cudaStreamSynchronize(nccl_stream)`"
3333
"after allreduce, this mode can get better performance in some scenarios.");
3434

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,22 +163,21 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
163163
"graph_printer", new details::GraphvizSSAGraphPrinter);
164164
}
165165

166-
// Verify that the graph is correct for multi-device executor.
167-
AppendPass("multi_devices_check_pass");
168-
169-
if (VLOG_IS_ON(2)) {
170-
AppendPass("all_reduce_deps_pass");
171-
}
172-
173-
if (SeqOnlyAllReduceOps(strategy_)) {
174-
VLOG(10) << "Add all_reduce_deps_pass";
166+
// experimental shows that the program will be faster if append
167+
// all_reduce_deps_pass here.
168+
if (!strategy_.enable_parallel_graph_ &&
169+
(SeqOnlyAllReduceOps(strategy_) ||
170+
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
175171
AppendPass("all_reduce_deps_pass");
176172
}
177173

178174
if (strategy_.remove_unnecessary_lock_) {
179175
VLOG(10) << "Add modify_op_lock_and_record_event_pass";
180176
AppendPass("modify_op_lock_and_record_event_pass");
181177
}
178+
179+
// Verify that the graph is correct for multi-device executor.
180+
AppendPass("multi_devices_check_pass");
182181
}
183182

184183
// Convert graph to run on multi-devices.

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ void OpHandleBase::Run(bool use_cuda) {
6868
if (out_var_handle) {
6969
PADDLE_ENFORCE(
7070
platform::is_same_place(place, out_var_handle->place()),
71-
"The place of input(%s) is not consistent with the "
71+
"The place of output(%s) is not consistent with the "
7272
"place of current op(%s).",
7373
out_var_handle->Name(), Name());
7474
out_var_handle->SetGenerateEvent(events_.at(dev_id));

paddle/fluid/framework/ir/multi_batch_merge_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
8484

8585
// 1. record op nodes of different roles
8686
for (auto node : nodes) {
87-
if (node->IsVar()) continue;
87+
if (!node->IsOp()) continue;
88+
PADDLE_ENFORCE(node->Op(), "must find opdesc");
8889
int op_role = boost::get<int>(node->Op()->GetAttr(
8990
framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
9091
if ((op_role == static_cast<int>(framework::OpRole::kForward)) ||

paddle/fluid/framework/parallel_executor.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,14 @@ limitations under the License. */
1919
#include <tuple>
2020
#include <utility>
2121
#include <vector>
22-
#include "paddle/fluid/framework/ir/graph_helper.h"
23-
24-
#include "paddle/fluid/framework/ir/graph.h"
25-
26-
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
2722
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
2823
#include "paddle/fluid/framework/details/multi_devices_helper.h"
2924
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
3025
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
3126
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
3227
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
28+
#include "paddle/fluid/framework/ir/graph.h"
29+
#include "paddle/fluid/framework/ir/graph_helper.h"
3330
#include "paddle/fluid/platform/profiler.h"
3431

3532
#ifdef WITH_GPERFTOOLS

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def run_trainer(self, args):
139139
pass_builder = None
140140
if args.batch_merge_repeat > 1:
141141
pass_builder = build_stra._finalize_strategy_and_create_passes()
142-
mypass = pass_builder.insert_pass(
143-
len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
142+
mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass")
144143
mypass.set("num_repeats", args.batch_merge_repeat)
145144

146145
if args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer":

0 commit comments

Comments
 (0)