Skip to content

Commit 99acf1d

Browse files
author
chengduo
authored
Merge pull request #10351 from chengduoZH/feature/update_sparse_parameter
Feature/update sparse parameter
2 parents 8f8a476 + 881e063 commit 99acf1d

15 files changed

+404
-109
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ if(WITH_GPU)
1515
dynload_cuda)
1616
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
1717
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
18+
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
19+
1820
else()
1921
set(multi_devices_graph_builder_deps)
2022
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
23+
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2124
endif()
2225

23-
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2426
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2527

2628
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919
namespace paddle {
2020
namespace framework {
2121
namespace details {
22-
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
23-
const std::vector<platform::Place> &places)
24-
: local_scopes_(local_scopes), places_(places) {}
2522

2623
void BroadcastOpHandle::RunImpl() {
27-
// the input and output may have dummy var.
28-
VarHandle *in_var_handle;
24+
if (places_.size() == 1) return;
2925

26+
// The input and output may have dummy vars.
27+
VarHandle *in_var_handle;
3028
{
3129
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
3230
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
@@ -55,27 +53,97 @@ void BroadcastOpHandle::RunImpl() {
5553

5654
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5755

58-
for (auto *out : out_var_handles) {
59-
if (*out == *in_var_handle) {
56+
// NOTE: The tensors' Place of input and output must be all on GPU or all on
57+
// CPU.
58+
for (auto *out_var_handle : out_var_handles) {
59+
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
6060
continue;
6161
}
62-
63-
auto &out_p = out->place_;
64-
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_);
62+
auto t_out_p = out_var_handle->place_;
63+
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
64+
->FindVar(out_var_handle->name_);
6565
PADDLE_ENFORCE_NOT_NULL(out_var);
66-
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
67-
"Places must be all on CPU or all on CUDA.");
68-
66+
if (platform::is_gpu_place(in_tensor.place())) {
67+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
68+
"Places of input and output must be all on GPU.");
69+
} else {
70+
t_out_p = platform::CPUPlace();
71+
}
6972
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
70-
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
73+
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
7174
in_tensor.type());
75+
}
76+
77+
if (platform::is_cpu_place(in_tensor.place())) {
78+
for (auto *out_var_handle : out_var_handles) {
79+
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
80+
continue;
81+
}
82+
auto &out_p = out_var_handle->place_;
83+
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
84+
->FindVar(out_var_handle->name_);
85+
86+
RunAndRecordEvent(out_p, [in_tensor, out_var] {
87+
paddle::framework::TensorCopy(
88+
in_tensor, platform::CPUPlace(),
89+
&VariableVisitor::GetMutableTensor(out_var));
90+
});
91+
}
92+
} else {
93+
#ifdef PADDLE_WITH_CUDA
94+
VarHandle *out_handle = nullptr;
95+
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
96+
std::vector<std::function<void()>> broadcast_calls;
97+
98+
for (auto out_var_handle : out_var_handles) {
99+
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
100+
->FindVar(out_var_handle->name_);
101+
102+
int dst_id =
103+
boost::get<platform::CUDAPlace>(out_var_handle->place_).device;
104+
105+
auto &nccl_ctx = nccl_ctxs_->at(dst_id);
106+
107+
void *send_recv_buffer = nullptr;
108+
if (root_id == dst_id) {
109+
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
110+
out_handle = out_var_handle;
111+
} else {
112+
send_recv_buffer =
113+
VariableVisitor::GetMutableTensor(out_var).mutable_data(
114+
out_var_handle->place_);
115+
}
116+
117+
int type = platform::ToNCCLDataType(in_tensor.type());
118+
size_t numel = static_cast<size_t>(in_tensor.numel());
119+
broadcast_calls.emplace_back(
120+
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
121+
PADDLE_ENFORCE(platform::dynload::ncclBcast(
122+
send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
123+
root_id, nccl_ctx.comm_, nccl_ctx.stream()));
124+
});
125+
}
72126

73-
auto dev_ctx = dev_ctxes_.at(out_p);
74-
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
75-
paddle::framework::TensorCopy(
76-
in_tensor, out_p, *(dev_ctx),
77-
&VariableVisitor::GetMutableTensor(out_var));
127+
this->RunAndRecordEvent([&] {
128+
{
129+
platform::NCCLGroupGuard guard;
130+
for (auto &call : broadcast_calls) {
131+
call();
132+
}
133+
}
134+
135+
if (!out_handle->IsTheSameVar(*in_var_handle)) {
136+
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
137+
->FindVar(out_var_handles[0]->name_);
138+
paddle::framework::TensorCopy(
139+
in_tensor, in_var_handle->place_,
140+
*(dev_ctxes_.at(in_var_handle->place_)),
141+
&VariableVisitor::GetMutableTensor(out_var));
142+
}
78143
});
144+
#else
145+
PADDLE_THROW("CUDA is not enabled.");
146+
#endif
79147
}
80148
}
81149

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,32 @@
2424
#include "paddle/fluid/framework/selected_rows.h"
2525
#include "paddle/fluid/platform/device_context.h"
2626

