@@ -200,6 +200,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
200
200
BuildStrategy::GradientScaleStrategy::kCustomized ) {
201
201
CreateScaleLossGradOp (&result);
202
202
}
203
+ // This assumes the backward generating code will ensure IsScaleLossOp
204
+ // is true only for the op that scale the final scalar loss.
205
+ // It also assumes backward op will always follow the forward op in
206
+ // the block.
203
207
is_forwarding = false ;
204
208
} else {
205
209
int op_dev_id = GetOpDeviceID (*op);
@@ -244,6 +248,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
244
248
InsertAllReduceOp (&result, g_name);
245
249
}
246
250
break ;
251
+ default :
252
+ LOG (FATAL) << " Unknown reduce strategy " ;
253
+ break ;
247
254
}
248
255
}
249
256
} catch (boost::bad_get e) {
@@ -262,7 +269,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
262
269
}
263
270
/*
264
271
Dependency graph has been constructed. However, there are still data
265
- harzaeds need to be handled.
272
+ hazards need to be handled.
266
273
*/
267
274
PolishGraphToSupportDataHazards (&result);
268
275
@@ -447,6 +454,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
447
454
return var;
448
455
}
449
456
457
+ // Find the first occurence of `prev_op_name` and make current `op` depend
458
+ // on it.
450
459
void MultiDevSSAGraphBuilder::ConnectOp (SSAGraph *result, OpHandleBase *op,
451
460
const std::string &prev_op_name) const {
452
461
for (auto &prev_op : result->ops_ ) {
@@ -490,6 +499,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
490
499
}
491
500
}
492
501
502
+ // Create RPC related op handles that connects its in ops and out ops.
493
503
void MultiDevSSAGraphBuilder::CreateRPCOp (SSAGraph *result,
494
504
const OpDesc &op) const {
495
505
int op_dev_id = -1 ;
0 commit comments