Skip to content

Commit 7d26dd8

Browse files
author
chengduo
authored
Enhance Parallel Executor stable (#11634)
* Fix Parallel Exe(VarHandel's version) * Fix broadcast * enhance ParallelExecutor stable
1 parent b5c5a3a commit 7d26dd8

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
7373
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
7474
std::vector<std::function<void()>> broadcast_calls;
7575

76+
int type = platform::ToNCCLDataType(in_tensor.type());
77+
size_t numel = static_cast<size_t>(in_tensor.numel());
78+
7679
for (auto out_var_handle : out_var_handles) {
7780
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
7881
->FindVar(out_var_handle->name_);
@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
8790
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
8891
out_handle = out_var_handle;
8992
} else {
90-
send_recv_buffer =
91-
VariableVisitor::GetMutableTensor(out_var).mutable_data(
92-
out_var_handle->place_);
93+
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
94+
.Resize(in_tensor.dims())
95+
.mutable_data(out_var_handle->place_);
9396
}
9497

95-
int type = platform::ToNCCLDataType(in_tensor.type());
96-
size_t numel = static_cast<size_t>(in_tensor.numel());
9798
broadcast_calls.emplace_back(
9899
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
99100
PADDLE_ENFORCE(platform::dynload::ncclBcast(

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
351351
auto &prev_grad = vars.back();
352352
op_handle->AddInput(prev_grad.get());
353353

354-
auto var = new VarHandle(vars.size() - 1, i, og, p);
354+
auto var = new VarHandle(vars.size(), i, og, p);
355355
vars.emplace_back(var);
356356
op_handle->AddOutput(var);
357357
}
@@ -447,8 +447,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
447447
op_handle->AddInput(prev_grad.get());
448448
}
449449
auto &vars = result->vars_[dst_dev_id][og];
450-
auto var =
451-
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
450+
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
452451
vars.emplace_back(var);
453452
op_handle->AddOutput(var);
454453
return var;

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
1514
#include "paddle/fluid/framework/details/op_handle_base.h"
15+
#include <map>
1616

1717
namespace paddle {
1818
namespace framework {
@@ -122,11 +122,16 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
122122
#ifdef PADDLE_WITH_CUDA
123123
if (!events_.empty()) { // Use event
124124
std::function<void()> method = callback;
125-
125+
// NOTE(zcd): device context must be ordered here because RecordEvent
126+
// will use a mutex to ensure the safe of multi-threads.
127+
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
126128
for (auto &p : dev_ctxes_) {
129+
ordered_ctxes.emplace(p.second, p.first);
130+
}
131+
for (auto &p : ordered_ctxes) {
127132
method = [method, p, this]() {
128-
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
129-
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
133+
static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent(
134+
events_.at(boost::get<platform::CUDAPlace>(p.second).device),
130135
method);
131136
};
132137
}

0 commit comments

Comments
 (0)