@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
37
37
const std::string &loss_var_name,
38
38
const std::unordered_set<std::string> ¶ms,
39
39
const std::vector<Scope *> &local_scopes,
40
- platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale)
40
+ platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
41
+ bool balance_parameter_opt_between_cards)
41
42
: loss_var_name_(loss_var_name),
42
43
places_(places),
43
44
local_scopes_(local_scopes),
44
- nccl_ctxs_(nccl_ctxs) {
45
+ nccl_ctxs_(nccl_ctxs),
46
+ balance_parameter_opt_between_cards_(
47
+ balance_parameter_opt_between_cards) {
45
48
#else
46
49
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder (
47
50
const std::vector<platform::Place> &places,
48
51
const std::string &loss_var_name,
49
52
const std::unordered_set<std::string> ¶ms,
50
- const std::vector<Scope *> &local_scopes, bool use_default_grad_scale)
53
+ const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
54
+ bool balance_parameter_opt_between_cards)
51
55
: loss_var_name_ (loss_var_name),
52
56
places_ (places),
53
- local_scopes_ (local_scopes) {
57
+ local_scopes_ (local_scopes),
58
+ balance_parameter_opt_between_cards_ (
59
+ balance_parameter_opt_between_cards) {
54
60
#endif
55
61
for (auto &p : params) {
56
62
grad_names_.insert (GradVarName (p));
@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
124
130
// Find "send" op first for split is in front of send.
125
131
OpDesc *send_op = GetSendOpDesc (program);
126
132
133
+ size_t cur_device_id = 0 ;
134
+ std::vector<std::unordered_set<std::string>> var_name_on_devices;
135
+ std::vector<std::unordered_set<std::string>> bcast_var_name_set;
136
+ var_name_on_devices.resize (places_.size ());
137
+ bcast_var_name_set.resize (places_.size ());
138
+
127
139
bool is_forwarding = true ;
128
140
for (auto *op : program.Block (0 ).AllOps ()) {
129
141
if (op->Type () == " send" ) {
@@ -139,24 +151,47 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
139
151
}
140
152
is_forwarding = false ;
141
153
} else {
142
- CreateComputationalOps (&result, *op, places_.size ());
154
+ int op_dev_id = GetOpDeviceID (var_name_on_devices, *op);
155
+ if (op_dev_id == -1 ) { // var on all device
156
+ CreateComputationalOps (&result, *op, places_.size ());
157
+ } else {
158
+ CreateComputationalOp (&result, *op, op_dev_id);
159
+ for (auto &var_name : op->OutputArgumentNames ()) {
160
+ var_name_on_devices[op_dev_id].emplace (var_name);
161
+ }
162
+ }
143
163
if (!is_forwarding && places_.size () > 1 ) {
144
164
// Currently, we assume that once gradient is generated, it can be
145
165
// broadcast, and each gradient is only broadcast once.
146
166
for (auto &og : op->OutputArgumentNames ()) {
147
167
if (IsParameterGradientOnce (og, &og_has_been_broadcast)) {
148
- if (IsSparseGradient (var_types, og)) {
149
- CreateReduceOp (&result, og, 0 );
150
- CreateBroadcastOp (&result, og, 0 );
168
+ if (balance_parameter_opt_between_cards_) {
169
+ CreateReduceOp (&result, og, cur_device_id);
170
+ var_name_on_devices[cur_device_id].emplace (og);
171
+ bcast_var_name_set[cur_device_id].emplace (
172
+ og.substr (0 , og.size () - strlen (kGradVarSuffix )));
173
+ cur_device_id = (cur_device_id + 1 ) % places_.size ();
151
174
} else {
152
- InsertNCCLAllReduceOp (&result, og);
175
+ if (IsSparseGradient (var_types, og)) {
176
+ CreateReduceOp (&result, og, 0 );
177
+ CreateBroadcastOp (&result, og, 0 );
178
+ } else {
179
+ InsertNCCLAllReduceOp (&result, og);
180
+ }
153
181
}
154
182
}
155
183
}
156
184
}
157
185
}
158
186
}
159
187
188
+ // Insert BCast Ops
189
+ for (size_t dev_id = 0 ; dev_id < bcast_var_name_set.size (); ++dev_id) {
190
+ auto &to_bcast_set = bcast_var_name_set[dev_id];
191
+ for (auto &bcast_name : to_bcast_set) {
192
+ CreateBroadcastOp (&result, bcast_name, dev_id);
193
+ }
194
+ }
160
195
/*
161
196
Dependency graph has been constructed. However, there are still data
162
197
harzaeds need to be handled.
@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
265
300
return is_pg_once;
266
301
}
267
302
303
+ int MultiDevSSAGraphBuilder::GetOpDeviceID (
304
+ const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
305
+ const OpDesc &op) const {
306
+ if (!balance_parameter_opt_between_cards_) {
307
+ return -1 ;
308
+ }
309
+
310
+ int var_dev_id = -1 ;
311
+ for (auto &var_name : op.InputArgumentNames ()) {
312
+ if (var_dev_id != -1 ) break ;
313
+ for (size_t i = 0 ; i < var_name_on_devices.size (); ++i) {
314
+ if (var_name_on_devices[i].count (var_name)) {
315
+ var_dev_id = static_cast <int >(i);
316
+ break ;
317
+ }
318
+ }
319
+ }
320
+ return var_dev_id;
321
+ }
322
+
268
323
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp (SSAGraph *result) const {
269
324
for (size_t i = 0 ; i < places_.size (); ++i) {
270
325
// Insert ScaleCost OpHandle
0 commit comments