Skip to content

Commit 9b3f48d

Browse files
authored
Merge pull request #11616 from chengduoZH/fix_parallel_exe
Enhance Parallel Executor stable
2 parents bcea248 + c99fca5 commit 9b3f48d

File tree

6 files changed

+88
-24
lines changed

6 files changed

+88
-24
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
7373
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
7474
std::vector<std::function<void()>> broadcast_calls;
7575

76+
int type = platform::ToNCCLDataType(in_tensor.type());
77+
size_t numel = static_cast<size_t>(in_tensor.numel());
78+
7679
for (auto out_var_handle : out_var_handles) {
7780
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
7881
->FindVar(out_var_handle->name_);
@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
8790
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
8891
out_handle = out_var_handle;
8992
} else {
90-
send_recv_buffer =
91-
VariableVisitor::GetMutableTensor(out_var).mutable_data(
92-
out_var_handle->place_);
93+
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
94+
.Resize(in_tensor.dims())
95+
.mutable_data(out_var_handle->place_);
9396
}
9497

95-
int type = platform::ToNCCLDataType(in_tensor.type());
96-
size_t numel = static_cast<size_t>(in_tensor.numel());
9798
broadcast_calls.emplace_back(
9899
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
99100
PADDLE_ENFORCE(platform::dynload::ncclBcast(
@@ -102,23 +103,50 @@ void BroadcastOpHandle::RunImpl() {
102103
});
103104
}
104105

105-
this->RunAndRecordEvent([&] {
106-
{
107-
platform::NCCLGroupGuard guard;
108-
for (auto &call : broadcast_calls) {
109-
call();
106+
// FIXME(zcd): a temporary fix for some language model that has sparse
107+
// parameter.
108+
bool use_mutex = true;
109+
if (in_var->IsType<paddle::framework::SelectedRows>()) {
110+
use_mutex = false;
111+
}
112+
if (use_mutex) {
113+
this->RunAndRecordEvent([&] {
114+
{
115+
platform::NCCLGroupGuard guard;
116+
for (auto &call : broadcast_calls) {
117+
call();
118+
}
110119
}
111-
}
112120

113-
if (!out_handle->IsTheSameVar(*in_var_handle)) {
114-
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
115-
->FindVar(out_var_handles[0]->name_);
116-
paddle::framework::TensorCopy(
117-
in_tensor, in_var_handle->place_,
118-
*(dev_ctxes_.at(in_var_handle->place_)),
119-
&VariableVisitor::GetMutableTensor(out_var));
120-
}
121-
});
121+
if (!out_handle->IsTheSameVar(*in_var_handle)) {
122+
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
123+
->FindVar(out_var_handles[0]->name_);
124+
paddle::framework::TensorCopy(
125+
in_tensor, in_var_handle->place_,
126+
*(dev_ctxes_.at(in_var_handle->place_)),
127+
&VariableVisitor::GetMutableTensor(out_var));
128+
}
129+
});
130+
} else {
131+
this->RunAndRecordEventNoMutex([&] {
132+
{
133+
platform::NCCLGroupGuard guard;
134+
for (auto &call : broadcast_calls) {
135+
call();
136+
}
137+
}
138+
139+
if (!out_handle->IsTheSameVar(*in_var_handle)) {
140+
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
141+
->FindVar(out_var_handles[0]->name_);
142+
paddle::framework::TensorCopy(
143+
in_tensor, in_var_handle->place_,
144+
*(dev_ctxes_.at(in_var_handle->place_)),
145+
&VariableVisitor::GetMutableTensor(out_var));
146+
}
147+
});
148+
}
149+
122150
#else
123151
PADDLE_THROW("CUDA is not enabled.");
124152
#endif

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
351351
auto &prev_grad = vars.back();
352352
op_handle->AddInput(prev_grad.get());
353353

354-
auto var = new VarHandle(vars.size() - 1, i, og, p);
354+
auto var = new VarHandle(vars.size(), i, og, p);
355355
vars.emplace_back(var);
356356
op_handle->AddOutput(var);
357357
}
@@ -447,8 +447,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
447447
op_handle->AddInput(prev_grad.get());
448448
}
449449
auto &vars = result->vars_[dst_dev_id][og];
450-
auto var =
451-
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
450+
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
452451
vars.emplace_back(var);
453452
op_handle->AddOutput(var);
454453
return var;

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
139139
#endif
140140
}
141141

142+
void OpHandleBase::RunAndRecordEventNoMutex(
143+
const std::function<void()> &callback) {
144+
#ifdef PADDLE_WITH_CUDA
145+
if (!events_.empty()) { // Use event
146+
std::function<void()> method = callback;
147+
148+
for (auto &p : dev_ctxes_) {
149+
method = [method, p, this]() {
150+
static_cast<platform::CUDADeviceContext *>(p.second)
151+
->RecordEventNoMutex(
152+
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
153+
method);
154+
};
155+
}
156+
method();
157+
} else {
158+
#endif
159+
callback();
160+
#ifdef PADDLE_WITH_CUDA
161+
}
162+
#endif
163+
}
164+
142165
void OpHandleBase::RunAndRecordEvent(platform::Place p,
143166
const std::function<void()> &callback) {
144167
#ifdef PADDLE_WITH_CUDA

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ class OpHandleBase {
8585
protected:
8686
void RunAndRecordEvent(const std::function<void()> &callback);
8787

88+
// FIXME(zcd): A temporary fix for some language model that has sparse
89+
// parameter.
90+
void RunAndRecordEventNoMutex(const std::function<void()> &callback);
91+
8892
void RunAndRecordEvent(platform::Place p,
8993
const std::function<void()> &callback);
9094

paddle/fluid/framework/details/reduce_op_handle.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() {
8080
}
8181

8282
if (pre_in_var->IsType<framework::SelectedRows>()) {
83-
this->RunAndRecordEvent([&] {
83+
// FIXME(zcd): A temporary fix for some language model that has sparse
84+
// parameter.
85+
this->RunAndRecordEventNoMutex([&] {
8486
std::vector<const SelectedRows *> in_selected_rows =
8587
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
8688
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,

paddle/fluid/platform/device_context.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext {
106106
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
107107
}
108108

109+
// FIXME(zcd): A temporary fix for some language model that has sparse
110+
// parameter.
111+
template <typename Callback>
112+
void RecordEventNoMutex(cudaEvent_t ev, Callback callback) {
113+
callback();
114+
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
115+
}
116+
109117
private:
110118
CUDAPlace place_;
111119

0 commit comments

Comments
 (0)