Skip to content

Commit eb8e14c

Browse files
author
chengduo
authored
Merge pull request #10081 from chengduoZH/refine/gather_broadcast
Fix scope of gather and broadcast, and code clean
2 parents acd7309 + 9a4ae4d commit eb8e14c

10 files changed

+190
-136
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,28 @@ cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope plac
88
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
99
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
1010

11+
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
12+
1113
if(WITH_GPU)
1214
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
1315
dynload_cuda)
1416
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
15-
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim dynload_cuda)
17+
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
1618
else()
1719
set(multi_devices_graph_builder_deps)
18-
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim)
20+
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
1921
endif()
22+
23+
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
24+
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
25+
2026
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
21-
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
27+
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
2228

2329
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
2430
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
2531
simple_threadpool device_context)
2632

27-
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
28-
29-
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
30-
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)
31-
3233
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
3334
device_context broadcast_op_handle)
3435
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,15 @@ void BroadcastOpHandle::RunImpl() {
4444
// &in_place;
4545
WaitInputVarGenerated(*in_var_handle);
4646

47-
auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
48-
->FindVar(in_var_handle->name_);
47+
std::vector<const Scope *> var_scopes;
48+
for (auto *s : local_scopes_) {
49+
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
50+
}
51+
52+
auto *in_var =
53+
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
4954
PADDLE_ENFORCE_NOT_NULL(in_var);
55+
5056
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5157

5258
for (auto *out : out_var_handles) {
@@ -55,17 +61,16 @@ void BroadcastOpHandle::RunImpl() {
5561
}
5662

5763
auto &out_p = out->place_;
58-
auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
59-
64+
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_);
65+
PADDLE_ENFORCE_NOT_NULL(out_var);
6066
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
6167
"Places must be all on CPU or all on CUDA.");
6268

6369
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
64-
VariableVisitor::GetMutableTensor(out_var)
65-
.Resize(in_tensor.dims())
66-
.mutable_data(out_p, in_tensor.type());
70+
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
71+
in_tensor.type());
6772

