11
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
- #include " paddle/fluid/framework/details/multi_devices_graph_builder.h "
14
+ #include < algorithm >
15
15
#include < fstream>
16
+ #include < string>
16
17
#include < utility>
18
+ #include < vector>
19
+
17
20
#include " paddle/fluid/framework/details/broadcast_op_handle.h"
18
21
#include " paddle/fluid/framework/details/computation_op_handle.h"
22
+ #include " paddle/fluid/framework/details/multi_devices_graph_builder.h"
19
23
#include " paddle/fluid/framework/details/reduce_op_handle.h"
20
24
#include " paddle/fluid/framework/details/rpc_op_handle.h"
21
25
#include " paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
26
30
#include " paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
27
31
#endif
28
32
29
- #include < string>
30
- #include < vector>
31
-
32
33
DEFINE_string (ssa_graph_path, " /tmp/ssa_graph.dot" ,
33
34
" the ssa graph path only print with GLOG_v=10,"
34
35
" default /tmp/graph.dot" );
@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
148
149
149
150
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
150
151
const ProgramDesc &program) const {
151
- std::unordered_map<std::string, proto::VarType::Type> var_types ;
152
+ std::unordered_map<std::string, VarDesc *> all_vars ;
152
153
for (auto *var : program.Block (0 ).AllVars ()) {
153
- var_types [var->Name ()] = var-> GetType () ;
154
+ all_vars [var->Name ()] = var;
154
155
}
155
156
156
157
auto graph = new SSAGraph ();
@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
167
168
auto send_vars = FindDistTrainSendVars (program);
168
169
auto recv_vars = FindDistTrainRecvVars (program);
169
170
170
- size_t cur_device_id = 0 ;
171
171
std::vector<std::unordered_set<std::string>> var_name_on_devices;
172
172
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
173
173
var_name_on_devices.resize (places_.size ());
174
174
bcast_var_name_set.resize (places_.size ());
175
175
176
+ size_t cur_device_id = 0 ;
177
+ std::vector<int64_t > balance_grads (places_.size (), 0 );
178
+
179
+ auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
180
+ auto var_desc = all_vars.at (g_name);
181
+ PADDLE_ENFORCE_NOT_NULL (var_desc);
182
+ auto dim = framework::make_ddim (var_desc->GetShape ());
183
+ int64_t numel = framework::product (dim);
184
+ PADDLE_ENFORCE_GE (numel, 0 );
185
+ auto smallest =
186
+ std::min_element (std::begin (balance_grads), std::end (balance_grads));
187
+ size_t dev_id =
188
+ static_cast <size_t >(std::distance (std::begin (balance_grads), smallest));
189
+ balance_grads[dev_id] += numel;
190
+ return dev_id;
191
+ };
192
+
176
193
bool is_forwarding = true ;
177
194
for (auto *op : program.Block (0 ).AllOps ()) {
178
195
if (boost::get<int >(
@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
220
237
221
238
switch (strategy_.reduce_ ) {
222
239
case BuildStrategy::ReduceStrategy::kReduce :
240
+ cur_device_id = get_appropriate_dev (g_name);
223
241
CreateReduceOp (&result, g_name, cur_device_id);
224
242
var_name_on_devices[cur_device_id].emplace (g_name);
225
243
bcast_var_name_set[cur_device_id].emplace (p_name);
226
- cur_device_id = (cur_device_id + 1 ) % places_.size ();
227
244
break ;
228
245
case BuildStrategy::ReduceStrategy::kAllReduce :
229
- if (IsSparseGradient (var_types , g_name)) {
246
+ if (IsSparseGradient (all_vars , g_name)) {
230
247
CreateReduceOp (&result, g_name, 0 );
231
248
CreateBroadcastOp (&result, g_name, 0 );
232
249
} else {
@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
269
286
}
270
287
271
288
bool MultiDevSSAGraphBuilder::IsSparseGradient (
272
- const std::unordered_map<std::string, proto::VarType::Type > &var_types ,
289
+ const std::unordered_map<std::string, VarDesc * > &all_vars ,
273
290
const std::string &og) const {
274
- PADDLE_ENFORCE (var_types .count (og) != 0 );
275
- if (var_types .at (og) == proto::VarType::SELECTED_ROWS) {
291
+ PADDLE_ENFORCE (all_vars .count (og) != 0 );
292
+ if (all_vars .at (og)-> GetType ( ) == proto::VarType::SELECTED_ROWS) {
276
293
return true ;
277
294
}
278
295
return false ;
0 commit comments