Skip to content

Commit 3dd0182

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/clean_matmul
2 parents c6a6d87 + dce0732 commit 3dd0182

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+1327
-431
lines changed

benchmark/cluster/vgg16/vgg16_fluid.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def str2bool(v):
8080
type=str,
8181
default="",
8282
help="Comma-separated list of hostname:port pairs")
83+
parser.add_argument(
84+
"--profile", action='store_true', help="If set, profile a few steps.")
8385

8486
# Flags for defining the tf.train.Server
8587
parser.add_argument(
@@ -183,8 +185,8 @@ def train_loop(exe, trainer_prog):
183185
start_time = time.time()
184186
num_samples = 0
185187
train_pass_acc.reset()
186-
for batch_id, data in enumerate(train_reader()):
187-
ts = time.time()
188+
189+
def run_step(batch_id, data):
188190
img_data = np.array(
189191
map(lambda x: x[0].reshape(data_shape), data)).astype(
190192
"float32")
@@ -196,14 +198,28 @@ def train_loop(exe, trainer_prog):
196198
feed={"pixel": img_data,
197199
"label": y_data},
198200
fetch_list=[avg_cost, batch_acc, batch_size])
201+
return loss, acc, b_size
202+
203+
if args.profile and args.task_index == 0:
204+
# warmup.
205+
for batch_id, data in enumerate(train_reader()):
206+
if batch_id > 5: break
207+
run_step(batch_id, data)
208+
with profiler.profiler('All', 'total', '/tmp/profile_vgg'):
209+
for batch_id, data in enumerate(train_reader()):
210+
if batch_id > 5: break
211+
run_step(batch_id, data)
212+
213+
for batch_id, data in enumerate(train_reader()):
214+
ts = time.time()
215+
loss, acc, b_size = run_step(batch_id, data)
199216
iters += 1
200217
num_samples += len(data)
201218
train_pass_acc.add(value=acc, weight=b_size)
202219
print(
203-
"Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
204-
"Speed = %.2f img/s " % (args.task_index, pass_id, iters,
205-
loss, acc,
206-
len(data) / (time.time() - ts))
220+
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
221+
"Speed = %.2f img/s" % (pass_id, iters, loss, acc,
222+
len(data) / (time.time() - ts))
207223
) # The accuracy is the accumulation of batches, but not the current batch.
208224

209225
pass_elapsed = time.time() - start_time

doc/fluid/design/dist_train/distributed_traing_review.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,3 @@ Codistillation is a technique that tries to scale the training further. A few tr
4242
[3] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine translation system: Bridging the gap between human and machine translation.
4343

4444
[4] LARGE SCALE DISTRIBUTED NEURAL NETWORK TRAINING THROUGH ONLINE DISTILLATION
45-
46-
47-
48-

paddle/fluid/framework/block_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ OpDesc *BlockDesc::InsertOp(size_t index) {
143143
}
144144

145145
void BlockDesc::RemoveOp(size_t s, size_t e) {
146-
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
146+
if (ops_.begin() + s >= ops_.end() || ops_.begin() + e > ops_.end()) {
147147
return;
148148
}
149149
need_update_ = true;

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ if(WITH_GPU)
1515
dynload_cuda)
1616
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
1717
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
18+
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
19+
1820
else()
1921
set(multi_devices_graph_builder_deps)
2022
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
23+
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2124
endif()
2225

23-
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2426
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2527

