@@ -111,6 +111,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
111
111
for (auto *var : program.Block (0 ).AllVars ()) {
112
112
var_types[var->Name ()] = var->GetType ();
113
113
}
114
+
114
115
auto graph = new SSAGraph ();
115
116
SSAGraph &result = *graph;
116
117
std::unordered_set<std::string> og_has_been_broadcast;
@@ -120,13 +121,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
120
121
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
121
122
places_.size ());
122
123
123
- size_t cur_dev_id = 0 ;
124
- std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
125
- std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
126
-
127
- sparse_var_name_on_devices.resize (places_.size ());
128
- bcast_sparse_var_name_set.resize (places_.size ());
129
-
130
124
// Find "send" op first for split is in front of send.
131
125
OpDesc *send_op = GetSendOpDesc (program);
132
126
@@ -145,27 +139,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
145
139
}
146
140
is_forwarding = false ;
147
141
} else {
148
- int op_dev_id = GetOpDeviceID (sparse_var_name_on_devices, *op);
149
- if (op_dev_id == -1 ) { // var on all device
150
- CreateComputationalOps (&result, *op, places_.size ());
151
- } else {
152
- CreateComputationalOp (&result, *op, op_dev_id);
153
- for (auto &var_name : op->OutputArgumentNames ()) {
154
- sparse_var_name_on_devices[op_dev_id].emplace (var_name);
155
- }
156
- }
157
-
142
+ CreateComputationalOps (&result, *op, places_.size ());
158
143
if (!is_forwarding && places_.size () > 1 ) {
159
144
// Currently, we assume that once gradient is generated, it can be
160
145
// broadcast, and each gradient is only broadcast once.
161
146
for (auto &og : op->OutputArgumentNames ()) {
162
147
if (IsParameterGradientOnce (og, &og_has_been_broadcast)) {
163
148
if (IsSparseGradient (var_types, og)) {
164
- CreateReduceOp (&result, cur_dev_id, og);
165
- sparse_var_name_on_devices[cur_dev_id].emplace (og);
166
- bcast_sparse_var_name_set[cur_dev_id].emplace (
167
- og.substr (0 , og.size () - strlen (kGradVarSuffix )));
168
- cur_dev_id = (cur_dev_id + 1 ) % places_.size ();
149
+ CreateReduceOp (&result, og, 0 );
150
+ CreateBroadcastOp (&result, og, 0 );
169
151
} else {
170
152
InsertNCCLAllReduceOp (&result, og);
171
153
}
@@ -175,14 +157,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
175
157
}
176
158
}
177
159
178
- // Insert BCast Ops
179
- for (size_t dev_id = 0 ; dev_id < bcast_sparse_var_name_set.size (); ++dev_id) {
180
- auto &to_bcast_set = bcast_sparse_var_name_set[dev_id];
181
- for (auto &bcast_name : to_bcast_set) {
182
- CreateBroadcastOp (&result, bcast_name, dev_id);
183
- }
184
- }
185
-
186
160
/*
187
161
Dependency graph has been constructed. However, there are still data
188
162
harzaeds need to be handled.
@@ -213,38 +187,21 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
213
187
return false ;
214
188
}
215
189
216
- int MultiDevSSAGraphBuilder::GetOpDeviceID (
217
- const std::vector<std::unordered_set<std::string>>
218
- &sparse_var_name_on_devices,
219
- const OpDesc &op) const {
220
- int var_dev_id = -1 ;
221
- for (auto &var_name : op.InputArgumentNames ()) {
222
- if (var_dev_id != -1 ) break ;
223
- for (size_t i = 0 ; i < sparse_var_name_on_devices.size (); ++i) {
224
- if (sparse_var_name_on_devices[i].count (var_name)) {
225
- var_dev_id = static_cast <int >(i);
226
- break ;
227
- }
228
- }
229
- }
230
- return var_dev_id;
231
- }
232
-
233
190
void MultiDevSSAGraphBuilder::CreateBroadcastOp (SSAGraph *result,
234
191
const std::string &p_name,
235
- size_t dev_id ) const {
192
+ size_t src_dev_id ) const {
236
193
#ifdef PADDLE_WITH_CUDA
237
194
auto *op_handle = new BroadcastOpHandle (local_scopes_, places_, nccl_ctxs_);
238
195
#else
239
196
auto *op_handle = new BroadcastOpHandle (local_scopes_, places_);
240
197
#endif
241
198
242
199
result->ops_ .emplace_back (op_handle);
243
- auto *in = result->vars_ .at (dev_id ).at (p_name).back ().get ();
200
+ auto *in = result->vars_ .at (src_dev_id ).at (p_name).back ().get ();
244
201
op_handle->AddInput (in);
245
202
246
203
for (size_t i = 0 ; i < places_.size (); ++i) {
247
- auto &vars = result->vars_ .at (dev_id ).at (p_name);
204
+ auto &vars = result->vars_ .at (i ).at (p_name);
248
205
auto &p = places_[i];
249
206
auto *out_var = new VarHandle (vars.size (), i, p_name, p);
250
207
vars.emplace_back (out_var);
@@ -345,8 +302,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
345
302
}
346
303
}
347
304
348
- VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp (
349
- SSAGraph *result, int dst_dev_id, const std::string &og) const {
305
+ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp (SSAGraph *result,
306
+ const std::string &og,
307
+ int dst_dev_id) const {
350
308
#ifdef PADDLE_WITH_CUDA
351
309
result->ops_ .emplace_back (
352
310
new ReduceOpHandle (local_scopes_, places_, nccl_ctxs_));
0 commit comments