Skip to content

Commit 16658f7

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into refine-prefetch
2 parents 83a577e + 1d19849 commit 16658f7

27 files changed

+269
-145
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
1313
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1414

1515
if(WITH_GPU)
16-
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
16+
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
1717
dynload_cuda variable_visitor)
18-
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
1918
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
2019
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
2120

2221
else()
23-
set(multi_devices_graph_builder_deps)
22+
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
23+
variable_visitor)
2424
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
2525
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2626
endif()
@@ -29,7 +29,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
2929
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
3030

3131
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
32-
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
32+
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle)
3333

3434

3535
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)

paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc renamed to paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,33 @@
1313
// limitations under the License.
1414
#include <algorithm>
1515

16+
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
1617
#include "paddle/fluid/framework/details/container_cast.h"
17-
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
1818
#include "paddle/fluid/framework/details/reduce_and_gather.h"
1919
#include "paddle/fluid/framework/details/variable_visitor.h"
2020

2121
namespace paddle {
2222
namespace framework {
2323
namespace details {
24-
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
25-
const std::vector<Scope *> &local_scopes,
26-
const std::vector<platform::Place> &places,
27-
const platform::NCCLContextMap &ctxs)
24+
25+
#ifdef PADDLE_WITH_CUDA
26+
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
27+
const std::vector<platform::Place> &places,
28+
const platform::NCCLContextMap *ctxs)
2829
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
29-
for (auto &p : places_) {
30-
this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p);
30+
if (nccl_ctxs_) {
31+
for (auto &p : places_) {
32+
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
33+
}
3134
}
3235
}
36+
#else
37+
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
38+
const std::vector<platform::Place> &places)
39+
: local_scopes_(local_scopes), places_(places) {}
40+
#endif
3341

34-
void NCCLAllReduceOpHandle::RunImpl() {
42+
void AllReduceOpHandle::RunImpl() {
3543
if (NoDummyInputSize() == 1) {
3644
return; // No need to all reduce when GPU count = 1;
3745
} else {
@@ -58,6 +66,8 @@ void NCCLAllReduceOpHandle::RunImpl() {
5866
}
5967

6068
if (platform::is_gpu_place(lod_tensors[0]->place())) {
69+
#ifdef PADDLE_WITH_CUDA
70+
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
6171
int dtype = -1;
6272
size_t numel = 0;
6373
std::vector<std::function<void()>> all_reduce_calls;
@@ -75,7 +85,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
7585
}
7686

7787
int dev_id = boost::get<platform::CUDAPlace>(p).device;
78-
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
88+
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
7989
auto stream = nccl_ctx.stream();
8090
auto comm = nccl_ctx.comm_;
8191
all_reduce_calls.emplace_back([=] {
@@ -90,22 +100,25 @@ void NCCLAllReduceOpHandle::RunImpl() {
90100
call();
91101
}
92102
});
103+
#else
104+
PADDLE_THROW("Not compiled with CUDA");
105+
#endif
93106
} else { // Special handle CPU only Operator's gradient. Like CRF
94107
auto &trg = *this->local_scopes_[0]
95108
->FindVar(kLocalExecScopeName)
96109
->Get<Scope *>()
97-
->Var()
110+
->FindVar(out_var_handles[0]->name_)
98111
->GetMutable<framework::LoDTensor>();
99112

100113
// Reduce All Tensor to trg in CPU
101114
ReduceLoDTensor func(lod_tensors, &trg);
102115
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
103116

104-
for (size_t i = 0; i < local_scopes_.size(); ++i) {
117+
for (size_t i = 1; i < local_scopes_.size(); ++i) {
105118
auto &scope =
106119
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
107120
auto &p = places_[i];
108-
auto *var = scope.FindVar(in_var_handles[i]->name_);
121+
auto *var = scope.FindVar(out_var_handles[i]->name_);
109122
auto *dev_ctx = dev_ctxes_[p];
110123

111124
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
@@ -118,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
118131
}
119132
}
120133

121-
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
134+
std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
122135
} // namespace details
123136
} // namespace framework
124137
} // namespace paddle

