Skip to content

Commit c99fca5

Browse files
committed
Add No Mutex
1 parent 13de723 commit c99fca5

File tree

5 files changed

+80
-16
lines changed

5 files changed

+80
-16
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,50 @@ void BroadcastOpHandle::RunImpl() {
103103
});
104104
}
105105

106-
this->RunAndRecordEvent([&] {
107-
{
108-
platform::NCCLGroupGuard guard;
109-
for (auto &call : broadcast_calls) {
110-
call();
106+
// FIXME(zcd): a temporary fix for some language model that has sparse
107+
// parameter.
108+
bool use_mutex = true;
109+
if (in_var->IsType<paddle::framework::SelectedRows>()) {
110+
use_mutex = false;
111+
}
112+
if (use_mutex) {
113+
this->RunAndRecordEvent([&] {
114+
{
115+
platform::NCCLGroupGuard guard;
116+
for (auto &call : broadcast_calls) {
117+
call();
118+
}
111119
}
112-
}
113120

114-
if (!out_handle->IsTheSameVar(*in_var_handle)) {
115-
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
116-
->FindVar(out_var_handles[0]->name_);
117-
paddle::framework::TensorCopy(
118-
in_tensor, in_var_handle->place_,
119-
*(dev_ctxes_.at(in_var_handle->place_)),
120-
&VariableVisitor::GetMutableTensor(out_var));
121-
}
122-
});
121+
if (!out_handle->IsTheSameVar(*in_var_handle)) {
122+
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
123+
->FindVar(out_var_handles[0]->name_);
124+
paddle::framework::TensorCopy(
125+
in_tensor, in_var_handle->place_,
126+
*(dev_ctxes_.at(in_var_handle->place_)),
127+
&VariableVisitor::GetMutableTensor(out_var));
128+
}
129+
});
130+
} else {
131+
this->RunAndRecordEventNoMutex([&] {
132+
{
133+
platform::NCCLGroupGuard guard;
134+
for (auto &call : broadcast_calls) {
135+
call();
136+
}
137+
}
138+
139+
if (!out_handle->IsTheSameVar(*in_var_handle)) {
140+
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
141+
->FindVar(out_var_handles[0]->name_);
142+
paddle::framework::TensorCopy(
143+
in_tensor, in_var_handle->place_,
144+
*(dev_ctxes_.at(in_var_handle->place_)),
145+
&VariableVisitor::GetMutableTensor(out_var));
146+
}
147+
});
148+
}
149+
123150
#else
124151
PADDLE_THROW("CUDA is not enabled.");
125152
#endif

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
139139
#endif
140140
}
141141

142+
void OpHandleBase::RunAndRecordEventNoMutex(
143+
const std::function<void()> &callback) {
144+
#ifdef PADDLE_WITH_CUDA
145+
if (!events_.empty()) { // Use event
146+
std::function<void()> method = callback;
147+
148+
for (auto &p : dev_ctxes_) {
149+
method = [method, p, this]() {
150+
static_cast<platform::CUDADeviceContext *>(p.second)
151+
->RecordEventNoMutex(
152+
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
153+
method);
154+
};
155+
}
156+
method();
157+
} else {
158+
#endif
159+
callback();
160+
#ifdef PADDLE_WITH_CUDA
161+
}
162+
#endif
163+
}
164+
142165
void OpHandleBase::RunAndRecordEvent(platform::Place p,
143166
const std::function<void()> &callback) {
144167
#ifdef PADDLE_WITH_CUDA

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ class OpHandleBase {
8585
protected:
8686
void RunAndRecordEvent(const std::function<void()> &callback);
8787

88+
// FIXME(zcd): A temporary fix for some language model that has sparse
89+
// parameter.
90+
void RunAndRecordEventNoMutex(const std::function<void()> &callback);
91+
8892
void RunAndRecordEvent(platform::Place p,
8993
const std::function<void()> &callback);
9094

paddle/fluid/framework/details/reduce_op_handle.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() {
8080
}
8181

8282
if (pre_in_var->IsType<framework::SelectedRows>()) {
83-
this->RunAndRecordEvent([&] {
83+
// FIXME(zcd): A temporary fix for some language model that has sparse
84+
// parameter.
85+
this->RunAndRecordEventNoMutex([&] {
8486
std::vector<const SelectedRows *> in_selected_rows =
8587
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
8688
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,

paddle/fluid/platform/device_context.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext {
106106
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
107107
}
108108

109+
// FIXME(zcd): A temporary fix for some language model that has sparse
110+
// parameter.
111+
template <typename Callback>
112+
void RecordEventNoMutex(cudaEvent_t ev, Callback callback) {
113+
callback();
114+
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
115+
}
116+
109117
private:
110118
CUDAPlace place_;
111119

0 commit comments

Comments
 (0)