Skip to content

Commit c7b7291

Browse files
authored
Merge pull request #8758 from panyx0718/nccl
[Speed]Avoid init_nccl for every steps.
2 parents 767acc6 + a4d68ed commit c7b7291

File tree

3 files changed

+54
-32
lines changed

3 files changed

+54
-32
lines changed

paddle/fluid/operators/nccl/nccl_gpu_common.cc

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,50 @@ limitations under the License. */
1616
#include "paddle/fluid/platform/gpu_info.h"
1717

1818
namespace paddle {
19-
namespace platform {} // namespace platform
19+
namespace platform {
20+
namespace {
21+
// TODO(panyx0718): Where to destroy them.
22+
std::unique_ptr<std::vector<ncclComm_t>> global_comms;
23+
std::unique_ptr<std::unordered_map<int, int>> comm_id_map;
24+
bool inited = false;
25+
size_t last_num_gpus = -1;
26+
// TODO(panyx0718): Need to decide whether Paddle supports parallel
27+
// runs with different number GPUs. If true, current solution is not enough.
28+
std::mutex comm_mu;
29+
}
30+
31+
int Communicator::GetCommId(int device_id) const {
32+
std::lock_guard<std::mutex> guard(comm_mu);
33+
return comm_id_map->at(device_id);
34+
}
35+
36+
void Communicator::InitAll(const std::vector<int>& gpus) {
37+
std::lock_guard<std::mutex> guard(comm_mu);
38+
if (inited && last_num_gpus == gpus.size()) {
39+
return;
40+
}
41+
last_num_gpus = gpus.size();
42+
if (global_comms) {
43+
for (size_t i = 0; i < global_comms->size(); ++i) {
44+
// FIXME(dzh) : PADDLE_ENFORCE return void
45+
dynload::ncclCommDestroy((*global_comms)[i]);
46+
}
47+
}
48+
global_comms.reset(new std::vector<ncclComm_t>());
49+
comm_id_map.reset(new std::unordered_map<int, int>());
50+
global_comms->resize(gpus.size());
51+
for (size_t i = 0; i < gpus.size(); ++i) {
52+
(*comm_id_map)[gpus[i]] = i;
53+
}
54+
PADDLE_ENFORCE(
55+
dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data()));
56+
inited = true;
57+
}
58+
59+
const std::vector<ncclComm_t>& Communicator::comms() const {
60+
std::lock_guard<std::mutex> guard(comm_mu);
61+
return *global_comms;
62+
}
63+
64+
} // namespace platform
2065
} // namespace paddle

paddle/fluid/operators/nccl/nccl_gpu_common.h

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,16 @@ limitations under the License. */
2929

3030
namespace paddle {
3131
namespace platform {
32-
3332
constexpr int kInvalidGPUId = -1;
3433

3534
struct Communicator {
36-
std::vector<ncclComm_t> comms_;
37-
std::unordered_map<int, int> comm_id_map_;
38-
bool inited_;
39-
4035
Communicator() {}
4136

42-
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
43-
44-
void InitAll(const std::vector<int>& gpus) {
45-
comms_.resize(gpus.size());
46-
inited_ = false;
47-
for (size_t i = 0; i < gpus.size(); ++i) {
48-
comm_id_map_[gpus[i]] = i;
49-
}
50-
PADDLE_ENFORCE(
51-
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
52-
inited_ = true;
53-
}
37+
int GetCommId(int device_id) const;
5438

55-
~Communicator() {
56-
if (inited_) {
57-
for (size_t i = 0; i < comms_.size(); ++i) {
58-
// FIXME(dzh) : PADDLE_ENFORCE return void
59-
dynload::ncclCommDestroy(comms_[i]);
60-
}
61-
}
62-
}
39+
void InitAll(const std::vector<int>& gpus);
6340

64-
DISABLE_COPY_AND_ASSIGN(Communicator);
41+
const std::vector<ncclComm_t>& comms() const;
6542
};
6643

6744
} // namespace platform

paddle/fluid/operators/nccl_op.cu.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
7878
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
7979
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
8080
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
81-
comm->comms_[idx], stream));
81+
comm->comms().at(idx), stream));
8282
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
8383

8484
VLOG(1) << "gpu : "
@@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
127127
std::hash<std::string> hasher;
128128
for (size_t i = 0; i < ins.size(); ++i) {
129129
if (root == platform::kInvalidGPUId) {
130-
root = hasher(ins_names[i]) % comm->comms_.size();
130+
root = hasher(ins_names[i]) % comm->comms().size();
131131
}
132132
T* recvbuffer = nullptr;
133133
if (root == gpu_id) {
@@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
139139

140140
PADDLE_ENFORCE(platform::dynload::ncclReduce(
141141
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
142-
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
142+
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
143143
stream));
144144
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
145145

@@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
176176
VLOG(1) << " before ncclBcast";
177177
PADDLE_ENFORCE(platform::dynload::ncclBcast(
178178
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
179-
root, comm->comms_[idx], stream));
179+
root, comm->comms().at(idx), stream));
180180
VLOG(1) << " after ncclBcast";
181181
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
182182

@@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
190190

191191
PADDLE_ENFORCE(platform::dynload::ncclBcast(
192192
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
193-
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
193+
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
194194
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
195195

196196
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "

0 commit comments

Comments
 (0)