@@ -326,7 +326,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
326
326
ir::Graph &result = *graph;
327
327
328
328
for (auto &node : nodes) {
329
- if (node->NodeType () == ir::Node::Type:: kVariable && node->Var ()) {
329
+ if (node->IsVar () && node->Var ()) {
330
330
all_vars_.emplace (node->Name (), node->Var ());
331
331
}
332
332
}
@@ -583,18 +583,6 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
583
583
}
584
584
}
585
585
586
- bool MultiDevSSAGraphBuilder::IsParameterGradientOnce (
587
- const std::string &og,
588
- std::unordered_set<std::string> *og_has_been_broadcast) const {
589
- bool is_pg_once =
590
- grad_names_.count (og) != 0 && og_has_been_broadcast->count (og) == 0 ;
591
- if (is_pg_once) {
592
- // Insert NCCL AllReduce Op
593
- og_has_been_broadcast->insert (og);
594
- }
595
- return is_pg_once;
596
- }
597
-
598
586
int MultiDevSSAGraphBuilder::GetOpDeviceID (const ir::Graph &graph,
599
587
ir::Node *node) const {
600
588
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce ) {
@@ -688,20 +676,6 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
688
676
return var;
689
677
}
690
678
691
- // Find the first occurence of `prev_op_name` and make current `op` depend
692
- // on it.
693
- void MultiDevSSAGraphBuilder::ConnectOp (ir::Graph *result, OpHandleBase *op,
694
- const std::string &prev_op_name) const {
695
- for (auto &prev_op : result->Get <GraphOps>(kGraphOps )) {
696
- if (prev_op->Name () == prev_op_name) {
697
- auto *dep_var = new DummyVarHandle (result->CreateControlDepVar ());
698
- prev_op->AddOutput (dep_var);
699
- result->Get <GraphDepVars>(kGraphDepVars ).emplace (dep_var);
700
- op->AddInput (dep_var);
701
- }
702
- }
703
- }
704
-
705
679
void MultiDevSSAGraphBuilder::CreateDistTrainOp (ir::Graph *result,
706
680
ir::Node *node) const {
707
681
int op_dev_id = -1 ;
0 commit comments