@@ -55,6 +55,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
55
55
const ProgramDesc &program) const {
56
56
auto graph = new SSAGraph ();
57
57
SSAGraph &result = *graph;
58
+ std::unordered_set<std::string> og_has_been_broadcast;
58
59
result.vars_ .resize (places_.size ());
59
60
60
61
bool is_forwarding = true ;
@@ -122,9 +123,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
122
123
123
124
if (!is_forwarding) {
124
125
auto var_names = op->OutputArgumentNames ();
126
+ // Currently, we assume that once gradient is generated, it can be
127
+ // broadcast, and each gradient is only broadcast once. But there are no
128
+ // other cases, for example, we need to adjust the gradient according to
129
+ // the input when we get the gradient, which is not considered at present.
125
130
for (auto &og : var_names) {
126
- if (grad_names_.count (og) != 0 ) { // is param grad
127
- // Insert NCCL AllReduce Op
131
+ if (grad_names_.count (og) != 0 &&
132
+ og_has_been_broadcast.count (og) == 0 ) { // is param grad
133
+ // Insert NCCL AllReduce Op
134
+ og_has_been_broadcast.insert (og);
128
135
#ifdef PADDLE_WITH_CUDA
129
136
result.ops_ .emplace_back (
130
137
new NCCLAllReduceOpHandle (local_scopes_, places_, *nccl_ctxs_));
0 commit comments