Skip to content

Commit ff599b9

Browse files
committed
use Reduce and Broadcast
1 parent 0441c2c commit ff599b9

File tree

2 files changed

+13
-59
lines changed

2 files changed

+13
-59
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 10 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
111111
for (auto *var : program.Block(0).AllVars()) {
112112
var_types[var->Name()] = var->GetType();
113113
}
114+
114115
auto graph = new SSAGraph();
115116
SSAGraph &result = *graph;
116117
std::unordered_set<std::string> og_has_been_broadcast;
@@ -120,13 +121,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
120121
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
121122
places_.size());
122123

123-
size_t cur_dev_id = 0;
124-
std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
125-
std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
126-
127-
sparse_var_name_on_devices.resize(places_.size());
128-
bcast_sparse_var_name_set.resize(places_.size());
129-
130124
// Find "send" op first for split is in front of send.
131125
OpDesc *send_op = GetSendOpDesc(program);
132126

@@ -145,27 +139,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
145139
}
146140
is_forwarding = false;
147141
} else {
148-
int op_dev_id = GetOpDeviceID(sparse_var_name_on_devices, *op);
149-
if (op_dev_id == -1) { // var on all device
150-
CreateComputationalOps(&result, *op, places_.size());
151-
} else {
152-
CreateComputationalOp(&result, *op, op_dev_id);
153-
for (auto &var_name : op->OutputArgumentNames()) {
154-
sparse_var_name_on_devices[op_dev_id].emplace(var_name);
155-
}
156-
}
157-
142+
CreateComputationalOps(&result, *op, places_.size());
158143
if (!is_forwarding && places_.size() > 1) {
159144
// Currently, we assume that once gradient is generated, it can be
160145
// broadcast, and each gradient is only broadcast once.
161146
for (auto &og : op->OutputArgumentNames()) {
162147
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
163148
if (IsSparseGradient(var_types, og)) {
164-
CreateReduceOp(&result, cur_dev_id, og);
165-
sparse_var_name_on_devices[cur_dev_id].emplace(og);
166-
bcast_sparse_var_name_set[cur_dev_id].emplace(
167-
og.substr(0, og.size() - strlen(kGradVarSuffix)));
168-
cur_dev_id = (cur_dev_id + 1) % places_.size();
149+
CreateReduceOp(&result, og, 0);
150+
CreateBroadcastOp(&result, og, 0);
169151
} else {
170152
InsertNCCLAllReduceOp(&result, og);
171153
}
@@ -175,14 +157,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
175157
}
176158
}
177159

178-
// Insert BCast Ops
179-
for (size_t dev_id = 0; dev_id < bcast_sparse_var_name_set.size(); ++dev_id) {
180-
auto &to_bcast_set = bcast_sparse_var_name_set[dev_id];
181-
for (auto &bcast_name : to_bcast_set) {
182-
CreateBroadcastOp(&result, bcast_name, dev_id);
183-
}
184-
}
185-
186160
/*
187161
Dependency graph has been constructed. However, there are still data
188162
harzaeds need to be handled.
@@ -213,38 +187,21 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
213187
return false;
214188
}
215189

216-
int MultiDevSSAGraphBuilder::GetOpDeviceID(
217-
const std::vector<std::unordered_set<std::string>>
218-
&sparse_var_name_on_devices,
219-
const OpDesc &op) const {
220-
int var_dev_id = -1;
221-
for (auto &var_name : op.InputArgumentNames()) {
222-
if (var_dev_id != -1) break;
223-
for (size_t i = 0; i < sparse_var_name_on_devices.size(); ++i) {
224-
if (sparse_var_name_on_devices[i].count(var_name)) {
225-
var_dev_id = static_cast<int>(i);
226-
break;
227-
}
228-
}
229-
}
230-
return var_dev_id;
231-
}
232-
233190
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
234191
const std::string &p_name,
235-
size_t dev_id) const {
192+
size_t src_dev_id) const {
236193
#ifdef PADDLE_WITH_CUDA
237194
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
238195
#else
239196
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
240197
#endif
241198

242199
result->ops_.emplace_back(op_handle);
243-
auto *in = result->vars_.at(dev_id).at(p_name).back().get();
200+
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
244201
op_handle->AddInput(in);
245202

246203
for (size_t i = 0; i < places_.size(); ++i) {
247-
auto &vars = result->vars_.at(dev_id).at(p_name);
204+
auto &vars = result->vars_.at(i).at(p_name);
248205
auto &p = places_[i];
249206
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
250207
vars.emplace_back(out_var);
@@ -345,8 +302,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
345302
}
346303
}
347304

348-
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(
349-
SSAGraph *result, int dst_dev_id, const std::string &og) const {
305+
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
306+
const std::string &og,
307+
int dst_dev_id) const {
350308
#ifdef PADDLE_WITH_CUDA
351309
result->ops_.emplace_back(
352310
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
7575
size_t num_places) const;
7676

7777
void CreateScaleLossGradOp(SSAGraph *result) const;
78-
VarHandle *CreateReduceOp(SSAGraph *result, int dst_dev_id,
79-
const std::string &og) const;
78+
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
79+
int dst_dev_id) const;
8080
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
8181
int dev_id) const;
8282

@@ -87,11 +87,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
8787
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
8888

8989
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
90-
size_t dev_id) const;
91-
92-
int GetOpDeviceID(
93-
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
94-
const OpDesc &op) const;
90+
size_t src_dev_id) const;
9591

9692
/**
9793
* Get send op in the global block of program.

0 commit comments

Comments
 (0)