Skip to content

Commit 4738802

Browse files
committed
fix bugs
1 parent 2e32007 commit 4738802

File tree

4 files changed

+31
-1
lines changed

4 files changed

+31
-1
lines changed

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,24 @@ namespace paddle {
2020
namespace framework {
2121
namespace details {
2222

23+
#ifdef PADDLE_WITH_CUDA
24+
DataBalanceOpHandle::DataBalanceOpHandle(
25+
const std::vector<Scope *> &local_scopes,
26+
const std::vector<platform::Place> &places,
27+
const platform::NCCLContextMap *ctxs)
28+
: local_scopes_(local_scopes), places_(places) {
29+
if (ctxs) {
30+
for (auto &p : places_) {
31+
this->dev_ctxes_[p] = ctxs->DevCtx(p);
32+
}
33+
}
34+
}
35+
#else
2336
DataBalanceOpHandle::DataBalanceOpHandle(
2437
const std::vector<Scope *> &local_scopes,
2538
const std::vector<platform::Place> &places)
2639
: local_scopes_(local_scopes), places_(places) {}
40+
#endif
2741

2842
std::string DataBalanceOpHandle::Name() const { return "data balance"; }
2943

@@ -104,6 +118,7 @@ void DataBalanceOpHandle::RunImpl() {
104118
}
105119
}
106120
const auto &balance_plan = GetBalancePlan(device_sizes);
121+
107122
for (const auto &trans : balance_plan) {
108123
for (int data_idx = 0; data_idx < data_num; ++data_idx) {
109124
LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]];

paddle/fluid/framework/details/data_balance_op_handle.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,24 @@
1919
#include "paddle/fluid/framework/details/op_handle_base.h"
2020
#include "paddle/fluid/framework/lod_tensor.h"
2121
#include "paddle/fluid/framework/scope.h"
22+
#ifdef PADDLE_WITH_CUDA
23+
#include "paddle/fluid/platform/nccl_helper.h"
24+
#endif
2225

2326
namespace paddle {
2427
namespace framework {
2528
namespace details {
2629

2730
struct DataBalanceOpHandle : public OpHandleBase {
2831
public:
32+
#ifdef PADDLE_WITH_CUDA
2933
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
30-
const std::vector<platform::Place> &places);
34+
const std::vector<platform::Place> &places,
35+
const platform::NCCLContextMap *ctxs);
36+
#else
37+
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
38+
const std::vector<platform::Place> *places)
39+
#endif
3140

3241
std::string Name() const override;
3342

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
368368

369369
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
370370
SSAGraph *result, const std::vector<std::string> &datas) const {
371+
#ifdef PADDLE_WITH_CUDA
372+
result->ops_.emplace_back(
373+
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
374+
#else
371375
result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
376+
#endif
372377
auto *op_handle = result->ops_.back().get();
373378
for (size_t i = 0; i < places_.size(); ++i) {
374379
auto &p = places_[i];

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
6060
#ifdef PADDLE_WITH_CUDA
6161
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
6262
for (auto &dev_ctx : dev_ctxes_) {
63+
PADDLE_ENFORCE_NOT_NULL(dev_ctx.second);
6364
dev_ctx.second->Wait();
6465
}
6566
} else {

0 commit comments

Comments
 (0)