@@ -55,7 +55,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
55
55
const ProgramDesc &program) const {
56
56
auto graph = new SSAGraph ();
57
57
SSAGraph &result = *graph;
58
- std::unordered_set<std::string> og_has_bc ;
58
+ std::unordered_set<std::string> og_has_been_broadcast ;
59
59
result.vars_ .resize (places_.size ());
60
60
61
61
bool is_forwarding = true ;
@@ -123,11 +123,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
123
123
124
124
if (!is_forwarding) {
125
125
auto var_names = op->OutputArgumentNames ();
126
+ // Currently, we assume that once gradient is generated, it can be
127
+ // broadcast, and each gradient is only broadcast once. But there are no
128
+ // other cases, for example, we need to adjust the gradient according to
129
+ // the input when we get the gradient, which is not considered at present.
126
130
for (auto &og : var_names) {
127
131
if (grad_names_.count (og) != 0 &&
128
- og_has_bc .count (og) == 0 ) { // is param grad
129
- // Insert NCCL AllReduce Op
130
- og_has_bc .insert (og);
132
+ og_has_been_broadcast .count (og) == 0 ) { // is param grad
133
+ // Insert NCCL AllReduce Op
134
+ og_has_been_broadcast .insert (og);
131
135
#ifdef PADDLE_WITH_CUDA
132
136
result.ops_ .emplace_back (
133
137
new NCCLAllReduceOpHandle (local_scopes_, places_, *nccl_ctxs_));
0 commit comments