Skip to content

Commit 1e1b662

Browse files
committed
update by comment
1 parent b084dfa commit 1e1b662

File tree

4 files changed

+5
-23
lines changed

4 files changed

+5
-23
lines changed

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
4646
#endif
4747

4848
void AllReduceOpHandle::RunImpl() {
49-
if (dev_ctxes_.size() > 0UL) {
50-
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
51-
} else {
52-
platform::RecordEvent record_event(Name(), nullptr);
53-
}
49+
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
5450

5551
if (NoDummyInputSize() == 1) {
5652
return; // No need to all reduce when GPU count = 1;

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@ namespace framework {
2222
namespace details {
2323

2424
void BroadcastOpHandle::RunImpl() {
25-
if (dev_ctxes_.size() > 0UL) {
26-
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
27-
} else {
28-
platform::RecordEvent record_event(Name(), nullptr);
29-
}
25+
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
3026

3127
if (places_.size() == 1) return;
3228

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,6 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
8787
}
8888

8989
void DataBalanceOpHandle::RunImpl() {
90-
if (dev_ctxes_.size() > 0UL) {
91-
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
92-
} else {
93-
platform::RecordEvent record_event(Name(), nullptr);
94-
}
95-
9690
PADDLE_ENFORCE_GT(places_.size(), 1,
9791
"Data balance can only be enabled when the number of "
9892
"places to run larger than 1.");

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
431431
CreateReduceOp(&result, g_name, cur_device_id);
432432
graph->Get<ShardedVarDevice>(kShardedVarDevice)
433433
.emplace(g_name, cur_device_id);
434-
if (!is_dist_train) {
435-
// will send gradients directly when distributed training
436-
bcast_var_name_set[cur_device_id].emplace(p_name);
437-
}
434+
bcast_var_name_set[cur_device_id].emplace(p_name);
438435
break;
439436
case BuildStrategy::ReduceStrategy::kAllReduce:
440437
if (IsSparseGradient(g_name)) {
@@ -461,9 +458,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
461458
use_gpu = nccl_ctxs_ != nullptr;
462459
#endif
463460

464-
if ((use_gpu &&
465-
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
466-
is_dist_train) {
461+
if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce &&
462+
!is_dist_train) {
467463
// Insert BCast Ops
468464
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
469465
auto &to_bcast_set = bcast_var_name_set[dev_id];

0 commit comments

Comments
 (0)