68-
auto dev_ctx = dev_ctxes_[out_p];
73+
auto dev_ctx = dev_ctxes_.at(out_p);
6974
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
7075
paddle::framework::TensorCopy(
7176
in_tensor, out_p, *(dev_ctx),

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ const f::DDim kDims = {20, 20};
3030
struct TestBroadcastOpHandle {
3131
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
3232
std::vector<Scope*> local_scopes_;
33+
std::vector<Scope*> param_scopes_;
3334
Scope g_scope_;
3435
std::unique_ptr<OpHandleBase> op_handle_;
3536
std::vector<std::unique_ptr<VarHandleBase>> vars_;
@@ -72,11 +73,17 @@ struct TestBroadcastOpHandle {
7273
void InitBroadcastOp(size_t input_scope_idx) {
7374
for (size_t j = 0; j < gpu_list_.size(); ++j) {
7475
local_scopes_.push_back(&(g_scope_.NewScope()));
75-
local_scopes_[j]->Var("out");
76+
Scope& local_scope = local_scopes_.back()->NewScope();
77+
*local_scopes_.back()
78+
->Var(details::kLocalExecScopeName)
79+
->GetMutable<Scope*>() = &local_scope;
80+
local_scope.Var("out");
81+
param_scopes_.emplace_back(&local_scope);
7682
}
77-
local_scopes_[input_scope_idx]->Var("input");
83+
param_scopes_[input_scope_idx]->Var("input");
7884

7985
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
86+
8087
auto* in_var_handle =
8188
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
8289
vars_.emplace_back(in_var_handle);
@@ -105,7 +112,8 @@ struct TestBroadcastOpHandle {
105112
}
106113

107114
void TestBroadcastLodTensor(size_t input_scope_idx) {
108-
auto in_var = local_scopes_[input_scope_idx]->Var("input");
115+
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
116+
PADDLE_ENFORCE_NOT_NULL(in_var);
109117
auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
110118
in_lod_tensor->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
111119

@@ -117,14 +125,16 @@ struct TestBroadcastOpHandle {
117125
paddle::framework::TensorFromVector<float>(
118126
send_vector, *(ctxs_[input_scope_idx]), in_lod_tensor);
119127
in_lod_tensor->set_lod(lod);
128+
in_lod_tensor->Resize(kDims);
120129

121130
op_handle_->Run(false);
122131

123132
WaitAll();
124133

125134
p::CPUPlace cpu_place;
126135
for (size_t j = 0; j < gpu_list_.size(); ++j) {
127-
auto out_var = local_scopes_[j]->Var("out");
136+
auto out_var = param_scopes_[j]->FindVar("out");
137+
PADDLE_ENFORCE_NOT_NULL(out_var);
128138
auto out_tensor = out_var->Get<f::LoDTensor>();
129139
PADDLE_ENFORCE_EQ(out_tensor.lod(), lod, "lod is not equal.");
130140

@@ -139,7 +149,8 @@ struct TestBroadcastOpHandle {
139149
}
140150

141151
void TestBroadcastSelectedRows(size_t input_scope_idx) {
142-
auto in_var = local_scopes_[input_scope_idx]->Var("input");
152+
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
153+
PADDLE_ENFORCE_NOT_NULL(in_var);
143154
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
144155
auto value = in_selected_rows->mutable_value();
145156
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
@@ -162,7 +173,8 @@ struct TestBroadcastOpHandle {
162173

163174
p::CPUPlace cpu_place;
164175
for (size_t j = 0; j < gpu_list_.size(); ++j) {
165-
auto out_var = local_scopes_[j]->Var("out");
176+
auto out_var = param_scopes_[j]->FindVar("out");
177+
PADDLE_ENFORCE_NOT_NULL(out_var);
166178
auto& out_select_rows = out_var->Get<f::SelectedRows>();
167179
auto rt = out_select_rows.value();
168180

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,19 @@ void GatherOpHandle::RunImpl() {
4141
out_var_handle = out_var_handles.front();
4242
}
4343

44+
std::vector<const Scope *> var_scopes;
45+
for (auto *s : local_scopes_) {
46+
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
47+
}
48+
4449
auto in_0_handle = in_var_handles[0];
4550
auto pre_in_var =
46-
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
47-
auto pre_place = in_0_handle->place_;
48-
51+
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_);
52+
PADDLE_ENFORCE_NOT_NULL(pre_in_var);
4953
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
5054
"Currently, gather_op only can gather SelectedRows.");
5155

56+
auto pre_place = in_0_handle->place_;
5257
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
5358
"The place of input and output should be the same.");
5459

@@ -67,7 +72,7 @@ void GatherOpHandle::RunImpl() {
6772
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
6873
"Places must be all on CPU or all on CUDA.");
6974
auto *in_var =
70-
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
75+
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
7176
auto &in_sr = in_var->Get<framework::SelectedRows>();
7277

7378
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
@@ -86,7 +91,7 @@ void GatherOpHandle::RunImpl() {
8691
// write the output
8792
auto &out_place = out_var_handle->place_;
8893
auto out_scope_idx = out_var_handle->scope_idx_;
89-
auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);
94+
auto out_var = var_scopes.at(out_scope_idx)->FindVar(out_var_handle->name_);
9095

9196
auto out = out_var->GetMutable<framework::SelectedRows>();
9297
out->set_height(pre_in.height());

paddle/fluid/framework/details/gather_op_handle_test.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ const f::DDim kDims = {20, 20};
2929
struct TestGatherOpHandle {
3030
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
3131
std::vector<Scope*> local_scopes_;
32+
std::vector<Scope*> param_scopes_;
3233
Scope g_scope_;
3334
std::unique_ptr<OpHandleBase> op_handle_;
3435
std::vector<std::unique_ptr<VarHandleBase>> vars_;
@@ -71,9 +72,14 @@ struct TestGatherOpHandle {
7172
void InitGatherOp(size_t input_scope_idx) {
7273
for (size_t j = 0; j < gpu_list_.size(); ++j) {
7374
local_scopes_.push_back(&(g_scope_.NewScope()));
74-
local_scopes_[j]->Var("out");
75+
Scope& local_scope = local_scopes_.back()->NewScope();
76+
*local_scopes_.back()
77+
->Var(details::kLocalExecScopeName)
78+
->GetMutable<Scope*>() = &local_scope;
79+
local_scope.Var("input");
80+
param_scopes_.emplace_back(&local_scope);
7581
}
76-
local_scopes_[input_scope_idx]->Var("input");
82+
param_scopes_[input_scope_idx]->Var("out");
7783

7884
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_));
7985
// add input
@@ -115,7 +121,8 @@ struct TestGatherOpHandle {
115121

116122
for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size();
117123
++input_scope_idx) {
118-
auto in_var = local_scopes_[input_scope_idx]->Var("input");
124+
auto in_var = param_scopes_.at(input_scope_idx)->FindVar("input");
125+
PADDLE_ENFORCE_NOT_NULL(in_var);
119126
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
120127
auto value = in_selected_rows->mutable_value();
121128
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
@@ -128,10 +135,11 @@ struct TestGatherOpHandle {
128135
value->Resize(kDims);
129136
}
130137

131-
auto out_var = local_scopes_[output_scope_idx]->Var("out");
138+
auto out_var = param_scopes_.at(output_scope_idx)->FindVar("out");
139+
PADDLE_ENFORCE_NOT_NULL(out_var);
132140
auto out_selected_rows = out_var->GetMutable<f::SelectedRows>();
133141

134-
auto in_var = local_scopes_[output_scope_idx]->Var("input");
142+
auto in_var = param_scopes_.at(output_scope_idx)->FindVar("input");
135143
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
136144

137145
out_selected_rows->mutable_value()->ShareDataWith(
@@ -155,7 +163,8 @@ struct TestGatherOpHandle {
155163
f::TensorCopy(rt, cpu_place, *(ctxs_[output_scope_idx]), &result_tensor);
156164
float* ct = result_tensor.data<float>();
157165

158-
for (int64_t j = 0; j < f::product(kDims); ++j) {
166+
for (int64_t j = 0;
167+
j < f::product(kDims) * static_cast<int64_t>(gpu_list_.size()); ++j) {
159168
ASSERT_NEAR(ct[j], send_vector[j % send_vector.size()], 1e-5);
160169
}
161170
}

paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,21 @@ void NCCLAllReduceOpHandle::RunImpl() {
4343
int dtype = -1;
4444
size_t numel = 0;
4545

46-
std::vector<LoDTensor> lod_tensors;
46+
std::vector<const LoDTensor *> lod_tensors;
4747

4848
for (size_t i = 0; i < local_scopes_.size(); ++i) {
4949
auto *s = local_scopes_[i];
5050
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
5151

5252
auto &lod_tensor = local_scope.FindVar(var_name)->Get<LoDTensor>();
53-
lod_tensors.emplace_back(lod_tensor);
53+
lod_tensors.emplace_back(&lod_tensor);
5454
}
5555

56-
if (platform::is_gpu_place(lod_tensors[0].place())) {
56+
if (platform::is_gpu_place(lod_tensors[0]->place())) {
5757
std::vector<std::function<void()>> all_reduce_calls;
5858
for (size_t i = 0; i < local_scopes_.size(); ++i) {
5959
auto &p = places_[i];
60-
auto &lod_tensor = lod_tensors[i];
60+
auto &lod_tensor = *lod_tensors[i];
6161
void *buffer = const_cast<void *>(lod_tensor.data<void>());
6262

6363
if (dtype == -1) {
@@ -93,7 +93,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
9393

9494
// Reduce All Tensor to trg in CPU
9595
ReduceLoDTensor func(lod_tensors, &trg);
96-
VisitDataType(ToDataType(lod_tensors[0].type()), func);
96+
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
9797

9898
for (size_t i = 0; i < local_scopes_.size(); ++i) {
9999
auto &scope =

paddle/fluid/framework/details/reduce_and_gather.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@ namespace framework {
2424
namespace details {
2525

2626
struct ReduceLoDTensor {
27-
const std::vector<LoDTensor> &src_tensors_;
27+
const std::vector<const LoDTensor *> &src_tensors_;
2828
LoDTensor &dst_tensor_;
2929

30-
ReduceLoDTensor(const std::vector<LoDTensor> &src, LoDTensor *dst)
30+
ReduceLoDTensor(const std::vector<const LoDTensor *> &src, LoDTensor *dst)
3131
: src_tensors_(src), dst_tensor_(*dst) {}
3232

3333
template <typename T>
3434
void operator()() const {
3535
PADDLE_ENFORCE(!src_tensors_.empty());
36-
auto &t0 = src_tensors_[0];
36+
auto &t0 = *src_tensors_[0];
3737
PADDLE_ENFORCE_NE(t0.numel(), 0);
3838
dst_tensor_.Resize(t0.dims());
3939
T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace());
4040
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
4141

4242
for (size_t i = 1; i < src_tensors_.size(); ++i) {
43-
auto &t = src_tensors_[i];
43+
auto &t = *src_tensors_[i];
4444
PADDLE_ENFORCE_EQ(t.dims(), t0.dims());
4545
PADDLE_ENFORCE_EQ(t.type(), t0.type());
4646
std::transform(t.data<T>(), t.data<T>() + t.numel(), dst, dst,

0 commit comments

Comments
 (0)