Skip to content

Commit 9ee1b7b

Browse files
committed
add some comments
1 parent bad4ea1 commit 9ee1b7b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
431431
CreateReduceOp(&result, g_name, cur_device_id);
432432
graph->Get<ShardedVarDevice>(kShardedVarDevice)
433433
.emplace(g_name, cur_device_id);
434-
bcast_var_name_set[cur_device_id].emplace(p_name);
434+
if (!is_dist_train) {
435+
bcast_var_name_set[cur_device_id].emplace(p_name);
436+
}
435437
break;
436438
case BuildStrategy::ReduceStrategy::kAllReduce:
437439
if (IsSparseGradient(g_name)) {
@@ -461,7 +463,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
461463
if ((use_gpu &&
462464
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
463465
is_dist_train) {
464-
// Insert BCast Ops
466+
// allways broadcast receieved parameters for distributed training
465467
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
466468
auto &to_bcast_set = bcast_var_name_set[dev_id];
467469
for (auto &bcast_name : to_bcast_set) {

0 commit comments

Comments
 (0)