@@ -31,6 +31,7 @@ static constexpr char kParallelScopes[] = "parallel_scopes";
31
31
static constexpr char kParallelBlock [] = " sub_block" ;
32
32
33
33
using LoDTensor = framework::LoDTensor;
34
+ using SelectedRows = framework::SelectedRows;
34
35
35
36
static void SplitTensorAndMoveTensorToScopes (
36
37
const framework::Scope &scope, std::vector<framework::Scope *> *sub_scopes,
@@ -64,6 +65,30 @@ static void SplitTensorAndMoveTensorToScopes(
64
65
}
65
66
}
66
67
68
+ inline void CopyOrShare (const framework::Variable &src,
69
+ const platform::Place &dst_place,
70
+ framework::Variable *dst) {
71
+ if (src.IsType <LoDTensor>()) {
72
+ if (src.Get <LoDTensor>().place () == dst_place) {
73
+ dst->GetMutable <LoDTensor>()->ShareDataWith (src.Get <LoDTensor>());
74
+ } else {
75
+ Copy (src.Get <LoDTensor>(), dst_place, dst->GetMutable <LoDTensor>());
76
+ }
77
+ } else if (src.IsType <SelectedRows>()) {
78
+ auto &src_sr = src.Get <SelectedRows>();
79
+ auto *dst_sr = dst->GetMutable <SelectedRows>();
80
+ dst_sr->set_rows (src_sr.rows ());
81
+ dst_sr->set_height (src_sr.height ());
82
+ if (src_sr.value ().place () == dst_place) {
83
+ dst_sr->mutable_value ()->ShareDataWith (src_sr.value ());
84
+ } else {
85
+ Copy (src_sr.value (), dst_place, dst_sr->mutable_value ());
86
+ }
87
+ } else {
88
+ PADDLE_THROW (" Expect LoDTensor/SelectedRows, get %s" , src.Type ().name ());
89
+ }
90
+ }
91
+
67
92
void WaitOnPlace (const platform::Place place) {
68
93
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
69
94
auto &dev_ctx = *pool.Get (place);
@@ -210,30 +235,30 @@ class ParallelDoGradOp : public framework::OperatorBase {
210
235
}
211
236
WaitOnPlaces (places);
212
237
213
- // merge grad
238
+ AccumulateGrad (scope, place, sub_scopes, places);
239
+ }
240
+
241
+ void AccumulateGrad (const framework::Scope &scope,
242
+ const platform::Place &place,
243
+ const std::vector<framework::Scope *> &sub_scopes,
244
+ const platform::PlaceList &places) const {
214
245
for (auto &s : Outputs (framework::GradVarName (kParameters ))) {
215
- auto &result = sub_scopes[0 ]->FindVar (s)->Get <LoDTensor>();
216
246
std::string tmp_name;
217
- auto *tmp = sub_scopes[0 ]->Var (&tmp_name)-> GetMutable <LoDTensor>() ;
247
+ auto *tmp = sub_scopes[0 ]->Var (&tmp_name);
218
248
219
249
for (size_t i = 1 ; i < sub_scopes.size (); ++i) {
220
- auto &tensor_to_merge = sub_scopes[i]->FindVar (s)->Get <LoDTensor>();
221
- if (!(places[i] == places[0 ])) {
222
- framework::Copy (tensor_to_merge, places[0 ], tmp);
223
- WaitOnPlace (places[0 ]);
224
- } else {
225
- tmp->ShareDataWith (tensor_to_merge);
226
- }
250
+ CopyOrShare (*sub_scopes[i]->FindVar (s), places[0 ], tmp);
251
+ WaitOnPlace (places[0 ]);
227
252
228
253
auto sum_op = framework::OpRegistry::CreateOp (
229
254
" sum" , {{" X" , {s, tmp_name}}}, {{" Out" , {s}}},
230
255
framework::AttributeMap{});
256
+ VLOG (3 ) << sum_op->DebugStringEx (sub_scopes[0 ]);
231
257
sum_op->Run (*sub_scopes[0 ], places[0 ]);
232
258
WaitOnPlace (places[0 ]);
233
259
}
234
260
235
- VLOG (3 ) << result;
236
- framework::Copy (result, place, scope.FindVar (s)->GetMutable <LoDTensor>());
261
+ CopyOrShare (*sub_scopes[0 ]->FindVar (s), place, scope.FindVar (s));
237
262
}
238
263
WaitOnPlaces (places);
239
264
}
@@ -289,7 +314,7 @@ class ParallelDoGradOpShapeInference : public framework::InferShapeBase {
289
314
290
315
PADDLE_ENFORCE (ctx->HasInputs (kParameters ));
291
316
PADDLE_ENFORCE (ctx->HasOutputs (framework::GradVarName (kParameters )));
292
- PADDLE_ENFORCE (ctx->HasInput (kInputs ));
317
+ PADDLE_ENFORCE (ctx->HasInputs (kInputs ));
293
318
294
319
for (auto &s : output) {
295
320
PADDLE_ENFORCE (ctx->HasInputs (s));
0 commit comments