@@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
107
107
108
108
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
109
109
const ProgramDesc &program) const {
110
+ std::unordered_map<std::string, proto::VarType::Type> var_types;
111
+ for (auto *var : program.Block (0 ).AllVars ()) {
112
+ var_types[var->Name ()] = var->GetType ();
113
+ }
110
114
auto graph = new SSAGraph ();
111
115
SSAGraph &result = *graph;
112
116
std::unordered_set<std::string> og_has_been_broadcast;
@@ -116,7 +120,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
116
120
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
117
121
places_.size ());
118
122
119
- size_t cur_update_sparse_gp_dev_id = 0 ;
123
+ size_t cur_dev_id = 0 ;
120
124
std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
121
125
std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
122
126
@@ -156,14 +160,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
156
160
// broadcast, and each gradient is only broadcast once.
157
161
for (auto &og : op->OutputArgumentNames ()) {
158
162
if (IsParameterGradientOnce (og, &og_has_been_broadcast)) {
159
- if (IsSparseGradient (og)) {
160
- CreateReduceOp (&result, cur_update_sparse_gp_dev_id, og);
161
- sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace (
162
- og);
163
- bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace (
163
+ if (IsSparseGradient (var_types, og)) {
164
+ CreateReduceOp (&result, cur_dev_id, og);
165
+ sparse_var_name_on_devices[cur_dev_id].emplace (og);
166
+ bcast_sparse_var_name_set[cur_dev_id].emplace (
164
167
og.substr (0 , og.size () - strlen (kGradVarSuffix )));
165
- cur_update_sparse_gp_dev_id =
166
- (cur_update_sparse_gp_dev_id + 1 ) % places_.size ();
168
+ cur_dev_id = (cur_dev_id + 1 ) % places_.size ();
167
169
} else {
168
170
InsertNCCLAllReduceOp (&result, og);
169
171
}
@@ -201,10 +203,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
201
203
return std::unique_ptr<SSAGraph>(graph);
202
204
}
203
205
204
- bool MultiDevSSAGraphBuilder::IsSparseGradient (const std::string &og) const {
205
- auto og_var = local_scopes_[0 ]->FindVar (og);
206
- PADDLE_ENFORCE_NOT_NULL (og_var);
207
- return og_var->IsType <SelectedRows>();
206
+ bool MultiDevSSAGraphBuilder::IsSparseGradient (
207
+ const std::unordered_map<std::string, proto::VarType::Type> &var_types,
208
+ const std::string &og) const {
209
+ PADDLE_ENFORCE (var_types.count (og) != 0 );
210
+ if (var_types.at (og) == proto::VarType::SELECTED_ROWS) {
211
+ return true ;
212
+ }
213
+ return false ;
208
214
}
209
215
210
216
int MultiDevSSAGraphBuilder::GetOpDeviceID (
0 commit comments