2628
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919
namespace paddle {
2020
namespace framework {
2121
namespace details {
22-
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
23-
const std::vector<platform::Place> &places)
24-
: local_scopes_(local_scopes), places_(places) {}
2522

2623
void BroadcastOpHandle::RunImpl() {
27-
// the input and output may have dummy var.
28-
VarHandle *in_var_handle;
24+
if (places_.size() == 1) return;
2925

26+
// The input and output may have dummy vars.
27+
VarHandle *in_var_handle;
3028
{
3129
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
3230
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
@@ -55,27 +53,97 @@ void BroadcastOpHandle::RunImpl() {
5553

5654
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5755

58-
for (auto *out : out_var_handles) {
59-
if (*out == *in_var_handle) {
56+
// NOTE: The tensors' Place of input and output must be all on GPU or all on
57+
// CPU.
58+
for (auto *out_var_handle : out_var_handles) {
59+
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
6060
continue;
6161
}
62-
63-
auto &out_p = out->place_;
64-
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_);
62+
auto t_out_p = out_var_handle->place_;
63+
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
64+
->FindVar(out_var_handle->name_);
6565
PADDLE_ENFORCE_NOT_NULL(out_var);
66-
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
67-
"Places must be all on CPU or all on CUDA.");
68-
66+
if (platform::is_gpu_place(in_tensor.place())) {
67+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
68+
"Places of input and output must be all on GPU.");
69+
} else {
70+
t_out_p = platform::CPUPlace();
71+
}
6972
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
70-
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
73+
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
7174
in_tensor.type());
75+
}
76+
77+
if (platform::is_cpu_place(in_tensor.place())) {
78+
for (auto *out_var_handle : out_var_handles) {
79+
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
80+
continue;
81+
}
82+
auto &out_p = out_var_handle->place_;
83+
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
84+
->FindVar(out_var_handle->name_);
85+
86+
RunAndRecordEvent(out_p, [in_tensor, out_var] {
87+
paddle::framework::TensorCopy(
88+
in_tensor, platform::CPUPlace(),
89+
&VariableVisitor::GetMutableTensor(out_var));
90+
});
91+
}
92+
} else {
93+
#ifdef PADDLE_WITH_CUDA
94+
VarHandle *out_handle = nullptr;
95+
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
96+
std::vector<std::function<void()>> broadcast_calls;
97+
98+
for (auto out_var_handle : out_var_handles) {
99+
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
100+
->FindVar(out_var_handle->name_);
101+
102+
int dst_id =
103+
boost::get<platform::CUDAPlace>(out_var_handle->place_).device;
104+
105+
auto &nccl_ctx = nccl_ctxs_->at(dst_id);
106+
107+
void *send_recv_buffer = nullptr;
108+
if (root_id == dst_id) {
109+
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
110+
out_handle = out_var_handle;
111+
} else {
112+
send_recv_buffer =
113+
VariableVisitor::GetMutableTensor(out_var).mutable_data(
114+
out_var_handle->place_);
115+
}
116+
117+
int type = platform::ToNCCLDataType(in_tensor.type());
118+
size_t numel = static_cast<size_t>(in_tensor.numel());
119+
broadcast_calls.emplace_back(
120+
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
121+
PADDLE_ENFORCE(platform::dynload::ncclBcast(
122+
send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
123+
root_id, nccl_ctx.comm_, nccl_ctx.stream()));
124+
});
125+
}
72126

73-
auto dev_ctx = dev_ctxes_.at(out_p);
74-
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
75-
paddle::framework::TensorCopy(
76-
in_tensor, out_p, *(dev_ctx),
77-
&VariableVisitor::GetMutableTensor(out_var));
127+
this->RunAndRecordEvent([&] {
128+
{
129+
platform::NCCLGroupGuard guard;
130+
for (auto &call : broadcast_calls) {
131+
call();
132+
}
133+
}
134+
135+
if (!out_handle->IsTheSameVar(*in_var_handle)) {
136+
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
137+
->FindVar(out_var_handles[0]->name_);
138+
paddle::framework::TensorCopy(
139+
in_tensor, in_var_handle->place_,
140+
*(dev_ctxes_.at(in_var_handle->place_)),
141+
&VariableVisitor::GetMutableTensor(out_var));
142+
}
78143
});
144+
#else
145+
PADDLE_THROW("CUDA is not enabled.");
146+
#endif
79147
}
80148
}
81149

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,32 @@
2424
#include "paddle/fluid/framework/selected_rows.h"
2525
#include "paddle/fluid/platform/device_context.h"
2626