27+
#ifdef PADDLE_WITH_CUDA
28+
#include "paddle/fluid/platform/nccl_helper.h"
29+
#endif
30+
2731
namespace paddle {
2832
namespace framework {
2933
namespace details {
3034

3135
struct BroadcastOpHandle : public OpHandleBase {
3236
public:
37+
#ifdef PADDLE_WITH_CUDA
38+
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
39+
const std::vector<platform::Place> &places,
40+
const platform::NCCLContextMap *nccl_ctxs)
41+
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
42+
if (nccl_ctxs_) {
43+
for (auto &p_ctx : nccl_ctxs_->contexts_) {
44+
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
45+
}
46+
}
47+
}
48+
#else
3349
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
34-
const std::vector<platform::Place> &places);
50+
const std::vector<platform::Place> &places)
51+
: local_scopes_(local_scopes), places_(places) {}
52+
#endif
3553

3654
std::string Name() const override;
3755

@@ -44,6 +62,9 @@ struct BroadcastOpHandle : public OpHandleBase {
4462
private:
4563
const std::vector<Scope *> &local_scopes_;
4664
const std::vector<platform::Place> &places_;
65+
#ifdef PADDLE_WITH_CUDA
66+
const platform::NCCLContextMap *nccl_ctxs_;
67+
#endif
4768
};
4869
} // namespace details
4970
} // namespace framework

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,25 @@ struct TestBroadcastOpHandle {
3535
std::unique_ptr<OpHandleBase> op_handle_;
3636
std::vector<std::unique_ptr<VarHandleBase>> vars_;
3737
std::vector<p::Place> gpu_list_;
38+
bool use_gpu_;
39+
#ifdef PADDLE_WITH_CUDA
40+
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
41+
#endif
3842

3943
void WaitAll() {
4044
for (size_t j = 0; j < ctxs_.size(); ++j) {
4145
ctxs_[j]->Wait();
4246
}
47+
#ifdef PADDLE_WITH_CUDA
48+
if (nccl_ctxs_) {
49+
nccl_ctxs_->WaitAll();
50+
}
51+
#endif
4352
}
4453

4554
void InitCtxOnGpu(bool use_gpu) {
46-
if (use_gpu) {
55+
use_gpu_ = use_gpu;
56+
if (use_gpu_) {
4757
#ifdef PADDLE_WITH_CUDA
4858
int count = p::GetCUDADeviceCount();
4959
if (count <= 1) {
@@ -57,6 +67,7 @@ struct TestBroadcastOpHandle {
5767
gpu_list_.push_back(p);
5868
ctxs_.emplace_back(new p::CUDADeviceContext(p));
5969
}
70+
nccl_ctxs_.reset(new platform::NCCLContextMap(gpu_list_));
6071
#else
6172
PADDLE_THROW("CUDA is not support.");
6273
#endif
@@ -67,6 +78,9 @@ struct TestBroadcastOpHandle {
6778
gpu_list_.push_back(p);
6879
ctxs_.emplace_back(new p::CPUDeviceContext(p));
6980
}
81+
#ifdef PADDLE_WITH_CUDA
82+
nccl_ctxs_.reset(nullptr);
83+
#endif
7084
}
7185
}
7286

@@ -82,7 +96,21 @@ struct TestBroadcastOpHandle {
8296
}
8397
param_scopes_[input_scope_idx]->Var("input");
8498

85-
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
99+
if (use_gpu_) {
100+
#ifdef PADDLE_WITH_CUDA
101+
op_handle_.reset(
102+
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
103+
#else
104+
PADDLE_THROW("CUDA is not support.");
105+
#endif
106+
} else {
107+
#ifdef PADDLE_WITH_CUDA
108+
op_handle_.reset(
109+
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
110+
#else
111+
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
112+
#endif
113+
}
86114

87115
auto* in_var_handle =
88116
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
@@ -97,7 +125,9 @@ struct TestBroadcastOpHandle {
97125
op_handle_->AddInput(dummy_var_handle);
98126

99127
for (size_t j = 0; j < gpu_list_.size(); ++j) {
100-
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
128+
if (!use_gpu_) {
129+
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
130+
}
101131
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
102132
vars_.emplace_back(out_var_handle);
103133
op_handle_->AddOutput(out_var_handle);

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
2525
: local_scopes_(local_scopes), places_(places) {}
2626

2727
void GatherOpHandle::RunImpl() {
28+
if (places_.size() == 1) return;
2829
// the input and output may have dummy var.
2930
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
3031

@@ -35,7 +36,6 @@ void GatherOpHandle::RunImpl() {
3536
VarHandle *out_var_handle;
3637
{
3738
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
38-
3939
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
4040
"The number of output should be one.");
4141
out_var_handle = out_var_handles.front();
@@ -50,68 +50,62 @@ void GatherOpHandle::RunImpl() {
5050
auto pre_in_var =
5151
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_);
5252
PADDLE_ENFORCE_NOT_NULL(pre_in_var);
53+
5354
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
5455
"Currently, gather_op only can gather SelectedRows.");
5556

56-
auto pre_place = in_0_handle->place_;
57-
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
58-
"The place of input and output should be the same.");
59-
6057
// Wait input done, this Wait is asynchronous operation
6158
WaitInputVarGenerated(in_var_handles);
6259

60+
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
6361
std::vector<int64_t> out_rows;
6462
std::vector<Tensor> in_tensors;
65-
std::vector<platform::Place> in_places;
6663

67-
auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
68-
// gather the inputs
64+
// Gather the inputs
6965
for (auto *in_handle : in_var_handles) {
70-
auto in_p = in_handle->place_;
71-
in_places.push_back(in_p);
72-
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
73-
"Places must be all on CPU or all on CUDA.");
7466
auto *in_var =
7567
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
76-
auto &in_sr = in_var->Get<framework::SelectedRows>();
68+
PADDLE_ENFORCE_NOT_NULL(in_var);
69+
VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var);
7770

78-
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
79-
"The type of input is not consistent.");
80-
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
81-
"The height of inputs is not consistent.");
82-
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
83-
"The dims of inputs is not consistent.");
71+
auto &in_sr_value = in_var->Get<framework::SelectedRows>();
8472

85-
auto &in_sr_rows = in_sr.rows();
73+
auto &in_sr_rows = in_sr_value.rows();
8674
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
87-
88-
in_tensors.emplace_back(in_sr.value());
75+
in_tensors.emplace_back(in_sr_value.value());
8976
}
9077

91-
// write the output
92-
auto &out_place = out_var_handle->place_;
93-
auto out_scope_idx = out_var_handle->scope_idx_;
94-
auto out_var = var_scopes.at(out_scope_idx)->FindVar(out_var_handle->name_);
78+
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
79+
platform::Place t_out_p = out_var_handle->place_;
80+
if (platform::is_gpu_place(pre_in_value.place())) {
81+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
82+
"Places of input and output must be all on GPU.");
83+
} else {
84+
t_out_p = platform::CPUPlace();
85+
}
9586

96-
auto out = out_var->GetMutable<framework::SelectedRows>();
97-
out->set_height(pre_in.height());
98-
out->set_rows(out_rows);
87+
auto out_var =
88+
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
89+
PADDLE_ENFORCE_NOT_NULL(out_var);
90+
auto out_value = out_var->GetMutable<framework::SelectedRows>();
91+
out_value->set_height(pre_in_value.height());
92+
out_value->set_rows(out_rows);
9993
size_t rows = out_rows.size();
100-
DDim out_dim = pre_in.GetCompleteDims();
94+
DDim out_dim = pre_in_value.GetCompleteDims();
10195
out_dim[0] = static_cast<int64_t>(rows);
102-
out->mutable_value()->Resize(out_dim);
103-
out->mutable_value()->mutable_data(out_place, pre_in.value().type());
104-
Tensor *out_tensor = out->mutable_value();
96+
out_value->mutable_value()->Resize(out_dim).mutable_data(
97+
t_out_p, pre_in_value.value().type());
98+
Tensor *out_tensor = out_value->mutable_value();
10599

106100
// copy
107-
auto dev_ctx = dev_ctxes_[out_place];
108-
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
101+
auto dev_ctx = dev_ctxes_[out_var_handle->place_];
102+
RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
103+
t_out_p] {
109104
int s = 0, e = 0;
110105
for (size_t j = 0; j < in_tensors.size(); ++j) {
111106
e += in_tensors[j].dims()[0];
112107
auto sub_out = out_tensor->Slice(s, e);
113-
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx),
114-
&sub_out);
108+
paddle::framework::TensorCopy(in_tensors[j], t_out_p, *dev_ctx, &sub_out);
115109
s = e;
116110
}
117111
});

0 commit comments

Comments
 (0)