paddle/fluid/framework/details/nccl_all_reduce_op_handle.h renamed to paddle/fluid/framework/details/all_reduce_op_handle.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,23 @@
2020
#include "paddle/fluid/framework/details/op_handle_base.h"
2121
#include "paddle/fluid/framework/lod_tensor.h"
2222
#include "paddle/fluid/framework/scope.h"
23+
#ifdef PADDLE_WITH_CUDA
2324
#include "paddle/fluid/platform/nccl_helper.h"
25+
#endif
2426

2527
namespace paddle {
2628
namespace framework {
2729
namespace details {
2830

29-
struct NCCLAllReduceOpHandle : public OpHandleBase {
30-
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
31-
const std::vector<platform::Place> &places,
32-
const platform::NCCLContextMap &ctxs);
33-
31+
struct AllReduceOpHandle : public OpHandleBase {
32+
#ifdef PADDLE_WITH_CUDA
33+
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
34+
const std::vector<platform::Place> &places,
35+
const platform::NCCLContextMap *ctxs);
36+
#else
37+
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
38+
const std::vector<platform::Place> &places);
39+
#endif
3440
std::string Name() const override;
3541

3642
// Delay and buffer nccl_all_reduce together can significantly increase
@@ -43,7 +49,9 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
4349
private:
4450
std::vector<Scope *> local_scopes_;
4551
std::vector<platform::Place> places_;
46-
const platform::NCCLContextMap &nccl_ctxs_;
52+
#ifdef PADDLE_WITH_CUDA
53+
const platform::NCCLContextMap *nccl_ctxs_;
54+
#endif
4755
};
4856

4957
} // namespace details

paddle/fluid/framework/details/execution_strategy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace details {
2020

2121
struct ExecutionStrategy {
2222
size_t num_threads_{0};
23-
bool use_event_{true};
23+
bool use_cuda_{true};
2424
bool allow_op_delay_{false};
2525
size_t num_iteration_per_drop_scope_{100};
2626
};

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <utility>
1818
#include <vector>
1919

20+
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
2021
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
2122
#include "paddle/fluid/framework/details/computation_op_handle.h"
2223
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
@@ -26,10 +27,6 @@
2627
#include "paddle/fluid/framework/op_info.h"
2728
#include "paddle/fluid/framework/scope.h"
2829

29-
#ifdef PADDLE_WITH_CUDA
30-
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
31-
#endif
32-
3330
namespace paddle {
3431
namespace framework {
3532
namespace details {
@@ -243,7 +240,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
243240
CreateReduceOp(&result, g_name, 0);
244241
CreateBroadcastOp(&result, g_name, 0);
245242
} else {
246-
InsertNCCLAllReduceOp(&result, g_name);
243+
InsertAllReduceOp(&result, g_name);
247244
}
248245
break;
249246
}
@@ -286,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
286283
return false;
287284
}
288285

286+
void MultiDevSSAGraphBuilder::SetCommunicationContext(
287+
OpHandleBase *op_handle, const platform::Place &p) const {
288+
#ifdef PADDLE_WITH_CUDA
289+
if (nccl_ctxs_ == nullptr) {
290+
op_handle->SetDeviceContext(p,
291+
platform::DeviceContextPool::Instance().Get(p));
292+
}
293+
#else
294+
op_handle->SetDeviceContext(p,
295+
platform::DeviceContextPool::Instance().Get(p));
296+
#endif
297+
}
298+
289299
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
290300
const std::string &p_name,
291301
size_t src_dev_id) const {
@@ -300,15 +310,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
300310
op_handle->AddInput(in);
301311

302312
for (size_t i = 0; i < places_.size(); ++i) {
303-
auto &vars = result->vars_.at(i).at(p_name);
304313
auto &p = places_[i];
314+
SetCommunicationContext(op_handle, p);
315+
auto &vars = result->vars_.at(i).at(p_name);
305316
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
306317
vars.emplace_back(out_var);
307318
op_handle->AddOutput(out_var);
308-
#ifndef ADDLE_WITH_CUDA
309-
op_handle->SetDeviceContext(p,
310-
platform::DeviceContextPool::Instance().Get(p));
311-
#endif
312319
}
313320
}
314321

@@ -320,15 +327,19 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
320327
CreateOpHandleIOs(result, op, dev_id);
321328
}
322329

