|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
15 | 15 | #include <algorithm>
|
16 |
| -#include <memory> |
| 16 | +#include <map> |
17 | 17 | #include <string>
|
18 | 18 | #include <unordered_map>
|
19 | 19 | #include <unordered_set>
|
| 20 | +#include <utility> |
20 | 21 | #include <vector>
|
21 | 22 |
|
22 |
| -#include "paddle/fluid/framework/details/all_reduce_deps_pass.h" |
23 | 23 | #include "paddle/fluid/framework/details/all_reduce_op_handle.h"
|
| 24 | +#include "paddle/fluid/framework/details/container_cast.h" |
24 | 25 | #include "paddle/fluid/framework/details/multi_devices_helper.h"
|
25 | 26 | #include "paddle/fluid/framework/details/op_graph_view.h"
|
26 |
| -#include "paddle/fluid/framework/details/var_handle.h" |
| 27 | +#include "paddle/fluid/framework/ir/graph.h" |
27 | 28 | #include "paddle/fluid/framework/ir/graph_helper.h"
|
| 29 | +#include "paddle/fluid/framework/ir/pass.h" |
28 | 30 | #include "paddle/fluid/framework/op_proto_maker.h"
|
29 | 31 |
|
30 | 32 | namespace paddle {
|
31 | 33 | namespace framework {
|
32 | 34 | namespace details {
|
33 | 35 |
|
34 |
| -VarHandle* GetValidInput(const OpHandleBase* a) { |
35 |
| - for (auto p : a->Inputs()) { |
36 |
| - VarHandle* b = dynamic_cast<VarHandle*>(p); |
37 |
| - if (b) { |
38 |
| - return b; |
| 36 | +class AllReduceDepsPass : public ir::Pass { |
| 37 | + protected: |
| 38 | + void ApplyImpl(ir::Graph* graph) const override { |
| 39 | + std::vector<AllReduceOpHandle*> all_reduce_op_handles = |
| 40 | + GetSortedAllReduceOps(*graph); |
| 41 | + |
| 42 | + for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) { |
| 43 | + auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar()); |
| 44 | + graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var); |
| 45 | + all_reduce_op_handles[i - 1]->AddOutput(dep_var); |
| 46 | + all_reduce_op_handles[i]->AddInput(dep_var); |
39 | 47 | }
|
40 |
| - } |
41 | 48 |
|
42 |
| - return nullptr; |
43 |
| -} |
44 |
| - |
45 |
| -void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const { |
46 |
| - auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph); |
47 |
| - |
48 |
| - // get vars order |
49 |
| - int order = 0; |
50 |
| - std::unordered_map<std::string, int> vars; |
51 |
| - // TODO(gongwb): use graph topology sort to find the order of operators. |
52 |
| - // Note that must assert topology sort is stable |
53 |
| - auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs); |
54 |
| - for (auto* op_desc : ops) { |
55 |
| - try { |
56 |
| - bool is_bk_op = |
57 |
| - static_cast<bool>(boost::get<int>(op_desc->GetAttr( |
58 |
| - OpProtoAndCheckerMaker::OpRoleAttrName())) & |
59 |
| - static_cast<int>(OpRole::kBackward)); |
60 |
| - if (!is_bk_op) continue; |
61 |
| - |
62 |
| - auto backward_vars = |
63 |
| - boost::get<std::vector<std::string>>(op_desc->GetNullableAttr( |
64 |
| - OpProtoAndCheckerMaker::OpRoleVarAttrName())); |
65 |
| - PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); |
66 |
| - |
67 |
| - auto outputs = op_desc->Outputs(); |
68 |
| - for (auto& o_it : outputs) { |
69 |
| - for (auto& v : o_it.second) { // values |
70 |
| - vars[v] = order; |
71 |
| - VLOG(10) << "in all_reduce_deps_pass:" << v; |
72 |
| - } |
73 |
| - } |
74 |
| - order++; |
75 |
| - } catch (boost::bad_get e) { |
| 49 | + if (VLOG_IS_ON(10)) { |
| 50 | + DebugString(*graph, all_reduce_op_handles); |
76 | 51 | }
|
77 | 52 | }
|
78 | 53 |
|
79 |
| - std::vector<OpHandleBase*> dist_ops; |
80 |
| - // get allreduce ops. |
81 |
| - for (auto& op : graph_ops) { |
82 |
| - // FIXME(gongwb):add broad cast. |
83 |
| - if (op->Name() == "all_reduce" || op->Name() == "reduce") { |
84 |
| - dist_ops.push_back(op); |
| 54 | + std::vector<AllReduceOpHandle*> GetSortedAllReduceOps( |
| 55 | + const ir::Graph& graph) const { |
| 56 | + std::vector<AllReduceOpHandle*> all_reduce_op_handles; |
| 57 | + std::unordered_map<OpHandleBase*, size_t> pending_ops; |
| 58 | + std::unordered_set<OpHandleBase*> ready_ops; |
| 59 | + std::unordered_set<OpHandleBase*> next_ready_ops; |
| 60 | + |
| 61 | + auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph); |
| 62 | + size_t num_of_ops = op_handles.size(); |
| 63 | + for (OpHandleBase* op : op_handles) { |
| 64 | + size_t not_ready_vars = op->NotReadyInputSize(); |
| 65 | + if (not_ready_vars) { |
| 66 | + pending_ops.insert({op, not_ready_vars}); |
| 67 | + } else { |
| 68 | + ready_ops.insert(op); |
| 69 | + } |
85 | 70 | }
|
86 |
| - } |
87 |
| - |
88 |
| - VLOG(10) << "dist_ops size:" << dist_ops.size() |
89 |
| - << ", outputs size:" << vars.size() << ", ops size:" << ops.size(); |
90 |
| - |
91 |
| - std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1, |
92 |
| - OpHandleBase* op2) { |
93 |
| - VarHandle* i0 = dynamic_cast<VarHandle*>(GetValidInput(op1)); |
94 |
| - VarHandle* i1 = dynamic_cast<VarHandle*>(GetValidInput(op2)); |
95 |
| - |
96 |
| - PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error", |
97 |
| - op1->DebugString(), op2->DebugString()); |
98 | 71 |
|
99 |
| - auto l_it = vars.find(i0->name()); |
100 |
| - auto r_it = vars.find(i1->name()); |
101 |
| - |
102 |
| - PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(), |
103 |
| - "can't find var's name %s and %s in opdesc", i0->name(), |
104 |
| - i1->name()); |
105 |
| - |
106 |
| - if (l_it->second < r_it->second) return true; |
| 72 | + GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles); |
| 73 | + |
| 74 | + size_t has_run_ops = ready_ops.size(); |
| 75 | + while (has_run_ops != num_of_ops) { |
| 76 | + for (auto* op : ready_ops) { |
| 77 | + for (auto& ready_var : op->Outputs()) { |
| 78 | + for (auto* pend_op : ready_var->PendingOps()) { |
| 79 | + auto& deps = --pending_ops[pend_op]; |
| 80 | + if (deps == 0) { |
| 81 | + next_ready_ops.insert(pend_op); |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + } |
107 | 86 |
|
108 |
| - if (l_it->second == r_it->second) { |
109 |
| - return i0->name() < i1->name(); |
| 87 | + PADDLE_ENFORCE_NE(next_ready_ops.size(), 0, "There maybe have a cycle."); |
| 88 | + ready_ops.clear(); |
| 89 | + std::swap(ready_ops, next_ready_ops); |
| 90 | + GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles); |
| 91 | + has_run_ops += ready_ops.size(); |
110 | 92 | }
|
| 93 | + return all_reduce_op_handles; |
| 94 | + } |
111 | 95 |
|
112 |
| - return false; |
113 |
| - }); |
114 |
| - |
115 |
| - // add dependency. |
116 |
| - auto& sorted_ops = dist_ops; |
117 |
| - for (size_t i = 1; i < sorted_ops.size(); ++i) { |
118 |
| - auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar()); |
119 |
| - |
120 |
| - auto* pre_op = sorted_ops[i - 1]; |
121 |
| - auto* op = sorted_ops[i]; |
122 |
| - |
123 |
| - pre_op->AddOutput(dep_var); |
124 |
| - op->AddInput(dep_var); |
125 |
| - graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var); |
| 96 | + void GetSortedAllReduceOps( |
| 97 | + const std::unordered_set<OpHandleBase*>& ready_ops, |
| 98 | + std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const { |
| 99 | + std::vector<AllReduceOpHandle*> current_all_reduce_op_handles; |
| 100 | + for (auto& op_handle : ready_ops) { |
| 101 | + auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle); |
| 102 | + if (all_reduce_op_handle) { |
| 103 | + current_all_reduce_op_handles.emplace_back(all_reduce_op_handle); |
| 104 | + } |
| 105 | + } |
126 | 106 |
|
127 |
| - VLOG(10) << "add all_reduce sequential dependencies between " << pre_op |
128 |
| - << " and " << op; |
| 107 | + // NOTE(zcd): For distributed training, it is important to keep the order of |
| 108 | + // allReduce on each node consistent. Otherwise, hang may occur. |
| 109 | + // Sort the current_all_reduce_op_handles according to the name of input. |
| 110 | + sort(current_all_reduce_op_handles.begin(), |
| 111 | + current_all_reduce_op_handles.end(), |
| 112 | + [](const AllReduceOpHandle* left, |
| 113 | + const AllReduceOpHandle* right) -> bool { |
| 114 | + auto left_in_vars = DynamicCast<VarHandle>(left->Inputs()); |
| 115 | + auto right_in_vars = DynamicCast<VarHandle>(right->Inputs()); |
| 116 | + PADDLE_ENFORCE_GT(left_in_vars.size(), 0); |
| 117 | + PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size()); |
| 118 | + return left_in_vars[0]->Name() > right_in_vars[0]->Name(); |
| 119 | + }); |
| 120 | + |
| 121 | + all_reduce_op_handles->insert(all_reduce_op_handles->end(), |
| 122 | + current_all_reduce_op_handles.begin(), |
| 123 | + current_all_reduce_op_handles.end()); |
| 124 | + } |
129 | 125 |
|
130 |
| - VLOG(10) << "pre_op:" << pre_op->DebugString() |
131 |
| - << ", op:" << op->DebugString(); |
| 126 | + void DebugString( |
| 127 | + const ir::Graph& graph, |
| 128 | + const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const { |
| 129 | + // get vars order |
| 130 | + std::map<int, std::vector<std::string>> vars = |
| 131 | + GetSoredGradientsFromStaleProgram(graph); |
| 132 | + std::stringstream out; |
| 133 | + size_t grads_of_stale_program = 0; |
| 134 | + out << "Get Order From kStaleProgramOpDescs: "; |
| 135 | + for (auto& var : vars) { |
| 136 | + out << "Order " << var.first << " ["; |
| 137 | + for (auto& var_name : var.second) { |
| 138 | + out << var_name << ", "; |
| 139 | + ++grads_of_stale_program; |
| 140 | + } |
| 141 | + out << "], "; |
| 142 | + } |
| 143 | + VLOG(10) << out.str(); |
| 144 | + |
| 145 | + std::stringstream out2; |
| 146 | + out2 << "Get Order From Topological order: "; |
| 147 | + for (auto& op : all_reduce_op_handles) { |
| 148 | + bool find_valid_input = false; |
| 149 | + for (auto& in_var : op->Inputs()) { |
| 150 | + if (dynamic_cast<VarHandle*>(in_var)) { |
| 151 | + out2 << in_var->Name() << ", "; |
| 152 | + find_valid_input = true; |
| 153 | + break; |
| 154 | + } |
| 155 | + } |
| 156 | + PADDLE_ENFORCE(find_valid_input, "Doesn't find valid input."); |
| 157 | + } |
| 158 | + VLOG(10) << out2.str(); |
| 159 | + if (grads_of_stale_program != all_reduce_op_handles.size()) { |
| 160 | + VLOG(10) |
| 161 | + << "The gradients number of stale program and graph is not equal."; |
| 162 | + } |
132 | 163 | }
|
133 |
| -} |
134 | 164 |
|
| 165 | + std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram( |
| 166 | + const ir::Graph& graph) const { |
| 167 | + std::map<int, std::vector<std::string>> vars; |
| 168 | + auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs); |
| 169 | + int order = 0; |
| 170 | + for (auto* op_desc : ops) { |
| 171 | + try { |
| 172 | + bool is_bk_op = |
| 173 | + static_cast<bool>(boost::get<int>(op_desc->GetAttr( |
| 174 | + OpProtoAndCheckerMaker::OpRoleAttrName())) & |
| 175 | + static_cast<int>(OpRole::kBackward)); |
| 176 | + if (!is_bk_op) continue; |
| 177 | + |
| 178 | + auto backward_vars = |
| 179 | + boost::get<std::vector<std::string>>(op_desc->GetNullableAttr( |
| 180 | + OpProtoAndCheckerMaker::OpRoleVarAttrName())); |
| 181 | + if (backward_vars.empty()) continue; |
| 182 | + |
| 183 | + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); |
| 184 | + for (size_t i = 1; i < backward_vars.size(); i += 2) { |
| 185 | + vars[order].emplace_back(backward_vars[i]); |
| 186 | + VLOG(1) << "get parameter and gradient: " << backward_vars[i - 1] |
| 187 | + << ", " << backward_vars[i]; |
| 188 | + } |
| 189 | + order++; |
| 190 | + } catch (boost::bad_get e) { |
| 191 | + } |
| 192 | + } |
| 193 | + return vars; |
| 194 | + } |
| 195 | +}; |
135 | 196 | } // namespace details
|
136 | 197 | } // namespace framework
|
137 | 198 | } // namespace paddle
|
|
0 commit comments