Skip to content

Commit e4de957

Browse files
committed
code refine
1 parent 3301d44 commit e4de957

File tree

3 files changed

+45
-33
lines changed

3 files changed

+45
-33
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,34 @@ cc_library(var_handle SRCS var_handle.cc DEPS place)
22
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
33
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
44
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
5-
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
6-
dynload_cuda)
75
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
86
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
97

108
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
119
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
1210

1311
if(WITH_GPU)
12+
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
13+
dynload_cuda)
1414
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
15+
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim dynload_cuda)
1516
else()
1617
set(multi_devices_graph_builder_deps)
18+
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim)
1719
endif()
1820
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
19-
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
21+
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
2022

2123
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
2224
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
2325
simple_threadpool device_context)
2426

2527
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
2628
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
27-
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim)
2829

2930
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
3031
device_context broadcast_op_handle)
3132
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
3233
device_context gather_op_handle)
3334
cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
34-
device_context reduce_op_handle)
35+
device_context reduce_op_handle )

paddle/fluid/framework/details/reduce_op_handle.cc

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,16 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/reduce_op_handle.h"
16-
#include "paddle/fluid/framework/details/gather_op_handle.h"
1716
#include "paddle/fluid/framework/details/reduce_and_gather.h"
18-
#include "paddle/fluid/platform/nccl_helper.h"
1917

2018
namespace paddle {
2119
namespace framework {
2220
namespace details {
2321

24-
std::vector<VarHandle *> GetValidVarHandle(
25-
const std::vector<VarHandleBase *> &inputs) {
26-
std::vector<VarHandle *> in_var_handles;
27-
for (auto *in : inputs) {
28-
auto *in_handle = dynamic_cast<VarHandle *>(in);
29-
if (in_handle) {
30-
in_var_handles.push_back(in_handle);
31-
}
32-
}
33-
return in_var_handles;
34-
}
35-
3622
void ReduceOpHandle::RunImpl() {
3723
// the input and output may have dummy var.
38-
std::vector<VarHandle *> in_var_handles = GetValidVarHandle(inputs_);
39-
std::vector<VarHandle *> out_var_handles = GetValidVarHandle(outputs_);
24+
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
25+
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
4026

4127
PADDLE_ENFORCE_EQ(
4228
in_var_handles.size(), places_.size(),
@@ -45,15 +31,10 @@ void ReduceOpHandle::RunImpl() {
4531
"The number of output should be one.");
4632

4733
// Wait input done, this Wait is asynchronous operation
48-
if (in_var_handles[0]->generated_op_) {
49-
for (auto *in : in_var_handles) {
50-
auto &in_p = in->place_;
51-
in_var_handles[0]->generated_op_->Wait(dev_ctxes_[in_p]);
52-
}
53-
}
34+
WaitEvents(in_var_handles);
5435

5536
// check in the same place
56-
auto in_0_handle = static_cast<VarHandle *>(in_var_handles[0]);
37+
auto in_0_handle = in_var_handles[0];
5738
auto pre_place = in_0_handle->place_;
5839

5940
std::vector<platform::Place> in_places;
@@ -120,6 +101,7 @@ void ReduceOpHandle::RunImpl() {
120101
for (size_t i = 0; i < local_scopes_.size(); ++i) {
121102
auto &p = in_places[i];
122103
auto &lod_tensor = lod_tensors[i];
104+
123105
int dev_id = boost::get<platform::CUDAPlace>(p).device;
124106
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
125107
auto stream = nccl_ctx.stream();
@@ -139,18 +121,41 @@ void ReduceOpHandle::RunImpl() {
139121
});
140122
}
141123

142-
platform::NCCLGroupGuard guard;
143-
for (auto &call : all_reduce_calls) {
144-
call();
145-
}
124+
this->RunAndRecordEvent([&] {
125+
platform::NCCLGroupGuard guard;
126+
for (auto &call : all_reduce_calls) {
127+
call();
128+
}
129+
});
146130
#else
147131
PADDLE_THROW("CUDA is not support.");
148132
#endif
149133
} else {
150-
PADDLE_THROW("Error");
134+
PADDLE_THROW("Place should be CPUPlace or CUDAPlace.");
151135
}
152136
}
153137
}
138+
139+
void ReduceOpHandle::WaitEvents(
140+
const std::vector<VarHandle *> &in_var_handles) {
141+
if (in_var_handles[0]->generated_op_) {
142+
for (auto *in : in_var_handles) {
143+
in_var_handles[0]->generated_op_->Wait(dev_ctxes_[in->place_]);
144+
}
145+
}
146+
}
147+
148+
std::vector<VarHandle *> ReduceOpHandle::GetValidVarHandles(
149+
const std::vector<VarHandleBase *> &inputs) {
150+
std::vector<VarHandle *> in_var_handles;
151+
for (auto *in : inputs) {
152+
auto *in_handle = dynamic_cast<VarHandle *>(in);
153+
if (in_handle) {
154+
in_var_handles.push_back(in_handle);
155+
}
156+
}
157+
return in_var_handles;
158+
}
154159
std::string ReduceOpHandle::Name() const { return "reduce"; }
155160
} // namespace details
156161
} // namespace framework

paddle/fluid/framework/details/reduce_op_handle.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
#include "paddle/fluid/framework/scope.h"
2424
#include "paddle/fluid/framework/selected_rows.h"
2525
#include "paddle/fluid/platform/device_context.h"
26+
#ifdef PADDLE_WITH_CUDA
2627
#include "paddle/fluid/platform/nccl_helper.h"
28+
#endif
2729

2830
namespace paddle {
2931
namespace framework {
@@ -57,6 +59,10 @@ struct ReduceOpHandle : public OpHandleBase {
5759

5860
protected:
5961
void RunImpl() override;
62+
std::vector<VarHandle *> GetValidVarHandles(
63+
const std::vector<VarHandleBase *> &inputs);
64+
65+
void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
6066
};
6167

6268
} // namespace details

0 commit comments

Comments
 (0)