@@ -145,12 +145,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
145145 } else if (IsDistTrainOp (*op, send_op)) {
146146 CreateComputationalOps (&result, *op, 1 );
147147 } else if (IsScaleLossOp (*op)) {
148+ CreateComputationalOps (&result, *op, places_.size ());
148149 // user can customize loss@grad if not use_default_grad_scale_
149150 if (use_default_grad_scale_) {
150151 CreateScaleLossGradOp (&result);
151152 }
152153 is_forwarding = false ;
153154 } else {
155+ if (IsScaleLossGradOp (*op)) continue ;
154156 int op_dev_id = GetOpDeviceID (var_name_on_devices, *op);
155157 if (op_dev_id == -1 ) { // var on all device
156158 CreateComputationalOps (&result, *op, places_.size ());
@@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
399401}
400402
401403bool MultiDevSSAGraphBuilder::IsScaleLossOp (const OpDesc &op) const {
404+ // FIXME(yy): Do not hard code like this
405+ return op.OutputArgumentNames ().size () == 1 &&
406+ (op.OutputArgumentNames ()[0 ]) == loss_var_name_;
407+ }
408+
409+ bool MultiDevSSAGraphBuilder::IsScaleLossGradOp (const OpDesc &op) const {
402410 // FIXME(yy): Do not hard code like this
403411 return op.OutputArgumentNames ().size () == 1 &&
404412 op.OutputArgumentNames ()[0 ] == GradVarName (loss_var_name_);
0 commit comments