@@ -199,6 +199,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
199
199
BuildStrategy::GradientScaleStrategy::kCustomized ) {
200
200
CreateScaleLossGradOp (&result);
201
201
}
202
+ // This assumes the backward generating code will ensure IsScaleLossOp
203
+ // is true only for the op that scale the final scalar loss.
204
+ // It also assumes backward op will always follow the forward op in
205
+ // the block.
202
206
is_forwarding = false ;
203
207
} else {
204
208
int op_dev_id = GetOpDeviceID (var_name_on_devices, *op);
@@ -243,6 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
243
247
InsertAllReduceOp (&result, g_name);
244
248
}
245
249
break ;
250
+ default :
251
+ LOG (FATAL) << " Unknown reduce strategy " ;
252
+ break ;
246
253
}
247
254
}
248
255
} catch (boost::bad_get e) {
@@ -261,7 +268,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
261
268
}
262
269
/*
263
270
Dependency graph has been constructed. However, there are still data
264
- harzaeds need to be handled.
271
+ hazards need to be handled.
265
272
*/
266
273
PolishGraphToSupportDataHazards (&result);
267
274
@@ -449,6 +456,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
449
456
return var;
450
457
}
451
458
459
+ // Find the first occurence of `prev_op_name` and make current `op` depend
460
+ // on it.
452
461
void MultiDevSSAGraphBuilder::ConnectOp (SSAGraph *result, OpHandleBase *op,
453
462
const std::string &prev_op_name) const {
454
463
for (auto &prev_op : result->ops_ ) {
@@ -469,6 +478,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
469
478
}
470
479
}
471
480
481
+ // Create RPC related op handles that connects its in ops and out ops.
472
482
void MultiDevSSAGraphBuilder::CreateRPCOp (SSAGraph *result,
473
483
const OpDesc &op) const {
474
484
result->ops_ .emplace_back (
0 commit comments