323-
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
324-
SSAGraph *result, const std::string &og) const {
330+
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
331+
const std::string &og) const {
325332
#ifdef PADDLE_WITH_CUDA
326333
result->ops_.emplace_back(
327-
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
334+
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
335+
#else
336+
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
337+
#endif
328338
auto *op_handle = result->ops_.back().get();
329339

330340
for (size_t i = 0; i < places_.size(); ++i) {
331341
auto &p = places_[i];
342+
SetCommunicationContext(op_handle, p);
332343
auto &vars = result->vars_[i][og];
333344
PADDLE_ENFORCE(!vars.empty());
334345
auto &prev_grad = vars.back();
@@ -338,9 +349,6 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
338349
vars.emplace_back(var);
339350
op_handle->AddOutput(var);
340351
}
341-
#else
342-
PADDLE_ENFORCE("Not implemented");
343-
#endif
344352
}
345353

346354
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
@@ -379,7 +387,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
379387
for (size_t i = 0; i < places_.size(); ++i) {
380388
// Insert ScaleCost OpHandle
381389
#ifdef PADDLE_WITH_CUDA
382-
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(places_[i]);
390+
auto *communication_dev_ctx =
391+
nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
392+
: platform::DeviceContextPool::Instance().Get(places_[i]);
383393
#else
384394
auto *communication_dev_ctx =
385395
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
@@ -424,12 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
424434
auto *op_handle = result->ops_.back().get();
425435

426436
for (size_t i = 0; i < places_.size(); ++i) {
427-
auto &vars = result->vars_[i][og];
428-
#ifndef PADDLE_WITH_CUDA
429437
auto &p = places_[i];
430-
op_handle->SetDeviceContext(p,
431-
platform::DeviceContextPool::Instance().Get(p));
432-
#endif
438+
SetCommunicationContext(op_handle, p);
439+
auto &vars = result->vars_[i][og];
433440
PADDLE_ENFORCE(!vars.empty());
434441
auto &prev_grad = vars.back();
435442
op_handle->AddInput(prev_grad.get());

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
100100
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
101101
const OpDesc &op) const;
102102

103-
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
103+
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
104104

105105
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
106106
size_t src_dev_id) const;
@@ -111,6 +111,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
111111

112112
private:
113113
BuildStrategy strategy_;
114+
115+
void SetCommunicationContext(OpHandleBase *op_handle,
116+
const platform::Place &p) const;
114117
};
115118
} // namespace details
116119
} // namespace framework

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ OpHandleBase::~OpHandleBase() {
3939
#endif
4040
}
4141

42-
void OpHandleBase::Run(bool use_event) {
42+
void OpHandleBase::Run(bool use_cuda) {
4343
#ifdef PADDLE_WITH_CUDA
44-
if (events_.empty() && use_event) {
44+
if (events_.empty() && use_cuda) {
4545
for (auto &p : dev_ctxes_) {
4646
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
4747
PADDLE_ENFORCE(cudaSetDevice(dev_id));
@@ -50,7 +50,7 @@ void OpHandleBase::Run(bool use_event) {
5050
}
5151
}
5252
#else
53-
PADDLE_ENFORCE(!use_event);
53+
PADDLE_ENFORCE(!use_cuda);
5454
#endif
5555

5656
RunImpl();

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class OpHandleBase {
3636

3737
virtual std::string Name() const = 0;
3838

39-
void Run(bool use_event);
39+
void Run(bool use_cuda);
4040

4141
virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx);
4242

paddle/fluid/framework/details/reduce_and_gather.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ struct ReduceLoDTensor {
3737
PADDLE_ENFORCE_NE(t0.numel(), 0);
3838
dst_tensor_.Resize(t0.dims());
3939
T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace());
40-
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
40+
if (dst != t0.data<T>()) {
41+
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
42+
}
4143

4244
for (size_t i = 1; i < src_tensors_.size(); ++i) {
4345
auto &t = *src_tensors_[i];

paddle/fluid/framework/details/ssa_graph_builder_factory.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ class SSAGraphBuilderFactory {
4040
loss_var_name_(loss_var_name),
4141
param_names_(param_names),
4242
local_scopes_(local_scopes),
43-
strategy_(strategy) {}
43+
strategy_(strategy) {
44+
#ifdef PADDLE_WITH_CUDA
45+
nccl_ctxs_ = nullptr;
46+
#endif
47+
}
4448

4549
#ifdef PADDLE_WITH_CUDA
4650
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {

0 commit comments

Comments
 (0)