@@ -89,101 +89,25 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
89
89
90
90
bool is_forwarding = true ;
91
91
for (auto *op : program.Block (0 ).AllOps ()) {
92
- bool change_forward = false ;
93
- if (!is_forwarding) {
94
- // FIXME(yy): Do not hard code like this
95
- if (op->OutputArgumentNames ().size () == 1 &&
96
- op->OutputArgumentNames ()[0 ] == GradVarName (loss_var_name_)) {
97
- continue ; // Drop fill 1. for backward coeff;
98
- }
99
- }
100
-
101
- // append send op if program is distributed trainer main program.
102
- // always use the first device
103
- if (!is_forwarding && op->Type () == " send" ) {
104
- auto &p = places_[0 ];
105
- auto *s = local_scopes_[0 ];
106
- // FIXME(wuyi): send op always copy from GPU 0
107
- result.ops_ .emplace_back (new SendOpHandle (*op, s, p));
108
- // Create inputs for output on original place and no ssa output
109
- // is created for send op.
110
- CreateOpHandleIOs (&result, *op, p, 0 );
111
- continue ;
112
- }
113
-
114
- for (size_t i = 0 ; i < places_.size (); ++i) {
115
- auto &p = places_[i];
116
- auto *s = local_scopes_[i];
117
-
118
- result.ops_ .emplace_back (new ComputationOpHandle (*op, s, p));
119
- auto *op_handle = result.ops_ .back ().get ();
120
- CreateOpHandleIOs (&result, *op, p, i);
121
-
122
- auto var_names = op->OutputArgumentNames ();
123
-
124
- if (is_forwarding) {
125
- if (var_names.size () == 1 && var_names[0 ] == loss_var_name_) {
126
- // Insert ScaleCost OpHandle
127
- #ifdef PADDLE_WITH_CUDA
128
- auto *communication_dev_ctx = nccl_ctxs_->DevCtx (p);
129
- #else
130
- auto *communication_dev_ctx =
131
- platform::DeviceContextPool::Instance ().Get (platform::CPUPlace ());
132
- #endif
133
-
134
- op_handle = new ScaleLossGradOpHandle (local_scopes_.size (), s, p,
135
- communication_dev_ctx);
136
- result.ops_ .emplace_back (op_handle);
137
-
138
- // FIXME: Currently ScaleLossGradOp only use device_count as scale
139
- // factor. So it does not depend on any other operators.
140
- // VarHandle *loss = GetVarHandle(loss_var_name, place);
141
- // loss->pending_ops_.emplace_back(op_handle);
142
- // op_handle->inputs_.emplace_back(loss);
143
-
144
- CreateOpOutput (&result, op_handle, GradVarName (loss_var_name_), p, i);
145
- change_forward = true ;
146
- }
147
- }
148
- }
149
-
150
- if (change_forward) {
92
+ if (op->Type () == " send" ) {
93
+ // append send op if program is distributed trainer main program.
94
+ // always use the first device
95
+ CreateSendOp (&result, *op);
96
+ } else if (IsScaleLossOp (*op)) {
97
+ CreateScaleLossGradOp (&result);
151
98
is_forwarding = false ;
152
- }
153
-
154
- if (!is_forwarding) {
155
- auto var_names = op->OutputArgumentNames ();
156
- // Currently, we assume that once gradient is generated, it can be
157
- // broadcast, and each gradient is only broadcast once. But there are no
158
- // other cases, for example, we need to adjust the gradient according to
159
- // the input when we get the gradient, which is not considered at present.
160
- for (auto &og : var_names) {
161
- if (grad_names_.count (og) != 0 &&
162
- og_has_been_broadcast.count (og) == 0 ) { // is param grad
163
- // Insert NCCL AllReduce Op
164
- og_has_been_broadcast.insert (og);
165
- #ifdef PADDLE_WITH_CUDA
166
- result.ops_ .emplace_back (
167
- new NCCLAllReduceOpHandle (local_scopes_, places_, *nccl_ctxs_));
168
- auto *op_handle = result.ops_ .back ().get ();
169
-
170
- for (size_t i = 0 ; i < places_.size (); ++i) {
171
- auto &p = places_[i];
172
- auto &vars = result.vars_ [i][og];
173
-
174
- if (vars.empty ()) { // This device has no data. continue.
175
- continue ;
176
- }
177
- auto &prev_grad = vars[vars.size () - 1 ];
178
- op_handle->AddInput (prev_grad.get ());
179
-
180
- auto var = new VarHandle (vars.size () - 1 , i, og, p);
181
- vars.emplace_back (var);
182
- op_handle->AddOutput (var);
99
+ } else {
100
+ CreateComputationalOps (&result, *op);
101
+ if (!is_forwarding) {
102
+ // Currently, we assume that once gradient is generated, it can be
103
+ // broadcast, and each gradient is only broadcast once. But there are no
104
+ // other cases, for example, we need to adjust the gradient according to
105
+ // the input when we get the gradient, which is not considered at
106
+ // present.
107
+ for (auto &og : op->OutputArgumentNames ()) {
108
+ if (IsParameterGradientOnce (og, &og_has_been_broadcast)) {
109
+ InsertNCCLAllReduceOp (&result, og);
183
110
}
184
- #else
185
- PADDLE_ENFORCE (" Not implemented" );
186
- #endif
187
111
}
188
112
}
189
113
}
@@ -207,7 +131,95 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
207
131
}
208
132
209
133
return std::unique_ptr<SSAGraph>(graph);
210
- } // namespace details
134
+ }
135
+
136
+ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp (
137
+ SSAGraph *result, const std::string &og) const {
138
+ #ifdef PADDLE_WITH_CUDA
139
+ result->ops_ .emplace_back (
140
+ new NCCLAllReduceOpHandle (local_scopes_, places_, *nccl_ctxs_));
141
+ auto *op_handle = result->ops_ .back ().get ();
142
+
143
+ for (size_t i = 0 ; i < places_.size (); ++i) {
144
+ auto &p = places_[i];
145
+ auto &vars = result->vars_ [i][og];
146
+ PADDLE_ENFORCE (!vars.empty ());
147
+ auto &prev_grad = vars.back ();
148
+ op_handle->AddInput (prev_grad.get ());
149
+
150
+ auto var = new VarHandle (vars.size () - 1 , i, og, p);
151
+ vars.emplace_back (var);
152
+ op_handle->AddOutput (var);
153
+ }
154
+ #else
155
+ PADDLE_ENFORCE (" Not implemented" );
156
+ #endif
157
+ }
158
+
159
+ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce (
160
+ const std::string &og,
161
+ std::unordered_set<std::string> *og_has_been_broadcast) const {
162
+ bool is_pg_once =
163
+ grad_names_.count (og) != 0 && og_has_been_broadcast->count (og) == 0 ;
164
+ if (is_pg_once) {
165
+ // Insert NCCL AllReduce Op
166
+ og_has_been_broadcast->insert (og);
167
+ }
168
+ return is_pg_once;
169
+ }
170
+
171
+ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp (SSAGraph *result) const {
172
+ for (size_t i = 0 ; i < places_.size (); ++i) {
173
+ // Insert ScaleCost OpHandle
174
+ #ifdef PADDLE_WITH_CUDA
175
+ auto *communication_dev_ctx = nccl_ctxs_->DevCtx (places_[i]);
176
+ #else
177
+ auto *communication_dev_ctx =
178
+ platform::DeviceContextPool::Instance ().Get (platform::CPUPlace ());
179
+ #endif
180
+
181
+ auto *op_handle =
182
+ new ScaleLossGradOpHandle (local_scopes_.size (), local_scopes_[i],
183
+ places_[i], communication_dev_ctx);
184
+ result->ops_ .emplace_back (op_handle);
185
+
186
+ // FIXME: Currently ScaleLossGradOp only use device_count as scale
187
+ // factor. So it does not depend on any other operators.
188
+ // VarHandle *loss = GetVarHandle(loss_var_name, place);
189
+ // loss->pending_ops_.emplace_back(op_handle);
190
+ // op_handle->inputs_.emplace_back(loss);
191
+
192
+ CreateOpOutput (result, op_handle, GradVarName (loss_var_name_), places_[i],
193
+ i);
194
+ }
195
+ }
196
+
197
+ void MultiDevSSAGraphBuilder::CreateComputationalOps (SSAGraph *result,
198
+ const OpDesc &op) const {
199
+ for (size_t scope_idx = 0 ; scope_idx < places_.size (); ++scope_idx) {
200
+ auto p = places_[scope_idx];
201
+ auto s = local_scopes_[scope_idx];
202
+ result->ops_ .emplace_back (new ComputationOpHandle (op, s, p));
203
+ CreateOpHandleIOs (result, op, p, scope_idx);
204
+ }
205
+ }
206
+
207
+ void MultiDevSSAGraphBuilder::CreateSendOp (SSAGraph *result,
208
+ const OpDesc &op) const {
209
+ auto &p = places_[0 ];
210
+ auto *s = local_scopes_[0 ];
211
+ // FIXME(wuyi): send op always copy from GPU 0
212
+ result->ops_ .emplace_back (new SendOpHandle (op, s, p));
213
+ // Create inputs for output on original place and no ssa output
214
+ // is created for send op.
215
+ CreateOpHandleIOs (result, op, p, 0 );
216
+ }
217
+
218
+ bool MultiDevSSAGraphBuilder::IsScaleLossOp (const OpDesc &op) const {
219
+ // FIXME(yy): Do not hard code like this
220
+ return op.OutputArgumentNames ().size () == 1 &&
221
+ op.OutputArgumentNames ()[0 ] == GradVarName (loss_var_name_);
222
+ }
211
223
} // namespace details
212
224
} // namespace framework
213
225
} // namespace paddle
0 commit comments