27+
#ifdef PADDLE_WITH_CUDA
28+
#include "paddle/fluid/platform/nccl_helper.h"
29+
#endif
30+
2731
namespace paddle {
2832
namespace framework {
2933
namespace details {
3034

3135
struct BroadcastOpHandle : public OpHandleBase {
3236
public:
37+
#ifdef PADDLE_WITH_CUDA
38+
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
39+
const std::vector<platform::Place> &places,
40+
const platform::NCCLContextMap *nccl_ctxs)
41+
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
42+
if (nccl_ctxs_) {
43+
for (auto &p_ctx : nccl_ctxs_->contexts_) {
44+
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
45+
}
46+
}
47+
}
48+
#else
3349
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
34-
const std::vector<platform::Place> &places);
50+
const std::vector<platform::Place> &places)
51+
: local_scopes_(local_scopes), places_(places) {}
52+
#endif
3553

3654
std::string Name() const override;
3755

@@ -44,6 +62,9 @@ struct BroadcastOpHandle : public OpHandleBase {
4462
private:
4563
const std::vector<Scope *> &local_scopes_;
4664
const std::vector<platform::Place> &places_;
65+
#ifdef PADDLE_WITH_CUDA
66+
const platform::NCCLContextMap *nccl_ctxs_;
67+
#endif
4768
};
4869
} // namespace details
4970
} // namespace framework

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,25 @@ struct TestBroadcastOpHandle {
3535
std::unique_ptr<OpHandleBase> op_handle_;
3636
std::vector<std::unique_ptr<VarHandleBase>> vars_;
3737
std::vector<p::Place> gpu_list_;
38+
bool use_gpu_;
39+
#ifdef PADDLE_WITH_CUDA
40+
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
41+
#endif
3842

3943
void WaitAll() {
4044
for (size_t j = 0; j < ctxs_.size(); ++j) {
4145
ctxs_[j]->Wait();
4246
}
47+
#ifdef PADDLE_WITH_CUDA
48+
if (nccl_ctxs_) {
49+
nccl_ctxs_->WaitAll();
50+
}
51+
#endif
4352
}
4453

4554
void InitCtxOnGpu(bool use_gpu) {
46-
if (use_gpu) {
55+
use_gpu_ = use_gpu;
56+
if (use_gpu_) {
4757
#ifdef PADDLE_WITH_CUDA
4858
int count = p::GetCUDADeviceCount();
4959
if (count <= 1) {
@@ -57,6 +67,7 @@ struct TestBroadcastOpHandle {
5767
gpu_list_.push_back(p);
5868
ctxs_.emplace_back(new p::CUDADeviceContext(p));
5969
}
70+
nccl_ctxs_.reset(new platform::NCCLContextMap(gpu_list_));
6071
#else
6172
PADDLE_THROW("CUDA is not support.");
6273
#endif
@@ -67,6 +78,9 @@ struct TestBroadcastOpHandle {
6778
gpu_list_.push_back(p);
6879
ctxs_.emplace_back(new p::CPUDeviceContext(p));
6980
}
81+
#ifdef PADDLE_WITH_CUDA
82+
nccl_ctxs_.reset(nullptr);
83+
#endif
7084
}
7185
}
7286

@@ -82,7 +96,21 @@ struct TestBroadcastOpHandle {
8296
}
8397
param_scopes_[input_scope_idx]->Var("input");
8498

85-
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
99+
if (use_gpu_) {
100+
#ifdef PADDLE_WITH_CUDA
101+
op_handle_.reset(
102+
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
103+
#else
104+
PADDLE_THROW("CUDA is not support.");
105+
#endif
106+
} else {
107+
#ifdef PADDLE_WITH_CUDA
108+
op_handle_.reset(
109+
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
110+
#else
111+
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
112+
#endif
113+
}
86114

87115
auto* in_var_handle =
88116
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
@@ -97,7 +125,9 @@ struct TestBroadcastOpHandle {
97125
op_handle_->AddInput(dummy_var_handle);
98126

99127
for (size_t j = 0; j < gpu_list_.size(); ++j) {
100-
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
128+
if (!use_gpu_) {
129+
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
130+
}
101131
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
102132
vars_.emplace_back(out_var_handle);
103133
op_handle_->AddOutput(out_var_handle);

0 commit comments

Comments
 (0)